blob: 05ccef4a9ca6f8921016b197bf15409f8d018063 [file] [log] [blame]
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01001//
Tracy Narine10403ec2023-11-28 11:55:08 +00002// Copyright © 2022-2024 Arm Ltd and Contributors. All rights reserved.
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +00008#include <Layer.hpp>
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01009#include <armnn/Tensor.hpp>
10#include <armnn/Types.hpp>
11
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000012#include "common/include/ProfilingGuid.hpp"
13
14#include <tosa_serialization_handler.h>
Tracy Narine10403ec2023-11-28 11:55:08 +000015#include <version.h>
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010016
17using namespace armnn;
18using namespace tosa;
19
Narumol Prangnawaratad323af2023-09-29 17:00:38 +010020const std::string mainName = "main";
21
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010022// Function to return Tosa datatype from input ArmNN datatype.
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010023inline DType ArmNNToDType(const DataType& type)
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010024{
25 switch (type)
26 {
27 case DataType::Float16:
Matthew Sloyanda824cc2022-10-10 12:43:20 +010028 return DType_FP16;
Narumol Prangnawaratad323af2023-09-29 17:00:38 +010029 case DataType::BFloat16:
30 return DType_BF16;
Matthew Sloyanda824cc2022-10-10 12:43:20 +010031 case DataType::Float32:
32 return DType_FP32;
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010033 case DataType::QAsymmU8:
34 return DType_UINT8;
35 case DataType::QSymmS8:
36 case DataType::QAsymmS8:
37 return DType_INT8;
38 case DataType::QSymmS16:
39 return DType_INT16;
40 case DataType::Signed32:
41 return DType_INT32;
42 case DataType::Signed64:
43 // No signed 64, only DType_INT48.
44 return DType_UNKNOWN;
45 case DataType::Boolean:
46 return DType_BOOL;
47 default:
48 return DType_UNKNOWN;
49 }
50}
51
52// Function to return Tosa tensor shape from input ArmNN tensor shape.
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010053inline std::vector<int32_t> GetTosaTensorShape(const TensorShape& shape)
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010054{
55 std::vector<int32_t> returnShape;
56 for (u_int32_t i = 0; i < shape.GetNumDimensions(); i++)
57 {
58 returnShape.push_back(static_cast<int32_t>(shape[i]));
59 }
60 return returnShape;
61}
62
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000063// Function that generates unique name using the layer type, input slot and layer guid.
64inline std::string GenerateUniqueName(const Layer& layer, uint32_t layerSlot)
65{
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000066 std::string guid = std::to_string(layer.GetGuid());
67 std::string slotAndGuid = std::to_string(layerSlot) + "_" + guid;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000068
Cathal Corbettb30e6552022-12-07 11:50:50 +000069 switch (layer.GetType())
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000070 {
Cathal Corbettb30e6552022-12-07 11:50:50 +000071 case LayerType::Input:
72 return "input" + slotAndGuid;
73 case LayerType::Output:
74 return "output" + slotAndGuid;
75 case LayerType::Constant:
76 return "constant_" + guid;
77 default:
78 return "intermediate" + slotAndGuid;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000079 }
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000080}
81
Matthew Sloyanfc9d5e72022-12-08 13:38:23 +000082// Function that generates unique output name using the layer type, input slot and layer guid.
83inline std::string GenerateUniqueOutputName(const Layer& layer, uint32_t layerSlot)
84{
85 Layer& connectedLayer = layer.GetOutputSlot().GetConnection(0)->GetOwningLayer();
86
87 // Get the layer connected to the output slot, if output use that layer and id,
88 // otherwise use current layer and id.
89 if(connectedLayer.GetType() == LayerType::Output)
90 {
91 return GenerateUniqueName(connectedLayer, layerSlot);
92 }
93 else
94 {
95 return GenerateUniqueName(layer, layerSlot);
96 }
97}
98
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010099// Function to return unique int as a string to ensure uniqueness between all input, output and block names.
100static int uniqueTosaMappingID = 0;
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100101inline std::string GetUniqueTosaMappingID()
Cathal Corbett9c9d5b92022-08-17 17:30:16 +0100102{
103 return std::to_string(++uniqueTosaMappingID);
104}
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000105
Cathal Corbettb30e6552022-12-07 11:50:50 +0000106// Function to return Tosa DType as string.
107inline std::string TosaDTypeToString(DType tosaDType)
108{
109 switch (tosaDType)
110 {
111 case DType_UNKNOWN:
112 return "DType_UNKNOWN";
113 case DType_BOOL:
114 return "DType_BOOL";
115 case DType_UINT8:
116 return "DType_UINT8";
117 case DType_INT4:
118 return "DType_INT4";
119 case DType_INT8:
120 return "DType_INT8";
121 case DType_INT16:
122 return "DType_INT16";
123 case DType_INT32:
124 return "DType_INT32";
125 case DType_INT48:
126 return "DType_INT48";
127 case DType_FP32:
128 return "DType_FP32";
129 case DType_UINT16:
130 return "DType_UINT16";
131 case DType_FP16:
132 return "DType_FP16";
Narumol Prangnawaratad323af2023-09-29 17:00:38 +0100133 case DType_BF16:
134 return "DType_BF16";
Cathal Corbettb30e6552022-12-07 11:50:50 +0000135 }
136 return "";
137}
138
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000139// Function to return Tosa Op as string.
140inline std::string TosaOpToString(Op tosaOp)
141{
142 switch (tosaOp)
143 {
144 case Op_ADD:
145 return "Op_ADD";
146 case Op_AVG_POOL2D:
147 return "Op_AVG_POOL2D";
148 case Op_MAX_POOL2D:
149 return "Op_MAX_POOL2D";
150 case Op_PAD:
151 return "Op_PAD";
152 case Op_UNKNOWN:
153 return "Op_UNKNOWN";
154 case Op_ARGMAX:
155 return "Op_ARGMAX";
156 case Op_CONV2D:
157 return "Op_CONV2D";
158 case Op_CONV3D:
159 return "Op_CONV3D";
160 case Op_DEPTHWISE_CONV2D:
161 return "Op_DEPTHWISE_CONV2D";
162 case Op_FULLY_CONNECTED:
163 return "Op_FULLY_CONNECTED";
164 case Op_MATMUL:
165 return "Op_MATMUL";
166 case Op_TRANSPOSE_CONV2D:
167 return "Op_TRANSPOSE_CONV2D";
168 case Op_CLAMP:
169 return "Op_CLAMP";
170 case Op_RESERVED:
171 return "Op_RESERVED";
172 case Op_SIGMOID:
173 return "Op_SIGMOID";
174 case Op_TANH:
175 return "Op_TANH";
176 case Op_ARITHMETIC_RIGHT_SHIFT:
177 return "Op_ARITHMETIC_RIGHT_SHIFT";
178 case Op_BITWISE_AND:
179 return "Op_BITWISE_AND";
180 case Op_BITWISE_OR:
181 return "Op_BITWISE_OR";
182 case Op_BITWISE_XOR:
183 return "Op_BITWISE_XOR";
184 case Op_INTDIV:
185 return "Op_INTDIV";
186 case Op_LOGICAL_AND:
187 return "Op_LOGICAL_AND";
188 case Op_LOGICAL_LEFT_SHIFT:
189 return "Op_LOGICAL_LEFT_SHIFT";
190 case Op_LOGICAL_RIGHT_SHIFT:
191 return "Op_LOGICAL_RIGHT_SHIFT";
192 case Op_LOGICAL_OR:
193 return "Op_LOGICAL_OR";
194 case Op_LOGICAL_XOR:
195 return "Op_LOGICAL_XOR";
196 case Op_MAXIMUM:
197 return "Op_MAXIMUM";
198 case Op_MINIMUM:
199 return "Op_MINIMUM";
200 case Op_MUL:
201 return "Op_MUL";
202 case Op_POW:
203 return "Op_POW";
204 case Op_SUB:
205 return "Op_SUB";
206 case Op_TABLE:
207 return "Op_TABLE";
208 case Op_ABS:
209 return "Op_ABS";
210 case Op_BITWISE_NOT:
211 return "Op_BITWISE_NOT";
212 case Op_CEIL:
213 return "Op_CEIL";
214 case Op_CLZ:
215 return "Op_CLZ";
216 case Op_EXP:
217 return "Op_EXP";
218 case Op_FLOOR:
219 return "Op_FLOOR";
220 case Op_LOG:
221 return "Op_LOG";
222 case Op_LOGICAL_NOT:
223 return "Op_LOGICAL_NOT";
224 case Op_NEGATE:
225 return "Op_NEGATE";
226 case Op_RECIPROCAL:
227 return "Op_RECIPROCAL";
228 case Op_RSQRT:
229 return "Op_RSQRT";
230 case Op_SELECT:
231 return "Op_SELECT";
232 case Op_EQUAL:
233 return "Op_EQUAL";
234 case Op_GREATER:
235 return "Op_GREATER";
236 case Op_GREATER_EQUAL:
237 return "Op_GREATER_EQUAL";
238 case Op_REDUCE_ANY:
239 return "Op_REDUCE_ANY";
240 case Op_REDUCE_ALL:
241 return "Op_REDUCE_ALL";
242 case Op_REDUCE_MAX:
243 return "Op_REDUCE_MAX";
244 case Op_REDUCE_MIN:
245 return "Op_REDUCE_MIN";
246 case Op_REDUCE_PRODUCT:
247 return "Op_REDUCE_PRODUCT";
248 case Op_REDUCE_SUM:
249 return "Op_REDUCE_SUM";
250 case Op_CONCAT:
251 return "Op_CONCAT";
252 case Op_RESHAPE:
253 return "Op_RESHAPE";
254 case Op_REVERSE:
255 return "Op_REVERSE";
256 case Op_SLICE:
257 return "Op_SLICE";
258 case Op_TILE:
259 return "Op_TILE";
260 case Op_TRANSPOSE:
261 return "Op_TRANSPOSE";
262 case Op_GATHER:
263 return "Op_GATHER";
264 case Op_SCATTER:
265 return "Op_SCATTER";
266 case Op_RESIZE:
267 return "Op_RESIZE";
268 case Op_CAST:
269 return "Op_CAST";
270 case Op_RESCALE:
271 return "Op_RESCALE";
272 case Op_CONST:
273 return "Op_CONST";
274 case Op_IDENTITY:
275 return "Op_IDENTITY";
276 case Op_CUSTOM:
277 return "Op_CUSTOM";
278 case Op_COND_IF:
279 return "Op_COND_IF";
280 case Op_WHILE_LOOP:
281 return "Op_WHILE_LOOP";
Narumol Prangnawaratad323af2023-09-29 17:00:38 +0100282 case Op_FFT2D:
283 return "Op_FFT2D";
284 case Op_RFFT2D:
285 return "Op_RFFT2D";
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000286 }
287 return "";
288}
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000289
290inline std::vector<uint8_t> ConvertConstantTensorDataToBuffer(const std::shared_ptr<ConstTensorHandle>& tensorHandle)
291{
Teresa Charlin3fbad942022-12-15 10:35:37 +0000292 tosa_err_t error = tosa_err_t::TOSA_OK;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000293 std::vector<uint8_t> uint8Data;
294 auto tensorInfo = tensorHandle->GetTensorInfo();
295
296 switch (tensorInfo.GetDataType())
297 {
298 case DataType::Float32:
299 {
300 std::vector<float> data(tensorInfo.GetNumElements());
301 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
302
303 error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data);
304 break;
305 }
306 case DataType::Float16:
307 {
308 std::vector<float> data(tensorInfo.GetNumElements());
309 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
310
311 error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data);
312 break;
313 }
314 case DataType::QSymmS8:
315 case DataType::QAsymmS8:
316 {
317 std::vector<int8_t> data(tensorInfo.GetNumElements());
318 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
319
320 error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data);
321 break;
322 }
323 case DataType::QAsymmU8:
324 {
325 memcpy(uint8Data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
326 break;
327 }
328 case DataType::QSymmS16:
329 {
330 std::vector<int16_t> data(tensorInfo.GetNumElements());
331 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
332
333 error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data);
334 break;
335 }
336 case DataType::Signed32:
337 {
338 std::vector<int32_t> data(tensorInfo.GetNumElements());
339 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
340
341 error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data);
342 break;
343 }
344 default:
345 {
346 throw armnn::Exception("SetConstantTensorData: An unsupported data type was encountered.");
347 }
348 }
349
350 if(error != tosa_err_t::TOSA_OK)
351 {
352 throw armnn::Exception("SetConstantTensorData: An error occurred when converting constant data");
353 }
354
355 tensorHandle->Unmap();
356 return uint8Data;
357}
Teresa Charlinca5a23a2023-12-15 14:20:47 +0000358
Teresa Charlinca5a23a2023-12-15 14:20:47 +0000359inline std::vector<uint8_t> CreateConstTosaData(const void* value,
360 DType dtype,
361 const std::vector<int32_t>& shape)
362{
363 std::vector<uint8_t> uint8Data;
364 tosa_err_t error = tosa_err_t::TOSA_OK;
365
366 unsigned int numElements = 1;
367 for (auto s : shape)
368 {
369 if (s < 0)
370 {
371 throw armnn::Exception("CreateConstTosaData: negative shape elements unhandled.");
372 }
373 numElements = numElements * static_cast<unsigned int>(s);
374 }
375
376 switch (dtype)
377 {
378 case DType::DType_FP32:
379 {
380 std::vector<float> data(numElements, *static_cast<const float*>(value));
381 error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data);
382 break;
383 }
384 case DType::DType_FP16:
385 {
386 std::vector<float> data(numElements, *static_cast<const float*>(value));
387 error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data);
388 break;
389 }
390 case DType::DType_INT48:
391 {
392 std::vector<int64_t> data(numElements, *static_cast<const int64_t*>(value));
393 error = TosaSerializationHandler::ConvertI48toU8(data, uint8Data);
394 break;
395 }
396 case DType::DType_INT32:
397 {
398 std::vector<int32_t> data(numElements, *static_cast<const int32_t*>(value));
399 error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data);
400 break;
401 }
402 case DType::DType_INT16:
403 {
404 std::vector<int16_t> data(numElements, *static_cast<const int16_t*>(value));
405 error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data);
406 break;
407 }
408 case DType::DType_INT8:
409 {
410 std::vector<int8_t> data(numElements, *static_cast<const int8_t*>(value));
411 error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data);
412 break;
413 }
414 case DType::DType_INT4:
415 {
416 std::vector<int8_t> data(numElements, *static_cast<const int8_t*>(value));
417 error = TosaSerializationHandler::ConvertI4toU8(data, uint8Data);
418 break;
419 }
420 case DType::DType_BOOL:
421 {
422 std::vector<bool> data(numElements, *static_cast<const bool*>(value));
423 error = TosaSerializationHandler::ConvertBooltoU8(data, uint8Data);
424 break;
425 }
426 default:
427 {
428 throw armnn::Exception("CreateConstTosaData: An unsupported data type was encountered.");
429 }
430 }
431
432 if(error != tosa_err_t::TOSA_OK)
433 {
434 throw armnn::Exception("CreateConstTosaData: An error occurred when converting constant data");
435 }
436
437 return uint8Data;
438}
439
440template<typename T>
441inline void CreateConstTosaOperator(const std::string& outputName,
442 const T value,
443 DType dtype,
444 const std::vector<int32_t>& shape,
445 TosaSerializationOperator*& op,
446 TosaSerializationTensor*& tensor)
447{
448 std::vector<uint8_t> uint8Data = CreateConstTosaData(static_cast<const void *>(&value), dtype, shape);
449
450 op = new TosaSerializationOperator(Op_CONST, Attribute_NONE, nullptr, {}, {outputName});
451 ARMNN_THROW_MSG_IF_FALSE(op, armnn::Exception, "CreateConstTosaOperator: failed to created operator");
452
453 tensor = new TosaSerializationTensor(outputName, shape, dtype, uint8Data);
454 ARMNN_THROW_MSG_IF_FALSE(tensor, armnn::Exception, "CreateConstTosaOperator: failed to created tensor");
Tracy Narine10403ec2023-11-28 11:55:08 +0000455}
456
457// Macro to conditionally compile Tosa code
458#define TOSA_FWD_COMPAT_VERSION(_major, _minor, _patch) \
459 (TOSA_REFERENCE_MODEL_VERSION_MAJOR >= _major) && \
460 (TOSA_REFERENCE_MODEL_VERSION_MINOR >= _minor) && \
461 (TOSA_REFERENCE_MODEL_VERSION_PATCH >= _patch)