IVGCVSW-5636 'Implement NNAPI caching functions'

* Fixed test failures.

!armnn:6617

Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Signed-off-by: Kevin May <kevin.may@arm.com>
Change-Id: I9989ece8999d67dd40dfcf69b73f4d80f71687a4
diff --git a/1.2/ArmnnDriverImpl.cpp b/1.2/ArmnnDriverImpl.cpp
index b3bc5cd..3274a8a 100644
--- a/1.2/ArmnnDriverImpl.cpp
+++ b/1.2/ArmnnDriverImpl.cpp
@@ -315,6 +315,14 @@
             NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release());
             return V1_0::ErrorStatus::NONE;
         }
+
+        if (dataCacheHandle[0]->data[0] < 0)
+        {
+            ALOGW("ArmnnDriverImpl::prepareArmnnModel_1_3: Cannot cache the data, fd < 0");
+            NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release());
+            return V1_0::ErrorStatus::NONE;
+        }
+
         int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE;
         if (dataCacheFileAccessMode != O_RDWR)
         {
@@ -420,6 +428,13 @@
         return V1_0::ErrorStatus::GENERAL_FAILURE;
     }
 
+    if (dataCacheHandle[0]->data[0] < 0)
+    {
+        ALOGW("ArmnnDriverImpl::prepareModelFromCache: Cannot read from the cache data, fd < 0");
+        FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "No data cache!", cb);
+        return V1_0::ErrorStatus::GENERAL_FAILURE;
+    }
+
     int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE;
     if (dataCacheFileAccessMode != O_RDWR)
     {
@@ -441,16 +456,12 @@
         if (fstat(dataCacheHandle[0]->data[0], &statBuffer) == 0)
         {
             unsigned long bufferSize = statBuffer.st_size;
-            if (bufferSize <= 0)
+            if (bufferSize != dataSize)
             {
                 ALOGW("ArmnnDriverImpl::prepareModelFromCache: Invalid data to deserialize!");
                 FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Invalid data to deserialize!", cb);
                 return V1_0::ErrorStatus::GENERAL_FAILURE;
             }
-            if (bufferSize > dataSize)
-            {
-                offset = bufferSize - dataSize;
-            }
         }
     }
     std::vector<uint8_t> dataCacheData(dataSize);
@@ -489,17 +500,19 @@
                 if (cachedFd != -1 && fstat(cachedFd, &statBuffer) == 0)
                 {
                     long modelDataSize = statBuffer.st_size;
-                    if (modelDataSize > 0)
+                    if (modelDataSize <= 0)
                     {
-                        std::vector<uint8_t> modelData(modelDataSize);
-                        pread(cachedFd, modelData.data(), modelData.size(), 0);
-                        hashValue ^= CacheDataHandlerInstance().Hash(modelData);
+                        FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Wrong cached model size!", cb);
+                        return V1_0::ErrorStatus::NONE;
+                    }
+                    std::vector<uint8_t> modelData(modelDataSize);
+                    pread(cachedFd, modelData.data(), modelData.size(), 0);
+                    hashValue ^= CacheDataHandlerInstance().Hash(modelData);
 
-                        // For GpuAcc numberOfCachedFiles is 1
-                        if (backend == armnn::Compute::GpuAcc)
-                        {
-                            gpuAccCachedFd = cachedFd;
-                        }
+                    // For GpuAcc numberOfCachedFiles is 1
+                    if (backend == armnn::Compute::GpuAcc)
+                    {
+                        gpuAccCachedFd = cachedFd;
                     }
                 }
                 index += numberOfCacheFiles;
@@ -507,7 +520,7 @@
         }
     }
 
-    if (!CacheDataHandlerInstance().Validate(token, hashValue))
+    if (!CacheDataHandlerInstance().Validate(token, hashValue, dataCacheData.size()))
     {
         ALOGW("ArmnnDriverImpl::prepareModelFromCache: ValidateHash() failed!");
         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "ValidateHash Failed!", cb);
@@ -515,7 +528,18 @@
     }
 
     // Deserialize the network..
-    auto network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
+    armnn::INetworkPtr network = armnn::INetworkPtr(nullptr, [](armnn::INetwork*){});
+    try
+    {
+        network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
+    }
+    catch (std::exception& e)
+    {
+        std::stringstream message;
+        message << "Exception (" << e.what() << ") caught from Deserializer.";
+        FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
+        return V1_0::ErrorStatus::GENERAL_FAILURE;
+    }
 
     // Optimize the network
     armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
diff --git a/1.3/ArmnnDriverImpl.cpp b/1.3/ArmnnDriverImpl.cpp
index e1d65f9..c8b1d96 100644
--- a/1.3/ArmnnDriverImpl.cpp
+++ b/1.3/ArmnnDriverImpl.cpp
@@ -328,6 +328,14 @@
             NotifyCallbackAndCheck(cb, V1_3::ErrorStatus::NONE, preparedModel.release());
             return V1_3::ErrorStatus::NONE;
         }
+
+        if (dataCacheHandle[0]->data[0] < 0)
+        {
+            ALOGW("ArmnnDriverImpl::prepareArmnnModel_1_3: Cannot cache the data, fd < 0");
+            NotifyCallbackAndCheck(cb, V1_3::ErrorStatus::NONE, preparedModel.release());
+            return V1_3::ErrorStatus::NONE;
+        }
+
         int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE;
         if (dataCacheFileAccessMode != O_RDWR)
         {
@@ -435,6 +443,13 @@
         return V1_3::ErrorStatus::GENERAL_FAILURE;
     }
 
+    if (dataCacheHandle[0]->data[0] < 0)
+    {
+        ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3(): Cannot read from the cache data, fd < 0");
+        cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
+        return V1_3::ErrorStatus::GENERAL_FAILURE;
+    }
+
     int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE;
     if (dataCacheFileAccessMode != O_RDWR)
     {
@@ -456,16 +471,12 @@
         if (fstat(dataCacheHandle[0]->data[0], &statBuffer) == 0)
         {
             unsigned long bufferSize = statBuffer.st_size;
-            if (bufferSize <= 0)
+            if (bufferSize != dataSize)
             {
                 ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3: Invalid data to deserialize!");
                 cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
                 return V1_3::ErrorStatus::GENERAL_FAILURE;
             }
-            if (bufferSize > dataSize)
-            {
-                offset = bufferSize - dataSize;
-            }
         }
     }
     std::vector<uint8_t> dataCacheData(dataSize);
@@ -504,17 +515,20 @@
                 if (cachedFd != -1 && fstat(cachedFd, &statBuffer) == 0)
                 {
                     long modelDataSize = statBuffer.st_size;
-                    if (modelDataSize > 0)
+                    if (modelDataSize <= 0)
                     {
-                        std::vector<uint8_t> modelData(modelDataSize);
-                        pread(cachedFd, modelData.data(), modelData.size(), 0);
-                        hashValue ^= CacheDataHandlerInstance().Hash(modelData);
+                        ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3(): Wrong cached model size!");
+                        cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
+                        return V1_3::ErrorStatus::NONE;
+                    }
+                    std::vector<uint8_t> modelData(modelDataSize);
+                    pread(cachedFd, modelData.data(), modelData.size(), 0);
+                    hashValue ^= CacheDataHandlerInstance().Hash(modelData);
 
-                        // For GpuAcc numberOfCachedFiles is 1
-                        if (backend == armnn::Compute::GpuAcc)
-                        {
-                            gpuAccCachedFd = cachedFd;
-                        }
+                    // For GpuAcc numberOfCachedFiles is 1
+                    if (backend == armnn::Compute::GpuAcc)
+                    {
+                        gpuAccCachedFd = cachedFd;
                     }
                 }
                 index += numberOfCacheFiles;
@@ -522,7 +536,7 @@
         }
     }
 
-    if (!CacheDataHandlerInstance().Validate(token, hashValue))
+    if (!CacheDataHandlerInstance().Validate(token, hashValue, dataCacheData.size()))
     {
         ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3: ValidateHash() failed!");
         cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
@@ -530,7 +544,17 @@
     }
 
     // Deserialize the network..
-    auto network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
+    armnn::INetworkPtr network = armnn::INetworkPtr(nullptr, [](armnn::INetwork*){});
+    try
+    {
+        network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
+    }
+    catch (std::exception&)
+    {
+        ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_3: Exception caught from Deserializer!");
+        cb->notify_1_3(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
+        return V1_3::ErrorStatus::GENERAL_FAILURE;
+    }
 
     // Optimize the network
     armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
diff --git a/CacheDataHandler.cpp b/CacheDataHandler.cpp
index 3688162..5f3a307 100644
--- a/CacheDataHandler.cpp
+++ b/CacheDataHandler.cpp
@@ -18,19 +18,22 @@
 
 void CacheDataHandler::Register(const HidlToken token, const size_t hashValue, const size_t cacheSize)
 {
-    if (m_CacheDataMap.find(hashValue) != m_CacheDataMap.end())
+    if (m_CacheDataMap.find(hashValue) != m_CacheDataMap.end()
+                        && m_CacheDataMap.at(hashValue).GetToken() == token
+                        && m_CacheDataMap.at(hashValue).GetCacheSize() == cacheSize)
     {
-        ALOGV("CacheHandler::Register() Token has been already registered.");
+        ALOGV("CacheHandler::Register() Hash value has already registered.");
         return;
     }
     CacheHandle cacheHandle(token, cacheSize);
     m_CacheDataMap.insert({hashValue, cacheHandle});
 }
 
-bool CacheDataHandler::Validate(const HidlToken token, const size_t hashValue) const
+bool CacheDataHandler::Validate(const HidlToken token, const size_t hashValue, const size_t cacheSize) const
 {
     return (m_CacheDataMap.find(hashValue) != m_CacheDataMap.end()
-                             && m_CacheDataMap.at(hashValue).GetToken() == token);
+                             && m_CacheDataMap.at(hashValue).GetToken() == token
+                             && m_CacheDataMap.at(hashValue).GetCacheSize() == cacheSize);
 }
 
 size_t CacheDataHandler::Hash(std::vector<uint8_t>& cacheData)
@@ -38,7 +41,7 @@
     std::size_t hash = cacheData.size();
     for (auto& i : cacheData)
     {
-        hash ^= std::hash<unsigned int>{}(i);
+        hash = ((hash << 5) - hash) + i;
     }
     return hash;
 }
diff --git a/CacheDataHandler.hpp b/CacheDataHandler.hpp
index cea73d2..5b1b295 100644
--- a/CacheDataHandler.hpp
+++ b/CacheDataHandler.hpp
@@ -48,7 +48,7 @@
 
     void Register(const HidlToken token, const size_t hashValue, const size_t cacheSize);
 
-    bool Validate(const HidlToken token, const size_t hashValue) const;
+    bool Validate(const HidlToken token, const size_t hashValue, const size_t cacheSize) const;
 
     size_t Hash(std::vector<uint8_t>& cacheData);