blob: e24055371f223fb4fa2cb508058c90551ca865e9 [file] [log] [blame]
Cathal Corbettbd18eab2022-11-15 12:56:16 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <Layer.hpp>
9
10#include <tosaCommon/TosaMappings.hpp>
Cathal Corbettb30e6552022-12-07 11:50:50 +000011#include <tosaCommon/operatorMappings/TosaOperatorUtils.hpp>
Cathal Corbettbd18eab2022-11-15 12:56:16 +000012
13#include <doctest/doctest.h>
Cathal Corbettb30e6552022-12-07 11:50:50 +000014#include <numeric>
Cathal Corbettbd18eab2022-11-15 12:56:16 +000015
16using namespace armnn;
17using namespace tosa;
18
Cathal Corbettb30e6552022-12-07 11:50:50 +000019inline void VerifyTosaAttribute(const BaseDescriptor& descriptor,
20 const TosaAttributeBase* attribute,
21 std::vector<int32_t> inputShape,
22 std::vector<int32_t> outputShape,
23 LayerType type,
24 uint32_t mappingOpNumber = 0)
Cathal Corbettbd18eab2022-11-15 12:56:16 +000025{
26 switch (type)
27 {
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000028 case LayerType::Convolution2d:
29 {
30 auto conv2dDesc = PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor);
31 std::vector<int> pad = {static_cast<int>(conv2dDesc->m_PadTop),
32 static_cast<int>(conv2dDesc->m_PadBottom),
33 static_cast<int>(conv2dDesc->m_PadLeft),
34 static_cast<int>(conv2dDesc->m_PadRight)};
35
36 std::vector<int> dilation = {static_cast<int>(conv2dDesc->m_DilationY),
37 static_cast<int>(conv2dDesc->m_DilationX)};
38 std::vector<int> stride = {static_cast<int>(conv2dDesc->m_StrideY),
39 static_cast<int>(conv2dDesc->m_StrideX)};
40 TosaConvAttribute convAttribute(attribute);
41 CHECK(pad == convAttribute.pad());
42 CHECK(dilation == convAttribute.dilation());
43 CHECK(stride == convAttribute.stride());
44 break;
45 }
Cathal Corbettbd18eab2022-11-15 12:56:16 +000046 case LayerType::Pooling2d:
47 {
48 auto poolDesc = PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor);
49 std::vector<int> pad = {static_cast<int>(poolDesc->m_PadTop),
50 static_cast<int>(poolDesc->m_PadBottom),
51 static_cast<int>(poolDesc->m_PadLeft),
52 static_cast<int>(poolDesc->m_PadRight)};
53
54 bool avgPoolIgnoreValue =
55 (poolDesc->m_PoolType == PoolingAlgorithm::Average) &&
56 (poolDesc->m_PaddingMethod == PaddingMethod::IgnoreValue);
57 if (avgPoolIgnoreValue)
58 {
59 if (mappingOpNumber == 0)
60 {
61 if (poolDesc->m_DataLayout == DataLayout::NHWC)
62 {
63 pad = {0,
64 0,
65 static_cast<int>(poolDesc->m_PadTop),
66 static_cast<int>(poolDesc->m_PadBottom),
67 static_cast<int>(poolDesc->m_PadLeft),
68 static_cast<int>(poolDesc->m_PadRight),
69 0,
70 0
71 };
72 }
73 else
74 {
75 pad = {0,
76 0,
77 0,
78 0,
79 static_cast<int>(poolDesc->m_PadTop),
80 static_cast<int>(poolDesc->m_PadBottom),
81 static_cast<int>(poolDesc->m_PadLeft),
82 static_cast<int>(poolDesc->m_PadRight)
83 };
84 }
85
86 TosaPadAttribute padAttribute(attribute);
87
88 CHECK(pad == padAttribute.padding());
89 CHECK(0.0f == padAttribute.pad_const_fp());
90 CHECK(0 == padAttribute.pad_const_int());
91
92 break;
93 }
94 pad = {0, 0, 0, 0};
95 }
96
97 std::vector<int> kernel = {static_cast<int>(poolDesc->m_PoolHeight),
98 static_cast<int>(poolDesc->m_PoolWidth)};
99 std::vector<int> stride = {static_cast<int>(poolDesc->m_StrideY),
100 static_cast<int>(poolDesc->m_StrideX)};
101 TosaPoolAttribute poolAttribute(attribute);
102 CHECK(pad == poolAttribute.pad());
103 CHECK(kernel == poolAttribute.kernel());
104 CHECK(stride == poolAttribute.stride());
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000105 break;
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000106 }
Cathal Corbettb30e6552022-12-07 11:50:50 +0000107 case LayerType::Reshape:
108 {
109 auto reshapeDesc = PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor);
110 TosaReshapeAttribute reshapeAttribute(attribute);
111 std::vector<int32_t> shapeAttrib = reshapeAttribute.new_shape();
112
113 CHECK(GetTosaTensorShape(reshapeDesc->m_TargetShape) == shapeAttrib);
114 CHECK(outputShape == shapeAttrib);
115
116 auto numInputElements = std::accumulate(std::begin(inputShape),
117 std::end(inputShape),
118 1,
119 std::multiplies<int32_t>());
120 auto numAttributeShapeElements = std::accumulate(std::begin(shapeAttrib),
121 std::end(shapeAttrib),
122 1,
123 std::multiplies<int32_t>());
124 CHECK(numInputElements == numAttributeShapeElements);
Cathal Corbett3b9acd52022-12-09 12:17:27 +0000125
126 break;
127 }
128 case LayerType::Slice:
129 {
130 auto sliceDesc = PolymorphicDowncast<const SliceDescriptor*>(&descriptor);
131 TosaSliceAttribute reshapeAttribute(attribute);
132
133 std::vector<int32_t> begin(sliceDesc->m_Begin.begin(), sliceDesc->m_Begin.end());
134 std::vector<int32_t> size(sliceDesc->m_Size.begin(), sliceDesc->m_Size.end());
135
136 CHECK(begin == reshapeAttribute.start());
137 CHECK(size == reshapeAttribute.size());
138
139 CHECK(begin.size() == inputShape.size());
140 CHECK(size.size() == inputShape.size());
141
142 CHECK(begin.size() == outputShape.size());
143 CHECK(size.size() == outputShape.size());
144
145 break;
Cathal Corbettb30e6552022-12-07 11:50:50 +0000146 }
Matthew Sloyanfc9d5e72022-12-08 13:38:23 +0000147 case LayerType::TransposeConvolution2d:
148 {
149 auto transposeConv2dDesc = PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor);
150 std::vector<int> outPad = {-static_cast<int>(transposeConv2dDesc->m_PadTop),
151 -static_cast<int>(transposeConv2dDesc->m_PadBottom),
152 -static_cast<int>(transposeConv2dDesc->m_PadLeft),
153 -static_cast<int>(transposeConv2dDesc->m_PadRight)};
154 std::vector<int> stride = {static_cast<int>(transposeConv2dDesc->m_StrideY),
155 static_cast<int>(transposeConv2dDesc->m_StrideX)};
156 TosaTransposeConvAttribute transposeConvAttribute(attribute);
157 CHECK(outPad == transposeConvAttribute.out_pad());
158 CHECK(stride == transposeConvAttribute.stride());
159 break;
160 }
Cathal Corbett0bb096d2022-12-22 13:09:38 +0000161 case LayerType::Transpose:
162 {
163 auto transposeDesc = PolymorphicDowncast<const TransposeDescriptor*>(&descriptor);
164 std::vector<int> outPerm(transposeDesc->m_DimMappings.begin(), transposeDesc->m_DimMappings.end());
165 TosaTransposeAttribute transposeAttribute(attribute);
166 CHECK(outPerm == transposeAttribute.perms());
167 break;
168 }
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000169 default:
170 break;
171 }
172 return;
173}
174
175inline void AssertTosaOneToOneMappingBasicBlock(TosaSerializationBasicBlock* basicBlock,
176 std::vector<std::vector<int32_t>> inputShape,
177 std::vector<std::vector<int32_t>> outputShape,
178 Op tosaOp,
179 Attribute tosaAttribute,
180 const BaseDescriptor& descriptor,
181 LayerType type,
182 DType dataType = DType_FP32)
183{
184 uint32_t numInputs = static_cast<uint32_t>(inputShape.size());
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000185 uint32_t numInputTensors = static_cast<uint32_t>(inputShape.size());
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000186 uint32_t numOutputs = static_cast<uint32_t>(outputShape.size());
187 std::string operatorString = TosaOpToString(tosaOp);
188
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000189 // The number of tensors in the block can be different if there are constant layers, as they are created separately.
190 if(type == LayerType::Convolution2d)
191 {
Matthew Sloyanfc9d5e72022-12-08 13:38:23 +0000192 numInputTensors = PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor)->m_BiasEnabled ? 3 : 2;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000193 }
194
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000195 std::string blockStr = operatorString + "_block_";
196 CHECK(basicBlock->GetName().find(blockStr) != std::string::npos);
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000197 CHECK(basicBlock->GetInputs().size() == numInputTensors);
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000198 CHECK(basicBlock->GetOutputs().size() == numOutputs);
199 CHECK(basicBlock->GetOperators().size() == 1);
200 CHECK(basicBlock->GetTensors().size() == (numInputs + numOutputs));
201
202 TosaSerializationOperator* op = basicBlock->GetOperators().at(0);
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000203 CHECK(op->GetInputTensorNames().size() == numInputTensors);
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000204 CHECK(op->GetOutputTensorNames().size() == numOutputs);
205
206 for (uint32_t i = 0; i < numInputs; i++)
207 {
208 std::basic_string<char> blockInputName = basicBlock->GetInputs()[i];
209 std::basic_string<char> operatorInputName = op->GetInputTensorNames()[i];
210 std::basic_string<char> tensorName = basicBlock->GetTensors()[i]->GetName();
211
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000212 std::string opStr = "input" + std::to_string(i) + "_";
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000213
214 CHECK(blockInputName == operatorInputName);
215 CHECK(tensorName == operatorInputName);
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000216 CHECK(blockInputName.find(opStr) != std::string::npos);
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000217 }
218
219 for (uint32_t i = 0; i < numOutputs; i++)
220 {
221 std::basic_string<char> blockOutputName = basicBlock->GetOutputs()[i];
222 std::basic_string<char> operatorOutputName = op->GetOutputTensorNames()[i];
223 std::basic_string<char> tensorName = basicBlock->GetTensors()[numInputs + i]->GetName();
224
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000225 std::string opStr = "output" + std::to_string(i) + "_";
226 if (tosaOp == Op_CONST)
227 {
228 opStr = "constant_";
229 }
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000230
231 CHECK(blockOutputName == operatorOutputName);
232 CHECK(tensorName == operatorOutputName);
233 CHECK(blockOutputName.find(opStr) != std::string::npos);
234 }
235
236 CHECK(op->GetAttributeType() == tosaAttribute);
237 CHECK(op->GetOp() == tosaOp);
238
239 for (uint32_t i = 0; i < numInputs; i++)
240 {
241 TosaSerializationTensor* tensor = basicBlock->GetTensors()[i];
242 CHECK(tensor->GetDtype() == dataType);
243 CHECK(tensor->GetData().size() == 0);
244 CHECK(tensor->GetShape() == inputShape[static_cast<unsigned long int>(i)]);
245 }
246
247 for (uint32_t i = 0; i < numOutputs; i++)
248 {
249 TosaSerializationTensor* tensor = basicBlock->GetTensors()[i + inputShape.size()];
250 CHECK(tensor->GetDtype() == dataType);
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000251 CHECK(tensor->GetShape() == outputShape[static_cast<unsigned long int>(i)]);
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000252 if (tosaOp != Op_CONST)
253 {
254 // Const tensors contain data.
255 CHECK(tensor->GetData().size() == 0);
256 }
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000257 }
258
Cathal Corbettb30e6552022-12-07 11:50:50 +0000259 std::vector<int32_t> input = {};
260 std::vector<int32_t> output = {};
261
262 if (!inputShape.empty())
263 {
264 input = inputShape[0];
265 }
266
267 if (!outputShape.empty())
268 {
269 output = outputShape[0];
270 }
271
272 VerifyTosaAttribute(descriptor,
273 op->GetAttribute(),
274 input,
275 output,
276 type);
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000277}