IVGCVSW-7063 'Support Library NNAPI Caching'

* Fixed caching issue.

Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: Ic7b3e0bd4438b2fd1b3dbfa86b6c89d625bbf9dd
diff --git a/shim/sl/CMakeLists.txt b/shim/sl/CMakeLists.txt
index 81c97f9..0ba6390 100644
--- a/shim/sl/CMakeLists.txt
+++ b/shim/sl/CMakeLists.txt
@@ -474,8 +474,6 @@
         canonical/ArmnnDriver.hpp
         canonical/ArmnnDriverImpl.cpp
         canonical/ArmnnDriverImpl.hpp
-        canonical/CacheDataHandler.cpp
-        canonical/CacheDataHandler.hpp
         canonical/CanonicalUtils.cpp
         canonical/CanonicalUtils.hpp
         canonical/ConversionUtils.cpp
diff --git a/shim/sl/canonical/ArmnnDriverImpl.cpp b/shim/sl/canonical/ArmnnDriverImpl.cpp
index 8706c38..0c98a16 100644
--- a/shim/sl/canonical/ArmnnDriverImpl.cpp
+++ b/shim/sl/canonical/ArmnnDriverImpl.cpp
@@ -5,7 +5,6 @@
 
 #include "ArmnnDriverImpl.hpp"
 #include "ArmnnPreparedModel.hpp"
-#include "CacheDataHandler.hpp"
 #include "ModelToINetworkTransformer.hpp"
 #include "SystemPropertiesUtils.hpp"
 
@@ -62,6 +61,16 @@
              /* whilePerformance */ defaultPerfInfo };
 }
 
+size_t Hash(std::vector<uint8_t>& cacheData)
+{
+    std::size_t hash = cacheData.size();
+    for (auto& i : cacheData)
+    {
+        hash = ((hash << 5) - hash) + i;
+    }
+    return hash;
+}
+
 } // anonymous namespace
 
 using namespace android::nn;
@@ -87,33 +96,6 @@
     return valid;
 }
 
-bool ArmnnDriverImpl::ValidateDataCacheHandle(const std::vector<SharedHandle>& dataCacheHandle, const size_t dataSize)
-{
-    bool valid = true;
-    // DataCacheHandle size should always be 1 for ArmNN model
-    if (dataCacheHandle.size() != 1)
-    {
-        return !valid;
-    }
-
-    if (dataSize == 0)
-    {
-        return !valid;
-    }
-
-    struct stat statBuffer;
-    if (fstat(*dataCacheHandle[0], &statBuffer) == 0)
-    {
-        unsigned long bufferSize = statBuffer.st_size;
-        if (bufferSize != dataSize)
-        {
-            return !valid;
-        }
-    }
-
-    return ValidateSharedHandle(dataCacheHandle[0]);
-}
-
 GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModel(
     const armnn::IRuntimePtr& runtime,
     const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
@@ -274,8 +256,7 @@
     size_t hashValue = 0;
     if (dataCacheHandle.size() == 1 )
     {
-        write(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size());
-        hashValue = CacheDataHandlerInstance().Hash(dataCacheData);
+        hashValue = Hash(dataCacheData);
     }
 
     // Cache the model data
@@ -296,16 +277,20 @@
                         {
                             std::vector<uint8_t> modelData(modelDataSize);
                             pread(*modelCacheHandle[i], modelData.data(), modelData.size(), 0);
-                            hashValue ^= CacheDataHandlerInstance().Hash(modelData);
+                            hashValue ^= Hash(modelData);
                         }
                     }
                 }
             }
         }
     }
-    if (hashValue != 0)
+    if (dataCacheHandle.size() == 1 && hashValue != 0)
     {
-        CacheDataHandlerInstance().Register(token, hashValue, dataCacheData.size());
+        std::vector<uint8_t> theHashValue(sizeof(hashValue));
+        ::memcpy(theHashValue.data(), &hashValue, sizeof(hashValue));
+
+        write(*dataCacheHandle[0], theHashValue.data(), theHashValue.size());
+        pwrite(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), theHashValue.size());
     }
 
     bool executeWithDummyInputs = (std::find(options.GetBackends().begin(),
@@ -374,13 +359,30 @@
     }
 
     // Validate dataCacheHandle
-    auto dataSize = CacheDataHandlerInstance().GetCacheSize(token);
-    if (!ValidateDataCacheHandle(dataCacheHandle, dataSize))
+    if (dataCacheHandle.size() != 1)
     {
         return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
                             << "ArmnnDriverImpl::prepareModelFromCache(): Not valid data cache handle!";
     }
 
+    if (!ValidateSharedHandle(dataCacheHandle[0]))
+    {
+        return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
+                << "ArmnnDriverImpl::prepareModelFromCache(): Not valid data cache handle!";
+    }
+
+    size_t cachedDataSize = 0;
+    struct stat dataStatBuffer;
+    if (fstat(*dataCacheHandle[0], &dataStatBuffer) == 0)
+    {
+        cachedDataSize = dataStatBuffer.st_size;
+    }
+    if (cachedDataSize == 0)
+    {
+        return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
+                << "ArmnnDriverImpl::prepareModelFromCache(): Not valid cached data!";
+    }
+
     // Check if model files cached they match the expected value
     unsigned int numberOfCachedModelFiles = 0;
     for (auto& backend : options.GetBackends())
@@ -393,10 +395,14 @@
                            << "ArmnnDriverImpl::prepareModelFromCache(): Model cache handle size does not match.";
     }
 
+    // Read the hashValue
+    std::vector<uint8_t> hashValue(sizeof(size_t));
+    pread(*dataCacheHandle[0], hashValue.data(), hashValue.size(), 0);
+
     // Read the model
-    std::vector<uint8_t> dataCacheData(dataSize);
-    pread(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), 0);
-    auto hashValue = CacheDataHandlerInstance().Hash(dataCacheData);
+    std::vector<uint8_t> dataCacheData(cachedDataSize - hashValue.size());
+    pread(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), hashValue.size());
+    auto calculatedHashValue = Hash(dataCacheData);
 
     int gpuAccCachedFd = -1;
     if (modelCacheHandle.size() > 0)
@@ -423,7 +429,7 @@
                     {
                         std::vector<uint8_t> modelData(modelDataSize);
                         pread(cachedFd, modelData.data(), modelData.size(), 0);
-                        hashValue ^= CacheDataHandlerInstance().Hash(modelData);
+                        calculatedHashValue ^= Hash(modelData);
 
                         if (backend == armnn::Compute::GpuAcc)
                         {
@@ -436,7 +442,9 @@
         }
     }
 
-    if (!CacheDataHandlerInstance().Validate(token, hashValue, dataCacheData.size()))
+    std::vector<uint8_t> calculatedHashData(sizeof(calculatedHashValue));
+    ::memcpy(calculatedHashData.data(), &calculatedHashValue, sizeof(calculatedHashValue));
+    if (hashValue != calculatedHashData)
     {
         return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
                 << "ArmnnDriverImpl::prepareModelFromCache(): ValidateHash() failed!";
@@ -529,12 +537,13 @@
         return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << message.str();
     }
 
-    return std::make_shared<const ArmnnPreparedModel>(netId,
+    auto preparedModel = std::make_shared<const ArmnnPreparedModel>(netId,
                                                       runtime.get(),
                                                       options.GetRequestInputsAndOutputsDumpDir(),
                                                       options.IsGpuProfilingEnabled(),
                                                       Priority::MEDIUM,
                                                       true);
+    return std::move(preparedModel);
 }
 
 const Capabilities& ArmnnDriverImpl::GetCapabilities(const armnn::IRuntimePtr& runtime)
diff --git a/shim/sl/canonical/CacheDataHandler.cpp b/shim/sl/canonical/CacheDataHandler.cpp
deleted file mode 100644
index 930a8e4..0000000
--- a/shim/sl/canonical/CacheDataHandler.cpp
+++ /dev/null
@@ -1,69 +0,0 @@
-//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "CacheDataHandler.hpp"
-
-#include <log/log.h>
-
-namespace armnn_driver
-{
-
-CacheDataHandler& CacheDataHandlerInstance()
-{
-    static CacheDataHandler instance;
-    return instance;
-}
-
-void CacheDataHandler::Register(const android::nn::CacheToken token, const size_t hashValue, const size_t cacheSize)
-{
-    if (!m_CacheDataMap.empty()
-            && m_CacheDataMap.find(hashValue) != m_CacheDataMap.end()
-            && m_CacheDataMap.at(hashValue).GetToken() == token
-            && m_CacheDataMap.at(hashValue).GetCacheSize() == cacheSize)
-    {
-        return;
-    }
-    CacheHandle cacheHandle(token, cacheSize);
-    m_CacheDataMap.insert({hashValue, cacheHandle});
-}
-
-bool CacheDataHandler::Validate(const android::nn::CacheToken token,
-                                const size_t hashValue,
-                                const size_t cacheSize) const
-{
-    return (!m_CacheDataMap.empty()
-            && m_CacheDataMap.find(hashValue) != m_CacheDataMap.end()
-            && m_CacheDataMap.at(hashValue).GetToken() == token
-            && m_CacheDataMap.at(hashValue).GetCacheSize() == cacheSize);
-}
-
-size_t CacheDataHandler::Hash(std::vector<uint8_t>& cacheData)
-{
-    std::size_t hash = cacheData.size();
-    for (auto& i : cacheData)
-    {
-        hash = ((hash << 5) - hash) + i;
-    }
-    return hash;
-}
-
-size_t CacheDataHandler::GetCacheSize(android::nn::CacheToken token)
-{
-    for (auto i = m_CacheDataMap.begin(); i != m_CacheDataMap.end(); ++i)
-    {
-        if (i->second.GetToken() == token)
-        {
-            return i->second.GetCacheSize();
-        }
-    }
-    return 0;
-}
-
-void CacheDataHandler::Clear()
-{
-    m_CacheDataMap.clear();
-}
-
-} // armnn_driver
diff --git a/shim/sl/canonical/CacheDataHandler.hpp b/shim/sl/canonical/CacheDataHandler.hpp
deleted file mode 100644
index 95464a9..0000000
--- a/shim/sl/canonical/CacheDataHandler.hpp
+++ /dev/null
@@ -1,64 +0,0 @@
-//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <vector>
-#include <unordered_map>
-
-#include <nnapi/Types.h>
-
-namespace armnn_driver
-{
-
-class CacheHandle
-{
-public:
-    CacheHandle(const android::nn::CacheToken token, const size_t cacheSize)
-    : m_CacheToken(token), m_CacheSize(cacheSize) {}
-
-    ~CacheHandle() {};
-
-    android::nn::CacheToken GetToken() const
-    {
-        return m_CacheToken;
-    }
-
-    size_t GetCacheSize() const
-    {
-        return m_CacheSize;
-    }
-
-private:
-    const android::nn::CacheToken m_CacheToken;
-    const size_t m_CacheSize;
-};
-
-class CacheDataHandler
-{
-public:
-    CacheDataHandler() {}
-    ~CacheDataHandler() {}
-
-    void Register(const android::nn::CacheToken token, const size_t hashValue, const size_t cacheSize);
-
-    bool Validate(const android::nn::CacheToken token, const size_t hashValue, const size_t cacheSize) const;
-
-    size_t Hash(std::vector<uint8_t>& cacheData);
-
-    size_t GetCacheSize(android::nn::CacheToken token);
-
-    void Clear();
-
-private:
-    CacheDataHandler(const CacheDataHandler&) = delete;
-    CacheDataHandler& operator=(const CacheDataHandler&) = delete;
-
-    std::unordered_map<size_t, CacheHandle> m_CacheDataMap;
-};
-
-CacheDataHandler& CacheDataHandlerInstance();
-
-} // armnn_driver
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 90558bb..a405cb9 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -1396,6 +1396,7 @@
                               weightsShape[0],
                               weightsShape[1],
                               weightsShape[2]*weightsShape[3]});
+        weightsInfo.SetConstant(true);
 
         armnn::ConstTensor weightsPermuted(weightsInfo, permuteBuffer.get());
 
@@ -1412,6 +1413,7 @@
         layer = m_Network->AddConstantLayer(input, layerName.c_str());
 
         armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+        outputTensorInfo.SetConstant(true);
         layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
     }