blob: caa11573c597e5a385663683a549baef1d7db36b [file] [log] [blame]
Finn Williamsb454c5c2021-02-09 15:56:23 +00001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "SerializerTestUtils.hpp"
7#include "../Serializer.hpp"
8
Sadik Armagan1625efc2021-06-10 18:24:34 +01009#include <doctest/doctest.h>
10
Finn Williamsb454c5c2021-02-09 15:56:23 +000011using armnnDeserializer::IDeserializer;
12
13LayerVerifierBase::LayerVerifierBase(const std::string& layerName,
14 const std::vector<armnn::TensorInfo>& inputInfos,
15 const std::vector<armnn::TensorInfo>& outputInfos)
16 : m_LayerName(layerName)
17 , m_InputTensorInfos(inputInfos)
18 , m_OutputTensorInfos(outputInfos)
19{}
20
21void LayerVerifierBase::ExecuteStrategy(const armnn::IConnectableLayer* layer,
22 const armnn::BaseDescriptor& descriptor,
23 const std::vector<armnn::ConstTensor>& constants,
24 const char* name,
25 const armnn::LayerBindingId id)
26{
27 armnn::IgnoreUnused(descriptor, constants, id);
28 switch (layer->GetType())
29 {
30 case armnn::LayerType::Input: break;
31 case armnn::LayerType::Output: break;
32 default:
33 {
34 VerifyNameAndConnections(layer, name);
35 }
36 }
37}
38
39
40void LayerVerifierBase::VerifyNameAndConnections(const armnn::IConnectableLayer* layer, const char* name)
41{
Sadik Armagan1625efc2021-06-10 18:24:34 +010042 CHECK(std::string(name) == m_LayerName.c_str());
Finn Williamsb454c5c2021-02-09 15:56:23 +000043
Sadik Armagan1625efc2021-06-10 18:24:34 +010044 CHECK(layer->GetNumInputSlots() == m_InputTensorInfos.size());
45 CHECK(layer->GetNumOutputSlots() == m_OutputTensorInfos.size());
Finn Williamsb454c5c2021-02-09 15:56:23 +000046
47 for (unsigned int i = 0; i < m_InputTensorInfos.size(); i++)
48 {
49 const armnn::IOutputSlot* connectedOutput = layer->GetInputSlot(i).GetConnection();
Sadik Armagan1625efc2021-06-10 18:24:34 +010050 CHECK(connectedOutput);
Finn Williamsb454c5c2021-02-09 15:56:23 +000051
52 const armnn::TensorInfo& connectedInfo = connectedOutput->GetTensorInfo();
Sadik Armagan1625efc2021-06-10 18:24:34 +010053 CHECK(connectedInfo.GetShape() == m_InputTensorInfos[i].GetShape());
54 CHECK(
Finn Williamsb454c5c2021-02-09 15:56:23 +000055 GetDataTypeName(connectedInfo.GetDataType()) == GetDataTypeName(m_InputTensorInfos[i].GetDataType()));
56
Sadik Armagan1625efc2021-06-10 18:24:34 +010057 CHECK(connectedInfo.GetQuantizationScale() == m_InputTensorInfos[i].GetQuantizationScale());
58 CHECK(connectedInfo.GetQuantizationOffset() == m_InputTensorInfos[i].GetQuantizationOffset());
Finn Williamsb454c5c2021-02-09 15:56:23 +000059 }
60
61 for (unsigned int i = 0; i < m_OutputTensorInfos.size(); i++)
62 {
63 const armnn::TensorInfo& outputInfo = layer->GetOutputSlot(i).GetTensorInfo();
Sadik Armagan1625efc2021-06-10 18:24:34 +010064 CHECK(outputInfo.GetShape() == m_OutputTensorInfos[i].GetShape());
65 CHECK(GetDataTypeName(outputInfo.GetDataType()) == GetDataTypeName(m_OutputTensorInfos[i].GetDataType()));
Finn Williamsb454c5c2021-02-09 15:56:23 +000066
Sadik Armagan1625efc2021-06-10 18:24:34 +010067 CHECK(outputInfo.GetQuantizationScale() == m_OutputTensorInfos[i].GetQuantizationScale());
68 CHECK(outputInfo.GetQuantizationOffset() == m_OutputTensorInfos[i].GetQuantizationOffset());
Finn Williamsb454c5c2021-02-09 15:56:23 +000069 }
70}
71
72void LayerVerifierBase::VerifyConstTensors(const std::string& tensorName,
73 const armnn::ConstTensor* expectedPtr,
74 const armnn::ConstTensor* actualPtr)
75{
76 if (expectedPtr == nullptr)
77 {
Sadik Armagan1625efc2021-06-10 18:24:34 +010078 CHECK_MESSAGE(actualPtr == nullptr, tensorName + " should not exist");
Finn Williamsb454c5c2021-02-09 15:56:23 +000079 }
80 else
81 {
Sadik Armagan1625efc2021-06-10 18:24:34 +010082 CHECK_MESSAGE(actualPtr != nullptr, tensorName + " should have been set");
Finn Williamsb454c5c2021-02-09 15:56:23 +000083 if (actualPtr != nullptr)
84 {
85 const armnn::TensorInfo& expectedInfo = expectedPtr->GetInfo();
86 const armnn::TensorInfo& actualInfo = actualPtr->GetInfo();
87
Sadik Armagan1625efc2021-06-10 18:24:34 +010088 CHECK_MESSAGE(expectedInfo.GetShape() == actualInfo.GetShape(),
Finn Williamsb454c5c2021-02-09 15:56:23 +000089 tensorName + " shapes don't match");
Sadik Armagan1625efc2021-06-10 18:24:34 +010090 CHECK_MESSAGE(
Finn Williamsb454c5c2021-02-09 15:56:23 +000091 GetDataTypeName(expectedInfo.GetDataType()) == GetDataTypeName(actualInfo.GetDataType()),
92 tensorName + " data types don't match");
93
Sadik Armagan1625efc2021-06-10 18:24:34 +010094 CHECK_MESSAGE(expectedPtr->GetNumBytes() == actualPtr->GetNumBytes(),
Finn Williamsb454c5c2021-02-09 15:56:23 +000095 tensorName + " (GetNumBytes) data sizes do not match");
96 if (expectedPtr->GetNumBytes() == actualPtr->GetNumBytes())
97 {
98 //check the data is identical
99 const char* expectedData = static_cast<const char*>(expectedPtr->GetMemoryArea());
100 const char* actualData = static_cast<const char*>(actualPtr->GetMemoryArea());
101 bool same = true;
102 for (unsigned int i = 0; i < expectedPtr->GetNumBytes(); ++i)
103 {
104 same = expectedData[i] == actualData[i];
105 if (!same)
106 {
107 break;
108 }
109 }
Sadik Armagan1625efc2021-06-10 18:24:34 +0100110 CHECK_MESSAGE(same, tensorName + " data does not match");
Finn Williamsb454c5c2021-02-09 15:56:23 +0000111 }
112 }
113 }
114}
115
116void CompareConstTensor(const armnn::ConstTensor& tensor1, const armnn::ConstTensor& tensor2)
117{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100118 CHECK(tensor1.GetShape() == tensor2.GetShape());
119 CHECK(GetDataTypeName(tensor1.GetDataType()) == GetDataTypeName(tensor2.GetDataType()));
Finn Williamsb454c5c2021-02-09 15:56:23 +0000120
121 switch (tensor1.GetDataType())
122 {
123 case armnn::DataType::Float32:
124 CompareConstTensorData<const float*>(
125 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
126 break;
127 case armnn::DataType::QAsymmU8:
128 case armnn::DataType::Boolean:
129 CompareConstTensorData<const uint8_t*>(
130 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
131 break;
132 case armnn::DataType::QSymmS8:
133 CompareConstTensorData<const int8_t*>(
134 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
135 break;
136 case armnn::DataType::Signed32:
137 CompareConstTensorData<const int32_t*>(
138 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
139 break;
140 default:
141 // Note that Float16 is not yet implemented
Sadik Armagan1625efc2021-06-10 18:24:34 +0100142 MESSAGE("Unexpected datatype");
143 CHECK(false);
Finn Williamsb454c5c2021-02-09 15:56:23 +0000144 }
145}
146
147armnn::INetworkPtr DeserializeNetwork(const std::string& serializerString)
148{
149 std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()};
150 return IDeserializer::Create()->CreateNetworkFromBinary(serializerVector);
151}
152
153std::string SerializeNetwork(const armnn::INetwork& network)
154{
155 armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create();
156
157 serializer->Serialize(network);
158
159 std::stringstream stream;
160 serializer->SaveSerializedToStream(stream);
161
162 std::string serializerString{stream.str()};
163 return serializerString;
164}