IVGCVSW-5665 Basic NN Driver support for next OS Version


Signed-off-by: Kevin May <kevin.may@arm.com>
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: I1e1db52322092c6b1b7ac6183c3adc90aabcec24
diff --git a/ArmnnPreparedModel_1_3.cpp b/ArmnnPreparedModel_1_3.cpp
index aed4fa1..2970e8f 100644
--- a/ArmnnPreparedModel_1_3.cpp
+++ b/ArmnnPreparedModel_1_3.cpp
@@ -22,12 +22,16 @@
 #include <cassert>
 #include <cinttypes>
 
+#ifdef ARMNN_ANDROID_S
+#include <LegacyUtils.h>
+#endif
+
 using namespace android;
 using namespace android::hardware;
 
 namespace {
 
-static const Timing g_NoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
+static const V1_2::Timing g_NoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
 using namespace armnn_driver;
 using TimePoint = std::chrono::steady_clock::time_point;
 
@@ -44,8 +48,8 @@
 
 void NotifyCallbackAndCheck(const ::android::sp<V1_0::IExecutionCallback>& callback,
                             V1_3::ErrorStatus errorStatus,
-                            std::vector<OutputShape>,
-                            const Timing,
+                            std::vector<V1_2::OutputShape>,
+                            const V1_2::Timing,
                             std::string callingFunction)
 {
     Return<void> returned = callback->notify(convertToV1_0(errorStatus));
@@ -59,8 +63,8 @@
 
 void NotifyCallbackAndCheck(const ::android::sp<V1_2::IExecutionCallback>& callback,
                             V1_3::ErrorStatus errorStatus,
-                            std::vector<OutputShape> outputShapes,
-                            const Timing timing,
+                            std::vector<V1_2::OutputShape> outputShapes,
+                            const V1_2::Timing timing,
                             std::string callingFunction)
 {
     Return<void> returned = callback->notify_1_2(convertToV1_0(errorStatus), outputShapes, timing);
@@ -74,8 +78,8 @@
 
 void NotifyCallbackAndCheck(const ::android::sp<V1_3::IExecutionCallback>& callback,
                             V1_3::ErrorStatus errorStatus,
-                            std::vector<OutputShape> outputShapes,
-                            const Timing timing,
+                            std::vector<V1_2::OutputShape> outputShapes,
+                            const V1_2::Timing timing,
                             std::string callingFunction)
 {
     Return<void> returned = callback->notify_1_3(errorStatus, outputShapes, timing);
@@ -87,7 +91,7 @@
     }
 }
 
-bool ValidateRequestArgument(const RequestArgument& requestArg, const armnn::TensorInfo& tensorInfo)
+bool ValidateRequestArgument(const V1_0::RequestArgument& requestArg, const armnn::TensorInfo& tensorInfo)
 {
     if (requestArg.dimensions.size() != 0)
     {
@@ -112,7 +116,7 @@
     return true;
 }
 
-armnn::Tensor GetTensorForRequestArgument(const RequestArgument& requestArg,
+armnn::Tensor GetTensorForRequestArgument(const V1_0::RequestArgument& requestArg,
                                           const armnn::TensorInfo& tensorInfo,
                                           const std::vector<::android::nn::RunTimePoolInfo>& requestPools)
 {
@@ -201,21 +205,21 @@
     }
 
     auto cb = [callback](V1_3::ErrorStatus errorStatus,
-                         std::vector<OutputShape> outputShapes,
-                         const Timing& timing,
+                         std::vector<V1_2::OutputShape> outputShapes,
+                         const V1_2::Timing& timing,
                          std::string callingFunction)
     {
         NotifyCallbackAndCheck(callback, errorStatus, outputShapes, timing, callingFunction);
     };
 
 
-    return convertToV1_0(Execute(convertToV1_3(request), MeasureTiming::NO, cb));
+    return convertToV1_0(Execute(convertToV1_3(request), V1_2::MeasureTiming::NO, cb));
 }
 
 template<typename HalVersion>
 Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::execute_1_2(
     const V1_0::Request& request,
-    MeasureTiming measureTiming,
+    V1_2::MeasureTiming measureTiming,
     const sp<V1_2::IExecutionCallback>& callback)
 {
     if (callback.get() == nullptr)
@@ -225,8 +229,8 @@
     }
 
     auto cb = [callback](V1_3::ErrorStatus errorStatus,
-                         std::vector<OutputShape> outputShapes,
-                         const Timing& timing,
+                         std::vector<V1_2::OutputShape> outputShapes,
+                         const V1_2::Timing& timing,
                          std::string callingFunction)
     {
         NotifyCallbackAndCheck(callback, errorStatus, outputShapes, timing, callingFunction);
@@ -238,7 +242,7 @@
 template<typename HalVersion>
 Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::execute_1_3(
         const V1_3::Request& request,
-        MeasureTiming measureTiming,
+        V1_2::MeasureTiming measureTiming,
         const V1_3::OptionalTimePoint&,
         const V1_3::OptionalTimeoutDuration&,
         const sp<V1_3::IExecutionCallback>& callback)
@@ -250,8 +254,8 @@
     }
 
     auto cb = [callback](V1_3::ErrorStatus errorStatus,
-                         std::vector<OutputShape> outputShapes,
-                         const Timing& timing,
+                         std::vector<V1_2::OutputShape> outputShapes,
+                         const V1_2::Timing& timing,
                          std::string callingFunction)
     {
         NotifyCallbackAndCheck(callback, errorStatus, outputShapes, timing, callingFunction);
@@ -266,7 +270,7 @@
 class ArmnnFencedExecutionCallback : public V1_3::IFencedExecutionCallback
 {
 public:
-    ArmnnFencedExecutionCallback(V1_3::ErrorStatus errorStatus, Timing timing, Timing fenceTiming)
+    ArmnnFencedExecutionCallback(V1_3::ErrorStatus errorStatus, V1_2::Timing timing, V1_2::Timing fenceTiming)
         : m_ErrorStatus(errorStatus), m_Timing(timing), m_FenceTiming(fenceTiming) {}
     ~ArmnnFencedExecutionCallback() {}
 
@@ -277,33 +281,33 @@
     }
 private:
     V1_3::ErrorStatus m_ErrorStatus;
-    Timing m_Timing;
-    Timing m_FenceTiming;
+    V1_2::Timing m_Timing;
+    V1_2::Timing m_FenceTiming;
 };
 
 template<typename HalVersion>
 Return<void> ArmnnPreparedModel_1_3<HalVersion>::executeFenced(const V1_3::Request& request,
                                                                const hidl_vec<hidl_handle>& fenceWaitFor,
-                                                               MeasureTiming measureTiming,
-                                                               const OptionalTimePoint& deadline,
-                                                               const OptionalTimeoutDuration& loopTimeoutDuration,
-                                                               const OptionalTimeoutDuration&,
+                                                               V1_2::MeasureTiming measureTiming,
+                                                               const V1_3::OptionalTimePoint& deadline,
+                                                               const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
+                                                               const V1_3::OptionalTimeoutDuration&,
                                                                executeFenced_cb cb)
 {
     ALOGV("ArmnnPreparedModel_1_3::executeFenced(...)");
     if (cb == nullptr)
     {
         ALOGE("ArmnnPreparedModel_1_3::executeFenced invalid callback passed");
-        cb(ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
+        cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
         return Void();
     }
 
-    if (deadline.getDiscriminator() != OptionalTimePoint::hidl_discriminator::none)
+    if (deadline.getDiscriminator() != V1_3::OptionalTimePoint::hidl_discriminator::none)
     {
         ALOGW("ArmnnPreparedModel_1_3::executeFenced parameter deadline is set but not supported.");
     }
 
-    if (loopTimeoutDuration.getDiscriminator() != OptionalTimeoutDuration::hidl_discriminator::none)
+    if (loopTimeoutDuration.getDiscriminator() != V1_3::OptionalTimeoutDuration::hidl_discriminator::none)
     {
         ALOGW("ArmnnPreparedModel_1_3::executeFenced parameter loopTimeoutDuration is set but not supported.");
     }
@@ -311,12 +315,12 @@
     if (!android::nn::validateRequest(request, m_Model, /*allowUnspecifiedOutput=*/false))
     {
         ALOGV("ArmnnPreparedModel_1_3::executeFenced outputs must be specified for fenced execution ");
-        cb(ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
+        cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
         return Void();
     }
 
     ExecutionContext_1_3 ctx;
-    if (measureTiming == MeasureTiming::YES)
+    if (measureTiming == V1_2::MeasureTiming::YES)
     {
         ctx.measureTimings = measureTiming;
         ctx.driverStart = Now();
@@ -339,20 +343,20 @@
         auto fenceNativeHandle = fenceWaitFor[index].getNativeHandle();
         if (!fenceNativeHandle)
         {
-            cb(ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
+            cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
             return Void();
         }
 
         if (sync_wait(fenceNativeHandle->data[0], -1) < 0)
         {
             ALOGE("ArmnnPreparedModel_1_3::executeFenced sync fence failed.");
-            cb(ErrorStatus::GENERAL_FAILURE, hidl_handle(nullptr), nullptr);
+            cb(V1_3::ErrorStatus::GENERAL_FAILURE, hidl_handle(nullptr), nullptr);
             return Void();
         }
     }
 
     TimePoint fenceExecutionStart;
-    if (measureTiming == MeasureTiming::YES)
+    if (measureTiming == V1_2::MeasureTiming::YES)
     {
         fenceExecutionStart = Now();
     }
@@ -368,14 +372,14 @@
     auto [status, outShapes, timings, message] = PrepareMemoryForIO(*inputs, *outputs, *memPools, request);
     if (status != V1_3::ErrorStatus::NONE)
     {
-        cb(ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
+        cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
         return Void();
     }
 
     ALOGV("ArmnnPreparedModel_1_3::executeFenced(...) before ExecuteGraph");
 
     // call it with nullCallback for now as we will report the error status from here..
-    auto nullCallback = [](V1_3::ErrorStatus, std::vector<OutputShape>, const Timing&, std::string) {};
+    auto nullCallback = [](V1_3::ErrorStatus, std::vector<V1_2::OutputShape>, const V1_2::Timing&, std::string) {};
     CallbackContext_1_3 cbCtx;
     cbCtx.callback = nullCallback;
     cbCtx.ctx = ctx;
@@ -388,9 +392,9 @@
     }
     ALOGV("ArmnnPreparedModel_1_3::executeFenced(...) after ExecuteGraph");
 
-    Timing timing = g_NoTiming;
-    Timing fenceTiming = g_NoTiming;
-    if (measureTiming == MeasureTiming::YES)
+    V1_2::Timing timing = g_NoTiming;
+    V1_2::Timing fenceTiming = g_NoTiming;
+    if (measureTiming == V1_2::MeasureTiming::YES)
     {
         fenceTiming.timeOnDevice = MicrosecondsDuration(ctx.deviceEnd, ctx.deviceStart);
         fenceTiming.timeInDriver = MicrosecondsDuration(ctx.driverEnd, fenceExecutionStart);
@@ -399,8 +403,8 @@
     }
 
     sp<ArmnnFencedExecutionCallback> armnnFencedExecutionCallback =
-        new ArmnnFencedExecutionCallback(ErrorStatus::NONE, timing, fenceTiming);
-    cb(ErrorStatus::NONE, hidl_handle(nullptr), armnnFencedExecutionCallback);
+        new ArmnnFencedExecutionCallback(V1_3::ErrorStatus::NONE, timing, fenceTiming);
+    cb(V1_3::ErrorStatus::NONE, hidl_handle(nullptr), armnnFencedExecutionCallback);
     return Void();
 }
 
@@ -433,7 +437,7 @@
 template<typename HalVersion>
 Return<V1_3::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::PrepareMemoryForOutputs(
     armnn::OutputTensors& outputs,
-    std::vector<OutputShape> &outputShapes,
+    std::vector<V1_2::OutputShape> &outputShapes,
     const V1_3::Request& request,
     const std::vector<android::nn::RunTimePoolInfo>& memPools)
 {
@@ -478,7 +482,13 @@
             return V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
         }
 
-        const size_t bufferSize = memPools.at(outputArg.location.poolIndex).getHidlMemory().size();
+        size_t bufferSize = 0;
+#if !defined(ARMNN_ANDROID_S)
+        bufferSize = memPools.at(outputArg.location.poolIndex).getHidlMemory().size();
+        if (bufferSize < outputSize)
+#else
+        bufferSize = memPools.at(outputArg.location.poolIndex).getMemory().size;
+#endif
         if (bufferSize < outputSize)
         {
             ALOGW("ArmnnPreparedModel_1_3::Execute failed bufferSize (%s) < outputSize (%s)",
@@ -492,15 +502,19 @@
 }
 
 template<typename HalVersion>
-std::tuple<V1_3::ErrorStatus, hidl_vec<OutputShape>, Timing, std::string>
+std::tuple<V1_3::ErrorStatus, hidl_vec<V1_2::OutputShape>, V1_2::Timing, std::string>
     ArmnnPreparedModel_1_3<HalVersion>::PrepareMemoryForIO(armnn::InputTensors& inputs,
                                                            armnn::OutputTensors& outputs,
                                                            std::vector<android::nn::RunTimePoolInfo>& memPools,
                                                            const V1_3::Request& request)
 {
+#if !defined(ARMNN_ANDROID_S)
     if (!setRunTimePoolInfosFromMemoryPools(&memPools, request.pools))
+#else
+    if (!setRunTimePoolInfosFromMemoryPools(&memPools, uncheckedConvert(request.pools)))
+#endif
     {
-        return {ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming, "ArmnnPreparedModel_1_3::execute"};
+        return {V1_3::ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming, "ArmnnPreparedModel_1_3::execute"};
     }
 
     // add the inputs and outputs with their data
@@ -508,10 +522,10 @@
     {
         if (PrepareMemoryForInputs(inputs, request, memPools) != V1_3::ErrorStatus::NONE)
         {
-            return {ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::execute"};
+            return {V1_3::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::execute"};
         }
 
-        std::vector<OutputShape> outputShapes(request.outputs.size());
+        std::vector<V1_2::OutputShape> outputShapes(request.outputs.size());
 
         auto errorStatus = PrepareMemoryForOutputs(outputs, outputShapes, request, memPools);
         if (errorStatus != V1_3::ErrorStatus::NONE)
@@ -522,12 +536,12 @@
     catch (armnn::Exception& e)
     {
         ALOGW("armnn::Exception caught while preparing for EnqueueWorkload: %s", e.what());
-        return {ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::execute"};
+        return {V1_3::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::execute"};
     }
     catch (std::exception& e)
     {
         ALOGE("std::exception caught while preparing for EnqueueWorkload: %s", e.what());
-        return {ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::execute"};
+        return {V1_3::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::execute"};
     }
 
     return {V1_3::ErrorStatus::NONE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::execute"};
@@ -538,7 +552,7 @@
 Return<void> ArmnnPreparedModel_1_3<HalVersion>::ExecuteSynchronously(const V1_3::Request& request,
                                                                       CallbackContext cbCtx)
 {
-    if (cbCtx.ctx.measureTimings == MeasureTiming::YES)
+    if (cbCtx.ctx.measureTimings == V1_2::MeasureTiming::YES)
     {
         cbCtx.ctx.driverStart = Now();
     }
@@ -587,7 +601,7 @@
 
 template<typename HalVersion>
 Return<void> ArmnnPreparedModel_1_3<HalVersion>::executeSynchronously(const V1_0::Request& request,
-                                                                      MeasureTiming measureTiming,
+                                                                      V1_2::MeasureTiming measureTiming,
                                                                       executeSynchronously_cb cb)
 {
     ALOGV("ArmnnPreparedModel_1_3::executeSynchronously(): %s", GetModelSummary(m_Model).c_str());
@@ -600,8 +614,8 @@
     }
 
     auto cbWrapper = [cb](V1_3::ErrorStatus errorStatus,
-                          std::vector<OutputShape> outputShapes,
-                          const Timing& timing,
+                          std::vector<V1_2::OutputShape> outputShapes,
+                          const V1_2::Timing& timing,
                           std::string)
     {
         cb(convertToV1_0(errorStatus), outputShapes, timing);
@@ -618,7 +632,7 @@
 template<typename HalVersion>
 Return<void>  ArmnnPreparedModel_1_3<HalVersion>::executeSynchronously_1_3(
         const V1_3::Request& request,
-        MeasureTiming measureTiming,
+        V1_2::MeasureTiming measureTiming,
         const V1_3::OptionalTimePoint& deadline,
         const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
         executeSynchronously_1_3_cb cb)
@@ -632,20 +646,20 @@
         return Void();
     }
 
-    if (deadline.getDiscriminator() != OptionalTimePoint::hidl_discriminator::none)
+    if (deadline.getDiscriminator() != V1_3::OptionalTimePoint::hidl_discriminator::none)
     {
         ALOGW("ArmnnPreparedModel_1_3::executeSynchronously_1_3 parameter deadline is set but not supported.");
     }
 
-    if (loopTimeoutDuration.getDiscriminator() != OptionalTimeoutDuration::hidl_discriminator::none)
+    if (loopTimeoutDuration.getDiscriminator() != V1_3::OptionalTimeoutDuration::hidl_discriminator::none)
     {
         ALOGW(
            "ArmnnPreparedModel_1_3::executeSynchronously_1_3 parameter loopTimeoutDuration is set but not supported.");
     }
 
     auto cbWrapper = [cb](V1_3::ErrorStatus errorStatus,
-                          std::vector<OutputShape> outputShapes,
-                          const Timing& timing,
+                          std::vector<V1_2::OutputShape> outputShapes,
+                          const V1_2::Timing& timing,
                           std::string)
     {
         cb(errorStatus, outputShapes, timing);
@@ -695,7 +709,7 @@
 
     DumpTensorsIfRequired("Input", inputTensors);
 
-    std::vector<OutputShape> outputShapes(outputTensors.size());
+    std::vector<V1_2::OutputShape> outputShapes(outputTensors.size());
     for (unsigned int i = 0; i < outputTensors.size(); i++)
     {
         std::pair<int, armnn::Tensor> outputTensorPair = outputTensors[i];
@@ -708,14 +722,14 @@
     // run it
     try
     {
-        if (cb.ctx.measureTimings == MeasureTiming::YES)
+        if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES)
         {
             cb.ctx.deviceStart = Now();
         }
 
         armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
 
-        if (cb.ctx.measureTimings == MeasureTiming::YES)
+        if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES)
         {
             cb.ctx.deviceEnd = Now();
         }
@@ -743,10 +757,10 @@
 
     DumpTensorsIfRequired("Output", outputTensors);
 
-    if (cb.ctx.measureTimings == MeasureTiming::YES)
+    if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES)
     {
         cb.ctx.driverEnd = Now();
-        Timing timing;
+        V1_2::Timing timing;
         timing.timeOnDevice = MicrosecondsDuration(cb.ctx.deviceEnd, cb.ctx.deviceStart);
         timing.timeInDriver = MicrosecondsDuration(cb.ctx.driverEnd, cb.ctx.driverStart);
         ALOGV("ArmnnPreparedModel_1_3::execute timing - Device = %lu Driver = %lu", timing.timeOnDevice,
@@ -783,10 +797,10 @@
         outputTensors.emplace_back(i, outputTensor);
     }
 
-    auto nullCallback = [](V1_3::ErrorStatus, std::vector<OutputShape>, const Timing&, std::string) {};
+    auto nullCallback = [](V1_3::ErrorStatus, std::vector<V1_2::OutputShape>, const V1_2::Timing&, std::string) {};
     CallbackContext_1_3 callbackContext;
     callbackContext.callback = nullCallback;
-    callbackContext.ctx.measureTimings = MeasureTiming::NO;
+    callbackContext.ctx.measureTimings = V1_2::MeasureTiming::NO;
     auto memPools = std::make_shared<std::vector<::android::nn::RunTimePoolInfo>>();
 
     auto errorStatus = ExecuteGraph(memPools,
@@ -798,11 +812,11 @@
 
 template<typename HalVersion>
 Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::Execute(const V1_3::Request& request,
-                                                                       MeasureTiming measureTiming,
+                                                                       V1_2::MeasureTiming measureTiming,
                                                                        CallbackAsync_1_3 callback)
 {
     ExecutionContext_1_3 ctx;
-    if (measureTiming == MeasureTiming::YES)
+    if (measureTiming == V1_2::MeasureTiming::YES)
     {
         ctx.measureTimings = measureTiming;
         ctx.driverStart = Now();