blob: ce4abbf9210eb01846ed5f4d77583af155c0117a [file] [log] [blame]
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "TosaRefLayerSupport.hpp"
7#include <tosaCommon/TosaMappings.hpp>
8
9#include <armnn/Types.hpp>
10#include <armnn/utility/IgnoreUnused.hpp>
11#include <tosaCommon/TosaLayerSupportRules.hpp>
12#include <LayerSupportCommon.hpp>
13
14#include <vector>
15#include <array>
Cathal Corbettbd18eab2022-11-15 12:56:16 +000016#include <tuple>
Francis Murtagh9270d9e2022-08-12 13:54:17 +010017
18namespace armnn
19{
20
Cathal Corbettbd18eab2022-11-15 12:56:16 +000021static bool RunTosaLayerChecksSingleDataType(TosaSerializationOperator* op,
22 const std::vector<TosaSerializationTensor*>& inputs,
23 const std::vector<TosaSerializationTensor*>& outputs,
24 const std::vector<Attribute>& supportedAttributes,
25 const std::vector<DType>& supportedTypes,
26 Optional<string&> reasonIfUnsupported)
Matthew Sloyan2523b792022-11-14 10:18:01 +000027{
28 bool supported = true;
29
Cathal Corbettbd18eab2022-11-15 12:56:16 +000030 std::string opString = TosaOpToString(op->GetOp());
Matthew Sloyan2523b792022-11-14 10:18:01 +000031
32 // Check Attribute from operator (GetAttribute)
33 supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported,
Cathal Corbettbd18eab2022-11-15 12:56:16 +000034 std::string("TOSA Reference Operator: " + opString +
Matthew Sloyan2523b792022-11-14 10:18:01 +000035 " has an unsupported attribute.").c_str());
36
37 for (auto input : inputs)
38 {
39 std::string dataTypeCode = std::to_string(input->GetDtype());
40
41 // Check Dtype from tensor (GetDtype)
42 supported &= CheckSupportRule(TosaTypeAnyOf(input, supportedTypes),
43 reasonIfUnsupported,
Cathal Corbettbd18eab2022-11-15 12:56:16 +000044 std::string("TOSA Reference Operator: " + opString + " for input: " +
Matthew Sloyan2523b792022-11-14 10:18:01 +000045 input->GetName() + " has an unsupported data type: " +
46 dataTypeCode).c_str());
47
48 // Check Shape from tensor (GetShape)
49 supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(input),
50 reasonIfUnsupported,
Cathal Corbettbd18eab2022-11-15 12:56:16 +000051 std::string("Tosa Reference Operator: " + opString + " for input: " +
Matthew Sloyan2523b792022-11-14 10:18:01 +000052 input->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str());
53 }
54
55 for (auto output : outputs)
56 {
57 std::string dataTypeCode = std::to_string(output->GetDtype());
58
59 // Check Dtype from tensor (GetDtype)
60 supported &= CheckSupportRule(TosaTypeAnyOf(output, supportedTypes),
61 reasonIfUnsupported,
Cathal Corbettbd18eab2022-11-15 12:56:16 +000062 std::string("TOSA Reference Operator: " + opString + " for output: " +
Matthew Sloyan2523b792022-11-14 10:18:01 +000063 output->GetName() + " has an unsupported data type: " +
64 dataTypeCode).c_str());
65
66 // Check Shape from tensor (GetShape)
67 supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(output),
68 reasonIfUnsupported,
Cathal Corbettbd18eab2022-11-15 12:56:16 +000069 std::string("Tosa Reference Operator: " + opString + " for output: " +
Matthew Sloyan2523b792022-11-14 10:18:01 +000070 output->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str());
71 }
72
73 return supported;
74}
75
Cathal Corbettbd18eab2022-11-15 12:56:16 +000076static bool RunTosaLayerChecksInputOutputDataType(TosaSerializationOperator* op,
77 const std::vector<TosaSerializationTensor*>& inputs,
78 const std::vector<TosaSerializationTensor*>& outputs,
79 const std::vector<Attribute>& supportedAttributes,
80 const std::vector<std::tuple<DType,DType>>& supportedMappingTypes,
81 Optional<string&> reasonIfUnsupported)
82{
83 bool supported = true;
84
85 std::string opString = TosaOpToString(op->GetOp());
86
87 // Check Attribute from operator (GetAttribute)
88 supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported,
89 std::string("TOSA Reference Operator: " + opString +
90 " has an unsupported attribute.").c_str());
91
92 supported &= CheckSupportRule(TosaAssertSize(inputs, outputs), reasonIfUnsupported,
93 std::string("TOSA Reference Operator: " + opString +
94 " must have 1-to-1 mapping of inputs-to-outputs.").c_str());
95
96 for (uint32_t i = 0; i < inputs.size(); i++)
97 {
98 auto input = inputs[i];
99 auto output = outputs[i];
100 std::string inputDataTypeCode = std::to_string(input->GetDtype());
101 std::string outputDataTypeCode = std::to_string(output->GetDtype());
102 std::tuple<DType, DType> mappingType(input->GetDtype(), output->GetDtype());
103
104 // Check Dtype from tensor (GetDtype)
105 supported &= CheckSupportRule(TosaContainerContains(mappingType, supportedMappingTypes),
106 reasonIfUnsupported,
107 std::string("TOSA Reference Operator: " + opString + " for input: " +
108 input->GetName() + " and output: " + output->GetName() +
109 " has an unsupported input data type: " + inputDataTypeCode +
110 " to output data type: " + outputDataTypeCode).c_str());
111
112 // Check Shape from tensor (GetShape)
113 supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(input),
114 reasonIfUnsupported,
115 std::string("Tosa Reference Operator: " + opString + " for input: " +
116 input->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str());
117
118 // Check Shape from tensor (GetShape)
119 supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(output),
120 reasonIfUnsupported,
121 std::string("Tosa Reference Operator: " + opString + " for output: " +
122 output->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str());
123 }
124
125 return supported;
126}
127
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100128static bool IsTosaLayerSupported(TosaSerializationOperator* op,
129 const std::vector<TosaSerializationTensor*>& inputs,
130 const std::vector<TosaSerializationTensor*>& outputs,
131 Optional<string&> reasonIfUnsupported)
132{
133 switch(op->GetOp())
134 {
135 case tosa::Op_ADD:
136 {
Matthew Sloyan2523b792022-11-14 10:18:01 +0000137 std::vector<Attribute> supportedAttributes =
Matthew Sloyanda824cc2022-10-10 12:43:20 +0100138 {
139 Attribute_NONE
140 };
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100141
Matthew Sloyan2523b792022-11-14 10:18:01 +0000142 // Only Int32, Fp32 and Fp16 are currently supported by the TOSA Reference Model.
143 std::vector<DType> supportedTypes =
Matthew Sloyanda824cc2022-10-10 12:43:20 +0100144 {
Matthew Sloyanda824cc2022-10-10 12:43:20 +0100145 DType_INT32,
146 DType_FP16,
147 DType_FP32
148 };
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100149
Matthew Sloyan2523b792022-11-14 10:18:01 +0000150 // Check the attribute, data types and bounds for inputs and outputs.
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000151 return RunTosaLayerChecksSingleDataType(op,
152 inputs,
153 outputs,
154 supportedAttributes,
155 supportedTypes,
156 reasonIfUnsupported);
157 }
158 case tosa::Op_AVG_POOL2D:
159 {
160 std::vector<Attribute> supportedAttributes =
161 {
162 Attribute_PoolAttribute
163 };
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100164
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000165 std::vector<std::tuple<DType, DType>> supportedTypesMapping =
166 {
167 std::tuple<DType, DType>(DType_FP16, DType_FP16),
168 std::tuple<DType, DType>(DType_FP16, DType_FP32),
169 std::tuple<DType, DType>(DType_FP32, DType_FP32),
170 std::tuple<DType, DType>(DType_INT8, DType_INT32),
171 std::tuple<DType, DType>(DType_INT16, DType_INT32)
172 };
173
174 // Check the attribute, data types and bounds for inputs and outputs.
175 return RunTosaLayerChecksInputOutputDataType(op,
176 inputs,
177 outputs,
178 supportedAttributes,
179 supportedTypesMapping,
180 reasonIfUnsupported);
181 }
182 case tosa::Op_MAX_POOL2D:
183 {
184 std::vector<Attribute> supportedAttributes =
185 {
186 Attribute_PoolAttribute
187 };
188
189 std::vector<DType> supportedTypes =
190 {
191 DType_FP16,
192 DType_FP32,
193 DType_INT8,
194 DType_INT16
195 };
196
197 // Check the attribute, data types and bounds for inputs and outputs.
198 return RunTosaLayerChecksSingleDataType(op,
199 inputs,
200 outputs,
201 supportedAttributes,
202 supportedTypes,
203 reasonIfUnsupported);
204 }
205 case tosa::Op_PAD:
206 {
207 std::vector<Attribute> supportedAttributes =
208 {
209 Attribute_PadAttribute
210 };
211
212 std::vector<DType> supportedTypes =
213 {
214 DType_FP16,
215 DType_FP32,
216 DType_INT8,
217 DType_INT16,
218 DType_INT32,
219 DType_BOOL
220 };
221
222 // Check the attribute, data types and bounds for inputs and outputs.
223 return RunTosaLayerChecksSingleDataType(op,
224 inputs,
225 outputs,
226 supportedAttributes,
227 supportedTypes,
228 reasonIfUnsupported);
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100229 }
230 default:
231 SetValueChecked(reasonIfUnsupported, "Operation is currently unsupported by the TOSA Reference Backend.");
232 return false;
233 }
234}
235
236bool TosaRefLayerSupport::IsLayerSupported(const LayerType& type,
237 const std::vector<TensorInfo>& infos,
238 const BaseDescriptor& descriptor,
239 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
240 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
241 Optional<std::string&> reasonIfUnsupported) const
242{
243 IgnoreUnused(lstmParamsInfo);
244 IgnoreUnused(quantizedLstmInputParamsInfo);
245
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100246 std::vector<const TensorInfo*> inputInfos;
247 std::vector<const TensorInfo*> outputInfos;
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100248
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100249 switch (type)
250 {
251 case LayerType::Addition:
252 // Setup inputs and outputs
253 inputInfos.push_back(&infos[0]);
254 inputInfos.push_back(&infos[1]);
255 outputInfos.push_back(&infos[2]);
256 break;
257 case LayerType::Input:
258 case LayerType::Output:
259 return true;
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000260 case LayerType::Pooling2d:
261 // Setup inputs and outputs
262 inputInfos.push_back(&infos[0]);
263 outputInfos.push_back(&infos[1]);
264 break;
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100265 default:
266 break;
267 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100268
Matthew Sloyan5c54c382022-11-09 16:28:51 +0000269 auto mappings = GetTosaMapping(type, inputInfos, outputInfos, descriptor, false);
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100270 if (mappings->GetName() == "")
271 {
272 // There currently isn't a TOSA mapping for this layer, as the default was returned.
273 return false;
274 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100275
276 // Loop through block and get each tensor and operator
277 for (long unsigned int i = 0; i < mappings->GetOperators().size(); ++i)
278 {
279 // While looping over operators check for op_UNKNOWN which is unsupported
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100280 if (mappings->GetOperators()[i]->GetOp() == tosa::Op_UNKNOWN) { return false; }
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100281
282 // Loop over operators and get GetInput/OutputTensorNames, loop over resulting names and
283 // use GetTensorByName to pass pointers to tensors on to the IsTosaLayerSupported()
284 std::vector<TosaSerializationTensor*> inputTensorsVect;
285 for (const auto& name : mappings->GetOperators()[i]->GetInputTensorNames())
286 {
287 inputTensorsVect.push_back(mappings->GetTensorByName(name));
288 }
289
290 std::vector<TosaSerializationTensor*> outputTensorsVect;
291 for (const auto& name : mappings->GetOperators()[i]->GetOutputTensorNames())
292 {
293 outputTensorsVect.push_back(mappings->GetTensorByName(name));
294 }
295
296 if (!IsTosaLayerSupported(mappings->GetOperators()[i],
297 inputTensorsVect,
298 outputTensorsVect,
299 reasonIfUnsupported))
300 {
301 return false;
302 }
303 }
304 return true;
305}
306
307} // namespace armnn