blob: b7f14bf5b7253fa186848e1796bc66540146c228 [file] [log] [blame]
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01001//
Tracy Narine10403ec2023-11-28 11:55:08 +00002// Copyright © 2022-2024 Arm Ltd and Contributors. All rights reserved.
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +00008#include <Layer.hpp>
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01009#include <armnn/Tensor.hpp>
10#include <armnn/Types.hpp>
11
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000012#include "common/include/ProfilingGuid.hpp"
13
14#include <tosa_serialization_handler.h>
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010015
16using namespace armnn;
17using namespace tosa;
18
Narumol Prangnawaratad323af2023-09-29 17:00:38 +010019const std::string mainName = "main";
20
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010021// Function to return Tosa datatype from input ArmNN datatype.
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010022inline DType ArmNNToDType(const DataType& type)
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010023{
24 switch (type)
25 {
26 case DataType::Float16:
Matthew Sloyanda824cc2022-10-10 12:43:20 +010027 return DType_FP16;
Narumol Prangnawaratad323af2023-09-29 17:00:38 +010028 case DataType::BFloat16:
29 return DType_BF16;
Matthew Sloyanda824cc2022-10-10 12:43:20 +010030 case DataType::Float32:
31 return DType_FP32;
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010032 case DataType::QAsymmU8:
33 return DType_UINT8;
34 case DataType::QSymmS8:
35 case DataType::QAsymmS8:
36 return DType_INT8;
37 case DataType::QSymmS16:
38 return DType_INT16;
39 case DataType::Signed32:
40 return DType_INT32;
41 case DataType::Signed64:
42 // No signed 64, only DType_INT48.
43 return DType_UNKNOWN;
44 case DataType::Boolean:
45 return DType_BOOL;
46 default:
47 return DType_UNKNOWN;
48 }
49}
50
John Mcloughlinceb44282024-04-23 16:47:04 +010051// Function to return ArmNN datatype from input Tosa datatype.
52inline DataType DtypeToArmNN(const DType type)
53{
54 switch (type)
55 {
56 case DType_FP16:
57 return DataType::Float16;
58 case DType_BF16:
59 return DataType::BFloat16;
60 case DType_FP32:
61 return DataType::Float32;
62 case DType_UINT8:
63 return DataType::QAsymmU8;
64 case DType_INT8:
65 return DataType::QSymmS8;
66 case DType_INT16:
67 return DataType::QSymmS16;
68 case DType_INT32:
69 return DataType::Signed32;
70 case DType_BOOL:
71 return DataType::Boolean;
72 default:
73 throw armnn::Exception("DtypeToArmNN: Unsupported tosa::DType in ArmNN.");
74 return DataType::Boolean;
75 }
76}
77
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010078// Function to return Tosa tensor shape from input ArmNN tensor shape.
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010079inline std::vector<int32_t> GetTosaTensorShape(const TensorShape& shape)
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010080{
81 std::vector<int32_t> returnShape;
82 for (u_int32_t i = 0; i < shape.GetNumDimensions(); i++)
83 {
84 returnShape.push_back(static_cast<int32_t>(shape[i]));
85 }
86 return returnShape;
87}
88
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000089// Function that generates unique name using the layer type, input slot and layer guid.
90inline std::string GenerateUniqueName(const Layer& layer, uint32_t layerSlot)
91{
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000092 std::string guid = std::to_string(layer.GetGuid());
93 std::string slotAndGuid = std::to_string(layerSlot) + "_" + guid;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000094
Cathal Corbettb30e6552022-12-07 11:50:50 +000095 switch (layer.GetType())
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000096 {
Cathal Corbettb30e6552022-12-07 11:50:50 +000097 case LayerType::Input:
98 return "input" + slotAndGuid;
99 case LayerType::Output:
100 return "output" + slotAndGuid;
101 case LayerType::Constant:
102 return "constant_" + guid;
103 default:
104 return "intermediate" + slotAndGuid;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000105 }
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000106}
107
Matthew Sloyanfc9d5e72022-12-08 13:38:23 +0000108// Function that generates unique output name using the layer type, input slot and layer guid.
109inline std::string GenerateUniqueOutputName(const Layer& layer, uint32_t layerSlot)
110{
111 Layer& connectedLayer = layer.GetOutputSlot().GetConnection(0)->GetOwningLayer();
112
113 // Get the layer connected to the output slot, if output use that layer and id,
114 // otherwise use current layer and id.
115 if(connectedLayer.GetType() == LayerType::Output)
116 {
117 return GenerateUniqueName(connectedLayer, layerSlot);
118 }
119 else
120 {
121 return GenerateUniqueName(layer, layerSlot);
122 }
123}
124
Cathal Corbett9c9d5b92022-08-17 17:30:16 +0100125// Function to return unique int as a string to ensure uniqueness between all input, output and block names.
126static int uniqueTosaMappingID = 0;
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100127inline std::string GetUniqueTosaMappingID()
Cathal Corbett9c9d5b92022-08-17 17:30:16 +0100128{
129 return std::to_string(++uniqueTosaMappingID);
130}
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000131
Cathal Corbettb30e6552022-12-07 11:50:50 +0000132// Function to return Tosa DType as string.
133inline std::string TosaDTypeToString(DType tosaDType)
134{
135 switch (tosaDType)
136 {
137 case DType_UNKNOWN:
138 return "DType_UNKNOWN";
139 case DType_BOOL:
140 return "DType_BOOL";
141 case DType_UINT8:
142 return "DType_UINT8";
143 case DType_INT4:
144 return "DType_INT4";
145 case DType_INT8:
146 return "DType_INT8";
147 case DType_INT16:
148 return "DType_INT16";
149 case DType_INT32:
150 return "DType_INT32";
151 case DType_INT48:
152 return "DType_INT48";
153 case DType_FP32:
154 return "DType_FP32";
155 case DType_UINT16:
156 return "DType_UINT16";
157 case DType_FP16:
158 return "DType_FP16";
Narumol Prangnawaratad323af2023-09-29 17:00:38 +0100159 case DType_BF16:
160 return "DType_BF16";
Teresa Charlin571a4f72024-03-26 11:18:42 +0000161 case DType_SHAPE:
162 return "DType_SHAPE";
Cathal Corbettb30e6552022-12-07 11:50:50 +0000163 }
164 return "";
165}
166
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000167// Function to return Tosa Op as string.
168inline std::string TosaOpToString(Op tosaOp)
169{
170 switch (tosaOp)
171 {
172 case Op_ADD:
173 return "Op_ADD";
174 case Op_AVG_POOL2D:
175 return "Op_AVG_POOL2D";
176 case Op_MAX_POOL2D:
177 return "Op_MAX_POOL2D";
178 case Op_PAD:
179 return "Op_PAD";
180 case Op_UNKNOWN:
181 return "Op_UNKNOWN";
182 case Op_ARGMAX:
183 return "Op_ARGMAX";
184 case Op_CONV2D:
185 return "Op_CONV2D";
186 case Op_CONV3D:
187 return "Op_CONV3D";
188 case Op_DEPTHWISE_CONV2D:
189 return "Op_DEPTHWISE_CONV2D";
190 case Op_FULLY_CONNECTED:
191 return "Op_FULLY_CONNECTED";
192 case Op_MATMUL:
193 return "Op_MATMUL";
194 case Op_TRANSPOSE_CONV2D:
195 return "Op_TRANSPOSE_CONV2D";
196 case Op_CLAMP:
197 return "Op_CLAMP";
198 case Op_RESERVED:
199 return "Op_RESERVED";
200 case Op_SIGMOID:
201 return "Op_SIGMOID";
202 case Op_TANH:
203 return "Op_TANH";
204 case Op_ARITHMETIC_RIGHT_SHIFT:
205 return "Op_ARITHMETIC_RIGHT_SHIFT";
206 case Op_BITWISE_AND:
207 return "Op_BITWISE_AND";
208 case Op_BITWISE_OR:
209 return "Op_BITWISE_OR";
210 case Op_BITWISE_XOR:
211 return "Op_BITWISE_XOR";
212 case Op_INTDIV:
213 return "Op_INTDIV";
214 case Op_LOGICAL_AND:
215 return "Op_LOGICAL_AND";
216 case Op_LOGICAL_LEFT_SHIFT:
217 return "Op_LOGICAL_LEFT_SHIFT";
218 case Op_LOGICAL_RIGHT_SHIFT:
219 return "Op_LOGICAL_RIGHT_SHIFT";
220 case Op_LOGICAL_OR:
221 return "Op_LOGICAL_OR";
222 case Op_LOGICAL_XOR:
223 return "Op_LOGICAL_XOR";
224 case Op_MAXIMUM:
225 return "Op_MAXIMUM";
226 case Op_MINIMUM:
227 return "Op_MINIMUM";
228 case Op_MUL:
229 return "Op_MUL";
230 case Op_POW:
231 return "Op_POW";
232 case Op_SUB:
233 return "Op_SUB";
234 case Op_TABLE:
235 return "Op_TABLE";
236 case Op_ABS:
237 return "Op_ABS";
238 case Op_BITWISE_NOT:
239 return "Op_BITWISE_NOT";
240 case Op_CEIL:
241 return "Op_CEIL";
242 case Op_CLZ:
243 return "Op_CLZ";
244 case Op_EXP:
245 return "Op_EXP";
246 case Op_FLOOR:
247 return "Op_FLOOR";
248 case Op_LOG:
249 return "Op_LOG";
250 case Op_LOGICAL_NOT:
251 return "Op_LOGICAL_NOT";
252 case Op_NEGATE:
253 return "Op_NEGATE";
254 case Op_RECIPROCAL:
255 return "Op_RECIPROCAL";
256 case Op_RSQRT:
257 return "Op_RSQRT";
258 case Op_SELECT:
259 return "Op_SELECT";
260 case Op_EQUAL:
261 return "Op_EQUAL";
262 case Op_GREATER:
263 return "Op_GREATER";
264 case Op_GREATER_EQUAL:
265 return "Op_GREATER_EQUAL";
266 case Op_REDUCE_ANY:
267 return "Op_REDUCE_ANY";
268 case Op_REDUCE_ALL:
269 return "Op_REDUCE_ALL";
270 case Op_REDUCE_MAX:
271 return "Op_REDUCE_MAX";
272 case Op_REDUCE_MIN:
273 return "Op_REDUCE_MIN";
274 case Op_REDUCE_PRODUCT:
275 return "Op_REDUCE_PRODUCT";
276 case Op_REDUCE_SUM:
277 return "Op_REDUCE_SUM";
278 case Op_CONCAT:
279 return "Op_CONCAT";
280 case Op_RESHAPE:
281 return "Op_RESHAPE";
282 case Op_REVERSE:
283 return "Op_REVERSE";
284 case Op_SLICE:
285 return "Op_SLICE";
286 case Op_TILE:
287 return "Op_TILE";
288 case Op_TRANSPOSE:
289 return "Op_TRANSPOSE";
290 case Op_GATHER:
291 return "Op_GATHER";
292 case Op_SCATTER:
293 return "Op_SCATTER";
294 case Op_RESIZE:
295 return "Op_RESIZE";
296 case Op_CAST:
297 return "Op_CAST";
298 case Op_RESCALE:
299 return "Op_RESCALE";
300 case Op_CONST:
301 return "Op_CONST";
302 case Op_IDENTITY:
303 return "Op_IDENTITY";
304 case Op_CUSTOM:
305 return "Op_CUSTOM";
306 case Op_COND_IF:
307 return "Op_COND_IF";
308 case Op_WHILE_LOOP:
309 return "Op_WHILE_LOOP";
Narumol Prangnawaratad323af2023-09-29 17:00:38 +0100310 case Op_FFT2D:
311 return "Op_FFT2D";
312 case Op_RFFT2D:
313 return "Op_RFFT2D";
Teresa Charlin571a4f72024-03-26 11:18:42 +0000314 case Op_ERF:
315 return "Op_ERF";
316 case Op_DIM: // = Op_MAX
317 return "Op_DIM";
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000318 }
319 return "";
320}
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000321
322inline std::vector<uint8_t> ConvertConstantTensorDataToBuffer(const std::shared_ptr<ConstTensorHandle>& tensorHandle)
323{
Teresa Charlin3fbad942022-12-15 10:35:37 +0000324 tosa_err_t error = tosa_err_t::TOSA_OK;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000325 std::vector<uint8_t> uint8Data;
326 auto tensorInfo = tensorHandle->GetTensorInfo();
327
328 switch (tensorInfo.GetDataType())
329 {
330 case DataType::Float32:
331 {
332 std::vector<float> data(tensorInfo.GetNumElements());
333 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
334
335 error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data);
336 break;
337 }
338 case DataType::Float16:
339 {
340 std::vector<float> data(tensorInfo.GetNumElements());
341 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
342
343 error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data);
344 break;
345 }
346 case DataType::QSymmS8:
347 case DataType::QAsymmS8:
348 {
349 std::vector<int8_t> data(tensorInfo.GetNumElements());
350 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
351
352 error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data);
353 break;
354 }
355 case DataType::QAsymmU8:
356 {
357 memcpy(uint8Data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
358 break;
359 }
360 case DataType::QSymmS16:
361 {
362 std::vector<int16_t> data(tensorInfo.GetNumElements());
363 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
364
365 error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data);
366 break;
367 }
368 case DataType::Signed32:
369 {
370 std::vector<int32_t> data(tensorInfo.GetNumElements());
371 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
372
373 error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data);
374 break;
375 }
376 default:
377 {
378 throw armnn::Exception("SetConstantTensorData: An unsupported data type was encountered.");
379 }
380 }
381
382 if(error != tosa_err_t::TOSA_OK)
383 {
384 throw armnn::Exception("SetConstantTensorData: An error occurred when converting constant data");
385 }
386
387 tensorHandle->Unmap();
388 return uint8Data;
389}
Teresa Charlinca5a23a2023-12-15 14:20:47 +0000390
Teresa Charlinca5a23a2023-12-15 14:20:47 +0000391inline std::vector<uint8_t> CreateConstTosaData(const void* value,
392 DType dtype,
393 const std::vector<int32_t>& shape)
394{
395 std::vector<uint8_t> uint8Data;
396 tosa_err_t error = tosa_err_t::TOSA_OK;
397
398 unsigned int numElements = 1;
399 for (auto s : shape)
400 {
401 if (s < 0)
402 {
403 throw armnn::Exception("CreateConstTosaData: negative shape elements unhandled.");
404 }
405 numElements = numElements * static_cast<unsigned int>(s);
406 }
407
408 switch (dtype)
409 {
410 case DType::DType_FP32:
411 {
412 std::vector<float> data(numElements, *static_cast<const float*>(value));
413 error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data);
414 break;
415 }
416 case DType::DType_FP16:
417 {
418 std::vector<float> data(numElements, *static_cast<const float*>(value));
419 error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data);
420 break;
421 }
422 case DType::DType_INT48:
423 {
424 std::vector<int64_t> data(numElements, *static_cast<const int64_t*>(value));
425 error = TosaSerializationHandler::ConvertI48toU8(data, uint8Data);
426 break;
427 }
428 case DType::DType_INT32:
429 {
430 std::vector<int32_t> data(numElements, *static_cast<const int32_t*>(value));
431 error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data);
432 break;
433 }
434 case DType::DType_INT16:
435 {
436 std::vector<int16_t> data(numElements, *static_cast<const int16_t*>(value));
437 error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data);
438 break;
439 }
440 case DType::DType_INT8:
441 {
442 std::vector<int8_t> data(numElements, *static_cast<const int8_t*>(value));
443 error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data);
444 break;
445 }
446 case DType::DType_INT4:
447 {
448 std::vector<int8_t> data(numElements, *static_cast<const int8_t*>(value));
449 error = TosaSerializationHandler::ConvertI4toU8(data, uint8Data);
450 break;
451 }
452 case DType::DType_BOOL:
453 {
454 std::vector<bool> data(numElements, *static_cast<const bool*>(value));
455 error = TosaSerializationHandler::ConvertBooltoU8(data, uint8Data);
456 break;
457 }
458 default:
459 {
460 throw armnn::Exception("CreateConstTosaData: An unsupported data type was encountered.");
461 }
462 }
463
464 if(error != tosa_err_t::TOSA_OK)
465 {
466 throw armnn::Exception("CreateConstTosaData: An error occurred when converting constant data");
467 }
468
469 return uint8Data;
470}
471
472template<typename T>
473inline void CreateConstTosaOperator(const std::string& outputName,
474 const T value,
475 DType dtype,
476 const std::vector<int32_t>& shape,
477 TosaSerializationOperator*& op,
478 TosaSerializationTensor*& tensor)
479{
480 std::vector<uint8_t> uint8Data = CreateConstTosaData(static_cast<const void *>(&value), dtype, shape);
481
482 op = new TosaSerializationOperator(Op_CONST, Attribute_NONE, nullptr, {}, {outputName});
483 ARMNN_THROW_MSG_IF_FALSE(op, armnn::Exception, "CreateConstTosaOperator: failed to created operator");
484
485 tensor = new TosaSerializationTensor(outputName, shape, dtype, uint8Data);
486 ARMNN_THROW_MSG_IF_FALSE(tensor, armnn::Exception, "CreateConstTosaOperator: failed to created tensor");
Tracy Narine10403ec2023-11-28 11:55:08 +0000487}
488
Tracy Narine91ffe3d2024-01-23 09:40:00 +0000489// Macro to preserve usage of a code block as the TOSA library version advances. Parameters
490// specify the minimum version required by the code block.
491#define TOSA_COMPAT_VERSION(_major, _minor, _patch) \
492 (TOSA_VERSION_MAJOR >= _major) || \
493 (TOSA_VERSION_MINOR >= _minor) || \
494 (TOSA_VERSION_PATCH >= _patch)