blob: 187384777dcf93532d95ce622c04cf22d1245efd [file] [log] [blame]
Finn Williamsb454c5c2021-02-09 15:56:23 +00001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "SerializerTestUtils.hpp"
7#include "../Serializer.hpp"
8
Sadik Armagan1625efc2021-06-10 18:24:34 +01009#include <doctest/doctest.h>
10
Finn Williamsb454c5c2021-02-09 15:56:23 +000011using armnnDeserializer::IDeserializer;
12
13LayerVerifierBase::LayerVerifierBase(const std::string& layerName,
14 const std::vector<armnn::TensorInfo>& inputInfos,
15 const std::vector<armnn::TensorInfo>& outputInfos)
16 : m_LayerName(layerName)
17 , m_InputTensorInfos(inputInfos)
18 , m_OutputTensorInfos(outputInfos)
19{}
20
21void LayerVerifierBase::ExecuteStrategy(const armnn::IConnectableLayer* layer,
22 const armnn::BaseDescriptor& descriptor,
23 const std::vector<armnn::ConstTensor>& constants,
24 const char* name,
25 const armnn::LayerBindingId id)
26{
27 armnn::IgnoreUnused(descriptor, constants, id);
28 switch (layer->GetType())
29 {
30 case armnn::LayerType::Input: break;
31 case armnn::LayerType::Output: break;
32 default:
33 {
34 VerifyNameAndConnections(layer, name);
35 }
36 }
37}
38
39
40void LayerVerifierBase::VerifyNameAndConnections(const armnn::IConnectableLayer* layer, const char* name)
41{
Sadik Armagan1625efc2021-06-10 18:24:34 +010042 CHECK(std::string(name) == m_LayerName.c_str());
Finn Williamsb454c5c2021-02-09 15:56:23 +000043
Sadik Armagan1625efc2021-06-10 18:24:34 +010044 CHECK(layer->GetNumInputSlots() == m_InputTensorInfos.size());
45 CHECK(layer->GetNumOutputSlots() == m_OutputTensorInfos.size());
Finn Williamsb454c5c2021-02-09 15:56:23 +000046
47 for (unsigned int i = 0; i < m_InputTensorInfos.size(); i++)
48 {
49 const armnn::IOutputSlot* connectedOutput = layer->GetInputSlot(i).GetConnection();
Sadik Armagan1625efc2021-06-10 18:24:34 +010050 CHECK(connectedOutput);
Finn Williamsb454c5c2021-02-09 15:56:23 +000051
52 const armnn::TensorInfo& connectedInfo = connectedOutput->GetTensorInfo();
Sadik Armagan1625efc2021-06-10 18:24:34 +010053 CHECK(connectedInfo.GetShape() == m_InputTensorInfos[i].GetShape());
Keith Davisb4dd5cc2022-04-07 11:32:00 +010054 CHECK(GetDataTypeName(connectedInfo.GetDataType()) == GetDataTypeName(m_InputTensorInfos[i].GetDataType()));
Finn Williamsb454c5c2021-02-09 15:56:23 +000055
Keith Davisb4dd5cc2022-04-07 11:32:00 +010056 if (connectedInfo.HasMultipleQuantizationScales())
57 {
58 CHECK(connectedInfo.GetQuantizationScales() == m_InputTensorInfos[i].GetQuantizationScales());
59 }
60 else
Cathal Corbett06902652022-04-14 17:55:11 +010061 {
62 CHECK(connectedInfo.GetQuantizationScale() == m_InputTensorInfos[i].GetQuantizationScale());
Cathal Corbett06902652022-04-14 17:55:11 +010063 }
Keith Davisb4dd5cc2022-04-07 11:32:00 +010064 CHECK(connectedInfo.GetQuantizationOffset() == m_InputTensorInfos[i].GetQuantizationOffset());
Finn Williamsb454c5c2021-02-09 15:56:23 +000065 }
66
67 for (unsigned int i = 0; i < m_OutputTensorInfos.size(); i++)
68 {
69 const armnn::TensorInfo& outputInfo = layer->GetOutputSlot(i).GetTensorInfo();
Sadik Armagan1625efc2021-06-10 18:24:34 +010070 CHECK(outputInfo.GetShape() == m_OutputTensorInfos[i].GetShape());
71 CHECK(GetDataTypeName(outputInfo.GetDataType()) == GetDataTypeName(m_OutputTensorInfos[i].GetDataType()));
Finn Williamsb454c5c2021-02-09 15:56:23 +000072
Sadik Armagan1625efc2021-06-10 18:24:34 +010073 CHECK(outputInfo.GetQuantizationScale() == m_OutputTensorInfos[i].GetQuantizationScale());
74 CHECK(outputInfo.GetQuantizationOffset() == m_OutputTensorInfos[i].GetQuantizationOffset());
Finn Williamsb454c5c2021-02-09 15:56:23 +000075 }
76}
77
78void LayerVerifierBase::VerifyConstTensors(const std::string& tensorName,
79 const armnn::ConstTensor* expectedPtr,
80 const armnn::ConstTensor* actualPtr)
81{
82 if (expectedPtr == nullptr)
83 {
Jan Eilers13d2e0d2021-09-28 15:11:28 +010084 CHECK_MESSAGE(actualPtr == nullptr, (tensorName + " should not exist"));
Finn Williamsb454c5c2021-02-09 15:56:23 +000085 }
86 else
87 {
Jan Eilers13d2e0d2021-09-28 15:11:28 +010088 CHECK_MESSAGE(actualPtr != nullptr, (tensorName + " should have been set"));
Finn Williamsb454c5c2021-02-09 15:56:23 +000089 if (actualPtr != nullptr)
90 {
91 const armnn::TensorInfo& expectedInfo = expectedPtr->GetInfo();
92 const armnn::TensorInfo& actualInfo = actualPtr->GetInfo();
93
Sadik Armagan1625efc2021-06-10 18:24:34 +010094 CHECK_MESSAGE(expectedInfo.GetShape() == actualInfo.GetShape(),
Jan Eilers13d2e0d2021-09-28 15:11:28 +010095 (tensorName + " shapes don't match"));
Sadik Armagan1625efc2021-06-10 18:24:34 +010096 CHECK_MESSAGE(
Finn Williamsb454c5c2021-02-09 15:56:23 +000097 GetDataTypeName(expectedInfo.GetDataType()) == GetDataTypeName(actualInfo.GetDataType()),
Jan Eilers13d2e0d2021-09-28 15:11:28 +010098 (tensorName + " data types don't match"));
Finn Williamsb454c5c2021-02-09 15:56:23 +000099
Sadik Armagan1625efc2021-06-10 18:24:34 +0100100 CHECK_MESSAGE(expectedPtr->GetNumBytes() == actualPtr->GetNumBytes(),
Jan Eilers13d2e0d2021-09-28 15:11:28 +0100101 (tensorName + " (GetNumBytes) data sizes do not match"));
Finn Williamsb454c5c2021-02-09 15:56:23 +0000102 if (expectedPtr->GetNumBytes() == actualPtr->GetNumBytes())
103 {
104 //check the data is identical
105 const char* expectedData = static_cast<const char*>(expectedPtr->GetMemoryArea());
106 const char* actualData = static_cast<const char*>(actualPtr->GetMemoryArea());
107 bool same = true;
108 for (unsigned int i = 0; i < expectedPtr->GetNumBytes(); ++i)
109 {
110 same = expectedData[i] == actualData[i];
111 if (!same)
112 {
113 break;
114 }
115 }
Jan Eilers13d2e0d2021-09-28 15:11:28 +0100116 CHECK_MESSAGE(same, (tensorName + " data does not match"));
Finn Williamsb454c5c2021-02-09 15:56:23 +0000117 }
118 }
119 }
120}
121
122void CompareConstTensor(const armnn::ConstTensor& tensor1, const armnn::ConstTensor& tensor2)
123{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100124 CHECK(tensor1.GetShape() == tensor2.GetShape());
125 CHECK(GetDataTypeName(tensor1.GetDataType()) == GetDataTypeName(tensor2.GetDataType()));
Finn Williamsb454c5c2021-02-09 15:56:23 +0000126
127 switch (tensor1.GetDataType())
128 {
129 case armnn::DataType::Float32:
130 CompareConstTensorData<const float*>(
131 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
132 break;
133 case armnn::DataType::QAsymmU8:
134 case armnn::DataType::Boolean:
135 CompareConstTensorData<const uint8_t*>(
136 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
137 break;
138 case armnn::DataType::QSymmS8:
139 CompareConstTensorData<const int8_t*>(
140 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
141 break;
142 case armnn::DataType::Signed32:
143 CompareConstTensorData<const int32_t*>(
144 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
145 break;
146 default:
147 // Note that Float16 is not yet implemented
Sadik Armagan1625efc2021-06-10 18:24:34 +0100148 MESSAGE("Unexpected datatype");
149 CHECK(false);
Finn Williamsb454c5c2021-02-09 15:56:23 +0000150 }
151}
152
153armnn::INetworkPtr DeserializeNetwork(const std::string& serializerString)
154{
155 std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()};
156 return IDeserializer::Create()->CreateNetworkFromBinary(serializerVector);
157}
158
159std::string SerializeNetwork(const armnn::INetwork& network)
160{
161 armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create();
162
163 serializer->Serialize(network);
164
165 std::stringstream stream;
166 serializer->SaveSerializedToStream(stream);
167
168 std::string serializerString{stream.str()};
169 return serializerString;
170}