IVGCVSW-7063 'Support Library NNAPI Caching'

* Fixed caching issue.

Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: Ic7b3e0bd4438b2fd1b3dbfa86b6c89d625bbf9dd
diff --git a/shim/sl/canonical/ArmnnDriverImpl.cpp b/shim/sl/canonical/ArmnnDriverImpl.cpp
index 8706c38..0c98a16 100644
--- a/shim/sl/canonical/ArmnnDriverImpl.cpp
+++ b/shim/sl/canonical/ArmnnDriverImpl.cpp
@@ -5,7 +5,6 @@
 
 #include "ArmnnDriverImpl.hpp"
 #include "ArmnnPreparedModel.hpp"
-#include "CacheDataHandler.hpp"
 #include "ModelToINetworkTransformer.hpp"
 #include "SystemPropertiesUtils.hpp"
 
@@ -62,6 +61,16 @@
              /* whilePerformance */ defaultPerfInfo };
 }
 
+size_t Hash(std::vector<uint8_t>& cacheData)
+{
+    std::size_t hash = cacheData.size();
+    for (auto& i : cacheData)
+    {
+        hash = ((hash << 5) - hash) + i;
+    }
+    return hash;
+}
+
 } // anonymous namespace
 
 using namespace android::nn;
@@ -87,33 +96,6 @@
     return valid;
 }
 
-bool ArmnnDriverImpl::ValidateDataCacheHandle(const std::vector<SharedHandle>& dataCacheHandle, const size_t dataSize)
-{
-    bool valid = true;
-    // DataCacheHandle size should always be 1 for ArmNN model
-    if (dataCacheHandle.size() != 1)
-    {
-        return !valid;
-    }
-
-    if (dataSize == 0)
-    {
-        return !valid;
-    }
-
-    struct stat statBuffer;
-    if (fstat(*dataCacheHandle[0], &statBuffer) == 0)
-    {
-        unsigned long bufferSize = statBuffer.st_size;
-        if (bufferSize != dataSize)
-        {
-            return !valid;
-        }
-    }
-
-    return ValidateSharedHandle(dataCacheHandle[0]);
-}
-
 GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModel(
     const armnn::IRuntimePtr& runtime,
     const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
@@ -274,8 +256,7 @@
     size_t hashValue = 0;
     if (dataCacheHandle.size() == 1 )
     {
-        write(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size());
-        hashValue = CacheDataHandlerInstance().Hash(dataCacheData);
+        hashValue = Hash(dataCacheData);
     }
 
     // Cache the model data
@@ -296,16 +277,20 @@
                         {
                             std::vector<uint8_t> modelData(modelDataSize);
                             pread(*modelCacheHandle[i], modelData.data(), modelData.size(), 0);
-                            hashValue ^= CacheDataHandlerInstance().Hash(modelData);
+                            hashValue ^= Hash(modelData);
                         }
                     }
                 }
             }
         }
     }
-    if (hashValue != 0)
+    if (dataCacheHandle.size() == 1 && hashValue != 0)
     {
-        CacheDataHandlerInstance().Register(token, hashValue, dataCacheData.size());
+        std::vector<uint8_t> theHashValue(sizeof(hashValue));
+        ::memcpy(theHashValue.data(), &hashValue, sizeof(hashValue));
+
+        write(*dataCacheHandle[0], theHashValue.data(), theHashValue.size());
+        pwrite(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), theHashValue.size());
     }
 
     bool executeWithDummyInputs = (std::find(options.GetBackends().begin(),
@@ -374,13 +359,30 @@
     }
 
     // Validate dataCacheHandle
-    auto dataSize = CacheDataHandlerInstance().GetCacheSize(token);
-    if (!ValidateDataCacheHandle(dataCacheHandle, dataSize))
+    if (dataCacheHandle.size() != 1)
     {
         return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
                             << "ArmnnDriverImpl::prepareModelFromCache(): Not valid data cache handle!";
     }
 
+    if (!ValidateSharedHandle(dataCacheHandle[0]))
+    {
+        return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
+                << "ArmnnDriverImpl::prepareModelFromCache(): Not valid data cache handle!";
+    }
+
+    size_t cachedDataSize = 0;
+    struct stat dataStatBuffer;
+    if (fstat(*dataCacheHandle[0], &dataStatBuffer) == 0)
+    {
+        cachedDataSize = dataStatBuffer.st_size;
+    }
+    if (cachedDataSize == 0)
+    {
+        return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
+                << "ArmnnDriverImpl::prepareModelFromCache(): Not valid cached data!";
+    }
+
     // Check if model files cached they match the expected value
     unsigned int numberOfCachedModelFiles = 0;
     for (auto& backend : options.GetBackends())
@@ -393,10 +395,14 @@
                            << "ArmnnDriverImpl::prepareModelFromCache(): Model cache handle size does not match.";
     }
 
+    // Read the hashValue
+    std::vector<uint8_t> hashValue(sizeof(size_t));
+    pread(*dataCacheHandle[0], hashValue.data(), hashValue.size(), 0);
+
     // Read the model
-    std::vector<uint8_t> dataCacheData(dataSize);
-    pread(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), 0);
-    auto hashValue = CacheDataHandlerInstance().Hash(dataCacheData);
+    std::vector<uint8_t> dataCacheData(cachedDataSize - hashValue.size());
+    pread(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), hashValue.size());
+    auto calculatedHashValue = Hash(dataCacheData);
 
     int gpuAccCachedFd = -1;
     if (modelCacheHandle.size() > 0)
@@ -423,7 +429,7 @@
                     {
                         std::vector<uint8_t> modelData(modelDataSize);
                         pread(cachedFd, modelData.data(), modelData.size(), 0);
-                        hashValue ^= CacheDataHandlerInstance().Hash(modelData);
+                        calculatedHashValue ^= Hash(modelData);
 
                         if (backend == armnn::Compute::GpuAcc)
                         {
@@ -436,7 +442,9 @@
         }
     }
 
-    if (!CacheDataHandlerInstance().Validate(token, hashValue, dataCacheData.size()))
+    std::vector<uint8_t> calculatedHashData(sizeof(calculatedHashValue));
+    ::memcpy(calculatedHashData.data(), &calculatedHashValue, sizeof(calculatedHashValue));
+    if (hashValue != calculatedHashData)
     {
         return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
                 << "ArmnnDriverImpl::prepareModelFromCache(): ValidateHash() failed!";
@@ -529,12 +537,13 @@
         return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << message.str();
     }
 
-    return std::make_shared<const ArmnnPreparedModel>(netId,
+    auto preparedModel = std::make_shared<const ArmnnPreparedModel>(netId,
                                                       runtime.get(),
                                                       options.GetRequestInputsAndOutputsDumpDir(),
                                                       options.IsGpuProfilingEnabled(),
                                                       Priority::MEDIUM,
                                                       true);
+    return std::move(preparedModel);
 }
 
 const Capabilities& ArmnnDriverImpl::GetCapabilities(const armnn::IRuntimePtr& runtime)