blob: a362bde10d28abc7dac87161af58ee57d2245af0 [file] [log] [blame]
Cathal Corbettbd18eab2022-11-15 12:56:16 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <Layer.hpp>
9
10#include <tosaCommon/TosaMappings.hpp>
11
12#include <doctest/doctest.h>
13
14using namespace armnn;
15using namespace tosa;
16
17inline void VerifyTosaAttributeFromDescriptor(const BaseDescriptor& descriptor,
18 const TosaAttributeBase* attribute,
19 LayerType type,
20 uint32_t mappingOpNumber = 0)
21{
22 switch (type)
23 {
24 case LayerType::Pooling2d:
25 {
26 auto poolDesc = PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor);
27 std::vector<int> pad = {static_cast<int>(poolDesc->m_PadTop),
28 static_cast<int>(poolDesc->m_PadBottom),
29 static_cast<int>(poolDesc->m_PadLeft),
30 static_cast<int>(poolDesc->m_PadRight)};
31
32 bool avgPoolIgnoreValue =
33 (poolDesc->m_PoolType == PoolingAlgorithm::Average) &&
34 (poolDesc->m_PaddingMethod == PaddingMethod::IgnoreValue);
35 if (avgPoolIgnoreValue)
36 {
37 if (mappingOpNumber == 0)
38 {
39 if (poolDesc->m_DataLayout == DataLayout::NHWC)
40 {
41 pad = {0,
42 0,
43 static_cast<int>(poolDesc->m_PadTop),
44 static_cast<int>(poolDesc->m_PadBottom),
45 static_cast<int>(poolDesc->m_PadLeft),
46 static_cast<int>(poolDesc->m_PadRight),
47 0,
48 0
49 };
50 }
51 else
52 {
53 pad = {0,
54 0,
55 0,
56 0,
57 static_cast<int>(poolDesc->m_PadTop),
58 static_cast<int>(poolDesc->m_PadBottom),
59 static_cast<int>(poolDesc->m_PadLeft),
60 static_cast<int>(poolDesc->m_PadRight)
61 };
62 }
63
64 TosaPadAttribute padAttribute(attribute);
65
66 CHECK(pad == padAttribute.padding());
67 CHECK(0.0f == padAttribute.pad_const_fp());
68 CHECK(0 == padAttribute.pad_const_int());
69
70 break;
71 }
72 pad = {0, 0, 0, 0};
73 }
74
75 std::vector<int> kernel = {static_cast<int>(poolDesc->m_PoolHeight),
76 static_cast<int>(poolDesc->m_PoolWidth)};
77 std::vector<int> stride = {static_cast<int>(poolDesc->m_StrideY),
78 static_cast<int>(poolDesc->m_StrideX)};
79 TosaPoolAttribute poolAttribute(attribute);
80 CHECK(pad == poolAttribute.pad());
81 CHECK(kernel == poolAttribute.kernel());
82 CHECK(stride == poolAttribute.stride());
83 }
84 default:
85 break;
86 }
87 return;
88}
89
90inline void AssertTosaOneToOneMappingBasicBlock(TosaSerializationBasicBlock* basicBlock,
91 std::vector<std::vector<int32_t>> inputShape,
92 std::vector<std::vector<int32_t>> outputShape,
93 Op tosaOp,
94 Attribute tosaAttribute,
95 const BaseDescriptor& descriptor,
96 LayerType type,
97 DType dataType = DType_FP32)
98{
99 uint32_t numInputs = static_cast<uint32_t>(inputShape.size());
100 uint32_t numOutputs = static_cast<uint32_t>(outputShape.size());
101 std::string operatorString = TosaOpToString(tosaOp);
102
103 std::string blockStr = operatorString + "_block_";
104 CHECK(basicBlock->GetName().find(blockStr) != std::string::npos);
105 CHECK(basicBlock->GetInputs().size() == numInputs);
106 CHECK(basicBlock->GetOutputs().size() == numOutputs);
107 CHECK(basicBlock->GetOperators().size() == 1);
108 CHECK(basicBlock->GetTensors().size() == (numInputs + numOutputs));
109
110 TosaSerializationOperator* op = basicBlock->GetOperators().at(0);
111 CHECK(op->GetInputTensorNames().size() == numInputs);
112 CHECK(op->GetOutputTensorNames().size() == numOutputs);
113
114 for (uint32_t i = 0; i < numInputs; i++)
115 {
116 std::basic_string<char> blockInputName = basicBlock->GetInputs()[i];
117 std::basic_string<char> operatorInputName = op->GetInputTensorNames()[i];
118 std::basic_string<char> tensorName = basicBlock->GetTensors()[i]->GetName();
119
120 std::string opStr = operatorString + "_input" + std::to_string(i) + "_";
121
122 CHECK(blockInputName == operatorInputName);
123 CHECK(tensorName == operatorInputName);
124 CHECK(blockInputName.find(opStr) != std::string::npos);
125 }
126
127 for (uint32_t i = 0; i < numOutputs; i++)
128 {
129 std::basic_string<char> blockOutputName = basicBlock->GetOutputs()[i];
130 std::basic_string<char> operatorOutputName = op->GetOutputTensorNames()[i];
131 std::basic_string<char> tensorName = basicBlock->GetTensors()[numInputs + i]->GetName();
132
133 std::string opStr = operatorString + "_output" + std::to_string(i) + "_";
134
135 CHECK(blockOutputName == operatorOutputName);
136 CHECK(tensorName == operatorOutputName);
137 CHECK(blockOutputName.find(opStr) != std::string::npos);
138 }
139
140 CHECK(op->GetAttributeType() == tosaAttribute);
141 CHECK(op->GetOp() == tosaOp);
142
143 for (uint32_t i = 0; i < numInputs; i++)
144 {
145 TosaSerializationTensor* tensor = basicBlock->GetTensors()[i];
146 CHECK(tensor->GetDtype() == dataType);
147 CHECK(tensor->GetData().size() == 0);
148 CHECK(tensor->GetShape() == inputShape[static_cast<unsigned long int>(i)]);
149 }
150
151 for (uint32_t i = 0; i < numOutputs; i++)
152 {
153 TosaSerializationTensor* tensor = basicBlock->GetTensors()[i + inputShape.size()];
154 CHECK(tensor->GetDtype() == dataType);
155 CHECK(tensor->GetData().size() == 0);
156 CHECK(tensor->GetShape() == outputShape[static_cast<unsigned long int>(i)]);
157 }
158
159 VerifyTosaAttributeFromDescriptor(descriptor,
160 op->GetAttribute(),
161 type);
162}