IVGCVSW-5665 Basic NN Driver support for next OS Version


Signed-off-by: Kevin May <kevin.may@arm.com>
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: I1e1db52322092c6b1b7ac6183c3adc90aabcec24
diff --git a/ArmnnPreparedModel_1_2.cpp b/ArmnnPreparedModel_1_2.cpp
index c2148ba..a2148c2 100644
--- a/ArmnnPreparedModel_1_2.cpp
+++ b/ArmnnPreparedModel_1_2.cpp
@@ -16,12 +16,16 @@
 #include <cassert>
 #include <cinttypes>
 
+#ifdef ARMNN_ANDROID_S
+#include <LegacyUtils.h>
+#endif
+
 using namespace android;
 using namespace android::hardware;
 
 namespace {
 
-static const Timing g_NoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
+static const V1_2::Timing g_NoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
 using namespace armnn_driver;
 using TimePoint = std::chrono::steady_clock::time_point;
 
@@ -38,8 +42,8 @@
 
 void NotifyCallbackAndCheck(const ::android::sp<V1_0::IExecutionCallback>& callback,
                             V1_0::ErrorStatus errorStatus,
-                            std::vector<OutputShape>,
-                            const Timing,
+                            std::vector<V1_2::OutputShape>,
+                            const V1_2::Timing,
                             std::string callingFunction)
 {
     Return<void> returned = callback->notify(errorStatus);
@@ -53,8 +57,8 @@
 
 void NotifyCallbackAndCheck(const ::android::sp<V1_2::IExecutionCallback>& callback,
                             V1_0::ErrorStatus errorStatus,
-                            std::vector<OutputShape> outputShapes,
-                            const Timing timing,
+                            std::vector<V1_2::OutputShape> outputShapes,
+                            const V1_2::Timing timing,
                             std::string callingFunction)
 {
     Return<void> returned = callback->notify_1_2(errorStatus, outputShapes, timing);
@@ -66,7 +70,7 @@
     }
 }
 
-bool ValidateRequestArgument(const RequestArgument& requestArg, const armnn::TensorInfo& tensorInfo)
+bool ValidateRequestArgument(const V1_0::RequestArgument& requestArg, const armnn::TensorInfo& tensorInfo)
 {
     if (requestArg.dimensions.size() != 0)
     {
@@ -91,7 +95,7 @@
     return true;
 }
 
-armnn::Tensor GetTensorForRequestArgument(const RequestArgument& requestArg,
+armnn::Tensor GetTensorForRequestArgument(const V1_0::RequestArgument& requestArg,
                                           const armnn::TensorInfo& tensorInfo,
                                           const std::vector<::android::nn::RunTimePoolInfo>& requestPools)
 {
@@ -178,20 +182,20 @@
     }
 
     auto cb = [callback](V1_0::ErrorStatus errorStatus,
-                         std::vector<OutputShape> outputShapes,
-                         const Timing& timing,
+                         std::vector<V1_2::OutputShape> outputShapes,
+                         const V1_2::Timing& timing,
                          std::string callingFunction)
     {
         NotifyCallbackAndCheck(callback, errorStatus, outputShapes, timing, callingFunction);
     };
 
-    return Execute(request, MeasureTiming::NO, cb);
+    return Execute(request, V1_2::MeasureTiming::NO, cb);
 }
 
 template<typename HalVersion>
 Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::execute_1_2(
         const V1_0::Request& request,
-        MeasureTiming measureTiming,
+        V1_2::MeasureTiming measureTiming,
         const sp<V1_2::IExecutionCallback>& callback)
 {
     if (callback.get() == nullptr)
@@ -201,8 +205,8 @@
     }
 
     auto cb = [callback](V1_0::ErrorStatus errorStatus,
-                         std::vector<OutputShape> outputShapes,
-                         const Timing& timing,
+                         std::vector<V1_2::OutputShape> outputShapes,
+                         const V1_2::Timing& timing,
                          std::string callingFunction)
     {
         NotifyCallbackAndCheck(callback, errorStatus, outputShapes, timing, callingFunction);
@@ -240,7 +244,7 @@
 template<typename HalVersion>
 Return<V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::PrepareMemoryForOutputs(
     armnn::OutputTensors& outputs,
-    std::vector<OutputShape> &outputShapes,
+    std::vector<V1_2::OutputShape> &outputShapes,
     const V1_0::Request& request,
     const std::vector<android::nn::RunTimePoolInfo>& memPools)
 {
@@ -265,13 +269,23 @@
             return V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
         }
 
+#if !defined(ARMNN_ANDROID_S)
         const size_t bufferSize = memPools.at(outputArg.location.poolIndex).getHidlMemory().size();
         if (bufferSize < outputSize)
         {
             ALOGW("ArmnnPreparedModel_1_2::Execute failed: bufferSize < outputSize");
             return V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
         }
-
+#else
+        const size_t bufferSize = memPools.at(outputArg.location.poolIndex).getMemory().size;
+        if (bufferSize < outputSize)
+        {
+            ALOGW("ArmnnPreparedModel_1_2::Execute failed bufferSize (%s) < outputSize (%s)",
+                  std::to_string(bufferSize).c_str(), std::to_string(outputSize).c_str());
+            outputShapes[i].isSufficient = false;
+            return V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
+        }
+#endif
         outputs.emplace_back(i, outputTensor);
         outputShapes[i] = ComputeShape(outputTensorInfo);
     }
@@ -287,12 +301,15 @@
                                          const V1_0::Request& request,
                                          CallbackAsync_1_2 callback)
 {
+#if !defined(ARMNN_ANDROID_S)
     if (!setRunTimePoolInfosFromHidlMemories(&memPools, request.pools))
+#else
+    if (!setRunTimePoolInfosFromCanonicalMemories(&memPools, uncheckedConvert(request.pools)))
+#endif
     {
         callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
         return V1_0::ErrorStatus::GENERAL_FAILURE;
     }
-
     // add the inputs and outputs with their data
     try
     {
@@ -302,7 +319,7 @@
             return V1_0::ErrorStatus::GENERAL_FAILURE;
         }
 
-        std::vector<OutputShape> outputShapes(request.outputs.size());
+        std::vector<V1_2::OutputShape> outputShapes(request.outputs.size());
 
         auto errorStatus = PrepareMemoryForOutputs(outputs, outputShapes, request, memPools);
         if (errorStatus != V1_0::ErrorStatus::NONE)
@@ -332,7 +349,7 @@
 
 template<typename HalVersion>
 Return<void> ArmnnPreparedModel_1_2<HalVersion>::executeSynchronously(const V1_0::Request& request,
-                                                                      MeasureTiming measureTiming,
+                                                                      V1_2::MeasureTiming measureTiming,
                                                                       executeSynchronously_cb cb)
 {
     ALOGV("ArmnnPreparedModel_1_2::executeSynchronously(): %s", GetModelSummary(m_Model).c_str());
@@ -346,7 +363,7 @@
 
     TimePoint driverStart;
 
-    if (measureTiming == MeasureTiming::YES)
+    if (measureTiming == V1_2::MeasureTiming::YES)
     {
         driverStart = Now();
     }
@@ -359,8 +376,8 @@
     }
 
     auto cbWrapper = [cb](V1_0::ErrorStatus errorStatus,
-                          std::vector<OutputShape> outputShapes,
-                          const Timing& timing,
+                          std::vector<V1_2::OutputShape> outputShapes,
+                          const V1_2::Timing& timing,
                           std::string)
         {
             cb(errorStatus, outputShapes, timing);
@@ -405,7 +422,7 @@
 
     DumpTensorsIfRequired("Input", inputTensors);
 
-    std::vector<OutputShape> outputShapes(outputTensors.size());
+    std::vector<V1_2::OutputShape> outputShapes(outputTensors.size());
     for (unsigned int i = 0; i < outputTensors.size(); i++)
     {
         std::pair<int, armnn::Tensor> outputTensorPair = outputTensors[i];
@@ -418,14 +435,14 @@
     // run it
     try
     {
-        if (cb.ctx.measureTimings == MeasureTiming::YES)
+        if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES)
         {
             deviceStart = Now();
         }
 
         armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
 
-        if (cb.ctx.measureTimings == MeasureTiming::YES)
+        if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES)
         {
             deviceEnd = Now();
         }
@@ -454,10 +471,10 @@
 
     DumpTensorsIfRequired("Output", outputTensors);
 
-    if (cb.ctx.measureTimings == MeasureTiming::YES)
+    if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES)
     {
         driverEnd = Now();
-        Timing timing;
+        V1_2::Timing timing;
         timing.timeOnDevice = MicrosecondsDuration(deviceEnd, deviceStart);
         timing.timeInDriver = MicrosecondsDuration(driverEnd, cb.ctx.driverStart);
         ALOGV("ArmnnPreparedModel_1_2::execute timing - Device = %lu Driver = %lu", timing.timeOnDevice,
@@ -494,10 +511,10 @@
         outputTensors.emplace_back(i, outputTensor);
     }
 
-    auto nullCallback = [](V1_0::ErrorStatus, std::vector<OutputShape>, const Timing&, std::string) {};
+    auto nullCallback = [](V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, const V1_2::Timing&, std::string) {};
     CallbackContext_1_2 callbackContext;
     callbackContext.callback = nullCallback;
-    callbackContext.ctx.measureTimings = MeasureTiming::NO;
+    callbackContext.ctx.measureTimings = V1_2::MeasureTiming::NO;
     auto memPools = std::make_shared<std::vector<::android::nn::RunTimePoolInfo>>();
     return ExecuteGraph(memPools,
                         inputTensors,
@@ -507,11 +524,11 @@
 
 template<typename HalVersion>
 Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::Execute(const V1_0::Request& request,
-                                                                       MeasureTiming measureTiming,
+                                                                       V1_2::MeasureTiming measureTiming,
                                                                        CallbackAsync_1_2 callback)
 {
     ExecutionContext_1_2 ctx;
-    if (measureTiming == MeasureTiming::YES)
+    if (measureTiming == V1_2::MeasureTiming::YES)
     {
         ctx.measureTimings = measureTiming;
         ctx.driverStart = Now();