blob: 047e0a1f42d243e208838c17329ae81e4929532f [file] [log] [blame]
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01001//
Tracy Narine10403ec2023-11-28 11:55:08 +00002// Copyright © 2022-2024 Arm Ltd and Contributors. All rights reserved.
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +00008#include <Layer.hpp>
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01009#include <armnn/Tensor.hpp>
10#include <armnn/Types.hpp>
11
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000012#include "common/include/ProfilingGuid.hpp"
13
14#include <tosa_serialization_handler.h>
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010015
16using namespace armnn;
17using namespace tosa;
18
Narumol Prangnawaratad323af2023-09-29 17:00:38 +010019const std::string mainName = "main";
20
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010021// Function to return Tosa datatype from input ArmNN datatype.
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010022inline DType ArmNNToDType(const DataType& type)
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010023{
24 switch (type)
25 {
26 case DataType::Float16:
Matthew Sloyanda824cc2022-10-10 12:43:20 +010027 return DType_FP16;
Narumol Prangnawaratad323af2023-09-29 17:00:38 +010028 case DataType::BFloat16:
29 return DType_BF16;
Matthew Sloyanda824cc2022-10-10 12:43:20 +010030 case DataType::Float32:
31 return DType_FP32;
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010032 case DataType::QAsymmU8:
33 return DType_UINT8;
34 case DataType::QSymmS8:
35 case DataType::QAsymmS8:
36 return DType_INT8;
37 case DataType::QSymmS16:
38 return DType_INT16;
39 case DataType::Signed32:
40 return DType_INT32;
41 case DataType::Signed64:
42 // No signed 64, only DType_INT48.
43 return DType_UNKNOWN;
44 case DataType::Boolean:
45 return DType_BOOL;
46 default:
47 return DType_UNKNOWN;
48 }
49}
50
51// Function to return Tosa tensor shape from input ArmNN tensor shape.
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010052inline std::vector<int32_t> GetTosaTensorShape(const TensorShape& shape)
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010053{
54 std::vector<int32_t> returnShape;
55 for (u_int32_t i = 0; i < shape.GetNumDimensions(); i++)
56 {
57 returnShape.push_back(static_cast<int32_t>(shape[i]));
58 }
59 return returnShape;
60}
61
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000062// Function that generates unique name using the layer type, input slot and layer guid.
63inline std::string GenerateUniqueName(const Layer& layer, uint32_t layerSlot)
64{
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000065 std::string guid = std::to_string(layer.GetGuid());
66 std::string slotAndGuid = std::to_string(layerSlot) + "_" + guid;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000067
Cathal Corbettb30e6552022-12-07 11:50:50 +000068 switch (layer.GetType())
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000069 {
Cathal Corbettb30e6552022-12-07 11:50:50 +000070 case LayerType::Input:
71 return "input" + slotAndGuid;
72 case LayerType::Output:
73 return "output" + slotAndGuid;
74 case LayerType::Constant:
75 return "constant_" + guid;
76 default:
77 return "intermediate" + slotAndGuid;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000078 }
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000079}
80
Matthew Sloyanfc9d5e72022-12-08 13:38:23 +000081// Function that generates unique output name using the layer type, input slot and layer guid.
82inline std::string GenerateUniqueOutputName(const Layer& layer, uint32_t layerSlot)
83{
84 Layer& connectedLayer = layer.GetOutputSlot().GetConnection(0)->GetOwningLayer();
85
86 // Get the layer connected to the output slot, if output use that layer and id,
87 // otherwise use current layer and id.
88 if(connectedLayer.GetType() == LayerType::Output)
89 {
90 return GenerateUniqueName(connectedLayer, layerSlot);
91 }
92 else
93 {
94 return GenerateUniqueName(layer, layerSlot);
95 }
96}
97
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010098// Function to return unique int as a string to ensure uniqueness between all input, output and block names.
99static int uniqueTosaMappingID = 0;
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100100inline std::string GetUniqueTosaMappingID()
Cathal Corbett9c9d5b92022-08-17 17:30:16 +0100101{
102 return std::to_string(++uniqueTosaMappingID);
103}
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000104
Cathal Corbettb30e6552022-12-07 11:50:50 +0000105// Function to return Tosa DType as string.
106inline std::string TosaDTypeToString(DType tosaDType)
107{
108 switch (tosaDType)
109 {
110 case DType_UNKNOWN:
111 return "DType_UNKNOWN";
112 case DType_BOOL:
113 return "DType_BOOL";
114 case DType_UINT8:
115 return "DType_UINT8";
116 case DType_INT4:
117 return "DType_INT4";
118 case DType_INT8:
119 return "DType_INT8";
120 case DType_INT16:
121 return "DType_INT16";
122 case DType_INT32:
123 return "DType_INT32";
124 case DType_INT48:
125 return "DType_INT48";
126 case DType_FP32:
127 return "DType_FP32";
128 case DType_UINT16:
129 return "DType_UINT16";
130 case DType_FP16:
131 return "DType_FP16";
Narumol Prangnawaratad323af2023-09-29 17:00:38 +0100132 case DType_BF16:
133 return "DType_BF16";
Teresa Charlin571a4f72024-03-26 11:18:42 +0000134 case DType_SHAPE:
135 return "DType_SHAPE";
Cathal Corbettb30e6552022-12-07 11:50:50 +0000136 }
137 return "";
138}
139
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000140// Function to return Tosa Op as string.
141inline std::string TosaOpToString(Op tosaOp)
142{
143 switch (tosaOp)
144 {
145 case Op_ADD:
146 return "Op_ADD";
147 case Op_AVG_POOL2D:
148 return "Op_AVG_POOL2D";
149 case Op_MAX_POOL2D:
150 return "Op_MAX_POOL2D";
151 case Op_PAD:
152 return "Op_PAD";
153 case Op_UNKNOWN:
154 return "Op_UNKNOWN";
155 case Op_ARGMAX:
156 return "Op_ARGMAX";
157 case Op_CONV2D:
158 return "Op_CONV2D";
159 case Op_CONV3D:
160 return "Op_CONV3D";
161 case Op_DEPTHWISE_CONV2D:
162 return "Op_DEPTHWISE_CONV2D";
163 case Op_FULLY_CONNECTED:
164 return "Op_FULLY_CONNECTED";
165 case Op_MATMUL:
166 return "Op_MATMUL";
167 case Op_TRANSPOSE_CONV2D:
168 return "Op_TRANSPOSE_CONV2D";
169 case Op_CLAMP:
170 return "Op_CLAMP";
171 case Op_RESERVED:
172 return "Op_RESERVED";
173 case Op_SIGMOID:
174 return "Op_SIGMOID";
175 case Op_TANH:
176 return "Op_TANH";
177 case Op_ARITHMETIC_RIGHT_SHIFT:
178 return "Op_ARITHMETIC_RIGHT_SHIFT";
179 case Op_BITWISE_AND:
180 return "Op_BITWISE_AND";
181 case Op_BITWISE_OR:
182 return "Op_BITWISE_OR";
183 case Op_BITWISE_XOR:
184 return "Op_BITWISE_XOR";
185 case Op_INTDIV:
186 return "Op_INTDIV";
187 case Op_LOGICAL_AND:
188 return "Op_LOGICAL_AND";
189 case Op_LOGICAL_LEFT_SHIFT:
190 return "Op_LOGICAL_LEFT_SHIFT";
191 case Op_LOGICAL_RIGHT_SHIFT:
192 return "Op_LOGICAL_RIGHT_SHIFT";
193 case Op_LOGICAL_OR:
194 return "Op_LOGICAL_OR";
195 case Op_LOGICAL_XOR:
196 return "Op_LOGICAL_XOR";
197 case Op_MAXIMUM:
198 return "Op_MAXIMUM";
199 case Op_MINIMUM:
200 return "Op_MINIMUM";
201 case Op_MUL:
202 return "Op_MUL";
203 case Op_POW:
204 return "Op_POW";
205 case Op_SUB:
206 return "Op_SUB";
207 case Op_TABLE:
208 return "Op_TABLE";
209 case Op_ABS:
210 return "Op_ABS";
211 case Op_BITWISE_NOT:
212 return "Op_BITWISE_NOT";
213 case Op_CEIL:
214 return "Op_CEIL";
215 case Op_CLZ:
216 return "Op_CLZ";
217 case Op_EXP:
218 return "Op_EXP";
219 case Op_FLOOR:
220 return "Op_FLOOR";
221 case Op_LOG:
222 return "Op_LOG";
223 case Op_LOGICAL_NOT:
224 return "Op_LOGICAL_NOT";
225 case Op_NEGATE:
226 return "Op_NEGATE";
227 case Op_RECIPROCAL:
228 return "Op_RECIPROCAL";
229 case Op_RSQRT:
230 return "Op_RSQRT";
231 case Op_SELECT:
232 return "Op_SELECT";
233 case Op_EQUAL:
234 return "Op_EQUAL";
235 case Op_GREATER:
236 return "Op_GREATER";
237 case Op_GREATER_EQUAL:
238 return "Op_GREATER_EQUAL";
239 case Op_REDUCE_ANY:
240 return "Op_REDUCE_ANY";
241 case Op_REDUCE_ALL:
242 return "Op_REDUCE_ALL";
243 case Op_REDUCE_MAX:
244 return "Op_REDUCE_MAX";
245 case Op_REDUCE_MIN:
246 return "Op_REDUCE_MIN";
247 case Op_REDUCE_PRODUCT:
248 return "Op_REDUCE_PRODUCT";
249 case Op_REDUCE_SUM:
250 return "Op_REDUCE_SUM";
251 case Op_CONCAT:
252 return "Op_CONCAT";
253 case Op_RESHAPE:
254 return "Op_RESHAPE";
255 case Op_REVERSE:
256 return "Op_REVERSE";
257 case Op_SLICE:
258 return "Op_SLICE";
259 case Op_TILE:
260 return "Op_TILE";
261 case Op_TRANSPOSE:
262 return "Op_TRANSPOSE";
263 case Op_GATHER:
264 return "Op_GATHER";
265 case Op_SCATTER:
266 return "Op_SCATTER";
267 case Op_RESIZE:
268 return "Op_RESIZE";
269 case Op_CAST:
270 return "Op_CAST";
271 case Op_RESCALE:
272 return "Op_RESCALE";
273 case Op_CONST:
274 return "Op_CONST";
275 case Op_IDENTITY:
276 return "Op_IDENTITY";
277 case Op_CUSTOM:
278 return "Op_CUSTOM";
279 case Op_COND_IF:
280 return "Op_COND_IF";
281 case Op_WHILE_LOOP:
282 return "Op_WHILE_LOOP";
Narumol Prangnawaratad323af2023-09-29 17:00:38 +0100283 case Op_FFT2D:
284 return "Op_FFT2D";
285 case Op_RFFT2D:
286 return "Op_RFFT2D";
Teresa Charlin571a4f72024-03-26 11:18:42 +0000287 case Op_ERF:
288 return "Op_ERF";
289 case Op_DIM: // = Op_MAX
290 return "Op_DIM";
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000291 }
292 return "";
293}
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000294
295inline std::vector<uint8_t> ConvertConstantTensorDataToBuffer(const std::shared_ptr<ConstTensorHandle>& tensorHandle)
296{
Teresa Charlin3fbad942022-12-15 10:35:37 +0000297 tosa_err_t error = tosa_err_t::TOSA_OK;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000298 std::vector<uint8_t> uint8Data;
299 auto tensorInfo = tensorHandle->GetTensorInfo();
300
301 switch (tensorInfo.GetDataType())
302 {
303 case DataType::Float32:
304 {
305 std::vector<float> data(tensorInfo.GetNumElements());
306 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
307
308 error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data);
309 break;
310 }
311 case DataType::Float16:
312 {
313 std::vector<float> data(tensorInfo.GetNumElements());
314 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
315
316 error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data);
317 break;
318 }
319 case DataType::QSymmS8:
320 case DataType::QAsymmS8:
321 {
322 std::vector<int8_t> data(tensorInfo.GetNumElements());
323 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
324
325 error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data);
326 break;
327 }
328 case DataType::QAsymmU8:
329 {
330 memcpy(uint8Data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
331 break;
332 }
333 case DataType::QSymmS16:
334 {
335 std::vector<int16_t> data(tensorInfo.GetNumElements());
336 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
337
338 error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data);
339 break;
340 }
341 case DataType::Signed32:
342 {
343 std::vector<int32_t> data(tensorInfo.GetNumElements());
344 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
345
346 error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data);
347 break;
348 }
349 default:
350 {
351 throw armnn::Exception("SetConstantTensorData: An unsupported data type was encountered.");
352 }
353 }
354
355 if(error != tosa_err_t::TOSA_OK)
356 {
357 throw armnn::Exception("SetConstantTensorData: An error occurred when converting constant data");
358 }
359
360 tensorHandle->Unmap();
361 return uint8Data;
362}
Teresa Charlinca5a23a2023-12-15 14:20:47 +0000363
Teresa Charlinca5a23a2023-12-15 14:20:47 +0000364inline std::vector<uint8_t> CreateConstTosaData(const void* value,
365 DType dtype,
366 const std::vector<int32_t>& shape)
367{
368 std::vector<uint8_t> uint8Data;
369 tosa_err_t error = tosa_err_t::TOSA_OK;
370
371 unsigned int numElements = 1;
372 for (auto s : shape)
373 {
374 if (s < 0)
375 {
376 throw armnn::Exception("CreateConstTosaData: negative shape elements unhandled.");
377 }
378 numElements = numElements * static_cast<unsigned int>(s);
379 }
380
381 switch (dtype)
382 {
383 case DType::DType_FP32:
384 {
385 std::vector<float> data(numElements, *static_cast<const float*>(value));
386 error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data);
387 break;
388 }
389 case DType::DType_FP16:
390 {
391 std::vector<float> data(numElements, *static_cast<const float*>(value));
392 error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data);
393 break;
394 }
395 case DType::DType_INT48:
396 {
397 std::vector<int64_t> data(numElements, *static_cast<const int64_t*>(value));
398 error = TosaSerializationHandler::ConvertI48toU8(data, uint8Data);
399 break;
400 }
401 case DType::DType_INT32:
402 {
403 std::vector<int32_t> data(numElements, *static_cast<const int32_t*>(value));
404 error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data);
405 break;
406 }
407 case DType::DType_INT16:
408 {
409 std::vector<int16_t> data(numElements, *static_cast<const int16_t*>(value));
410 error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data);
411 break;
412 }
413 case DType::DType_INT8:
414 {
415 std::vector<int8_t> data(numElements, *static_cast<const int8_t*>(value));
416 error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data);
417 break;
418 }
419 case DType::DType_INT4:
420 {
421 std::vector<int8_t> data(numElements, *static_cast<const int8_t*>(value));
422 error = TosaSerializationHandler::ConvertI4toU8(data, uint8Data);
423 break;
424 }
425 case DType::DType_BOOL:
426 {
427 std::vector<bool> data(numElements, *static_cast<const bool*>(value));
428 error = TosaSerializationHandler::ConvertBooltoU8(data, uint8Data);
429 break;
430 }
431 default:
432 {
433 throw armnn::Exception("CreateConstTosaData: An unsupported data type was encountered.");
434 }
435 }
436
437 if(error != tosa_err_t::TOSA_OK)
438 {
439 throw armnn::Exception("CreateConstTosaData: An error occurred when converting constant data");
440 }
441
442 return uint8Data;
443}
444
445template<typename T>
446inline void CreateConstTosaOperator(const std::string& outputName,
447 const T value,
448 DType dtype,
449 const std::vector<int32_t>& shape,
450 TosaSerializationOperator*& op,
451 TosaSerializationTensor*& tensor)
452{
453 std::vector<uint8_t> uint8Data = CreateConstTosaData(static_cast<const void *>(&value), dtype, shape);
454
455 op = new TosaSerializationOperator(Op_CONST, Attribute_NONE, nullptr, {}, {outputName});
456 ARMNN_THROW_MSG_IF_FALSE(op, armnn::Exception, "CreateConstTosaOperator: failed to created operator");
457
458 tensor = new TosaSerializationTensor(outputName, shape, dtype, uint8Data);
459 ARMNN_THROW_MSG_IF_FALSE(tensor, armnn::Exception, "CreateConstTosaOperator: failed to created tensor");
Tracy Narine10403ec2023-11-28 11:55:08 +0000460}
461
Tracy Narine91ffe3d2024-01-23 09:40:00 +0000462// Macro to preserve usage of a code block as the TOSA library version advances. Parameters
463// specify the minimum version required by the code block.
464#define TOSA_COMPAT_VERSION(_major, _minor, _patch) \
465 (TOSA_VERSION_MAJOR >= _major) || \
466 (TOSA_VERSION_MINOR >= _minor) || \
467 (TOSA_VERSION_PATCH >= _patch)