blob: 288966badd43caacb815dfffe44cdd2dbf715242 [file] [log] [blame]
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +00008#include <Layer.hpp>
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01009#include <armnn/Tensor.hpp>
10#include <armnn/Types.hpp>
11
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000012#include "common/include/ProfilingGuid.hpp"
13
14#include <tosa_serialization_handler.h>
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010015
16using namespace armnn;
17using namespace tosa;
18
19// Function to return Tosa datatype from input ArmNN datatype.
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010020inline DType ArmNNToDType(const DataType& type)
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010021{
22 switch (type)
23 {
24 case DataType::Float16:
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010025 case DataType::BFloat16:
Matthew Sloyanda824cc2022-10-10 12:43:20 +010026 return DType_FP16;
27 case DataType::Float32:
28 return DType_FP32;
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010029 case DataType::QAsymmU8:
30 return DType_UINT8;
31 case DataType::QSymmS8:
32 case DataType::QAsymmS8:
33 return DType_INT8;
34 case DataType::QSymmS16:
35 return DType_INT16;
36 case DataType::Signed32:
37 return DType_INT32;
38 case DataType::Signed64:
39 // No signed 64, only DType_INT48.
40 return DType_UNKNOWN;
41 case DataType::Boolean:
42 return DType_BOOL;
43 default:
44 return DType_UNKNOWN;
45 }
46}
47
48// Function to return Tosa tensor shape from input ArmNN tensor shape.
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010049inline std::vector<int32_t> GetTosaTensorShape(const TensorShape& shape)
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010050{
51 std::vector<int32_t> returnShape;
52 for (u_int32_t i = 0; i < shape.GetNumDimensions(); i++)
53 {
54 returnShape.push_back(static_cast<int32_t>(shape[i]));
55 }
56 return returnShape;
57}
58
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000059// Function that generates unique name using the layer type, input slot and layer guid.
60inline std::string GenerateUniqueName(const Layer& layer, uint32_t layerSlot)
61{
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000062 std::string guid = std::to_string(layer.GetGuid());
63 std::string slotAndGuid = std::to_string(layerSlot) + "_" + guid;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000064
Cathal Corbettb30e6552022-12-07 11:50:50 +000065 switch (layer.GetType())
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000066 {
Cathal Corbettb30e6552022-12-07 11:50:50 +000067 case LayerType::Input:
68 return "input" + slotAndGuid;
69 case LayerType::Output:
70 return "output" + slotAndGuid;
71 case LayerType::Constant:
72 return "constant_" + guid;
73 default:
74 return "intermediate" + slotAndGuid;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000075 }
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000076}
77
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010078// Function to return unique int as a string to ensure uniqueness between all input, output and block names.
79static int uniqueTosaMappingID = 0;
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010080inline std::string GetUniqueTosaMappingID()
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010081{
82 return std::to_string(++uniqueTosaMappingID);
83}
Cathal Corbettbd18eab2022-11-15 12:56:16 +000084
Cathal Corbettb30e6552022-12-07 11:50:50 +000085// Function to return Tosa DType as string.
86inline std::string TosaDTypeToString(DType tosaDType)
87{
88 switch (tosaDType)
89 {
90 case DType_UNKNOWN:
91 return "DType_UNKNOWN";
92 case DType_BOOL:
93 return "DType_BOOL";
94 case DType_UINT8:
95 return "DType_UINT8";
96 case DType_INT4:
97 return "DType_INT4";
98 case DType_INT8:
99 return "DType_INT8";
100 case DType_INT16:
101 return "DType_INT16";
102 case DType_INT32:
103 return "DType_INT32";
104 case DType_INT48:
105 return "DType_INT48";
106 case DType_FP32:
107 return "DType_FP32";
108 case DType_UINT16:
109 return "DType_UINT16";
110 case DType_FP16:
111 return "DType_FP16";
112 }
113 return "";
114}
115
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000116// Function to return Tosa Op as string.
117inline std::string TosaOpToString(Op tosaOp)
118{
119 switch (tosaOp)
120 {
121 case Op_ADD:
122 return "Op_ADD";
123 case Op_AVG_POOL2D:
124 return "Op_AVG_POOL2D";
125 case Op_MAX_POOL2D:
126 return "Op_MAX_POOL2D";
127 case Op_PAD:
128 return "Op_PAD";
129 case Op_UNKNOWN:
130 return "Op_UNKNOWN";
131 case Op_ARGMAX:
132 return "Op_ARGMAX";
133 case Op_CONV2D:
134 return "Op_CONV2D";
135 case Op_CONV3D:
136 return "Op_CONV3D";
137 case Op_DEPTHWISE_CONV2D:
138 return "Op_DEPTHWISE_CONV2D";
139 case Op_FULLY_CONNECTED:
140 return "Op_FULLY_CONNECTED";
141 case Op_MATMUL:
142 return "Op_MATMUL";
143 case Op_TRANSPOSE_CONV2D:
144 return "Op_TRANSPOSE_CONV2D";
145 case Op_CLAMP:
146 return "Op_CLAMP";
147 case Op_RESERVED:
148 return "Op_RESERVED";
149 case Op_SIGMOID:
150 return "Op_SIGMOID";
151 case Op_TANH:
152 return "Op_TANH";
153 case Op_ARITHMETIC_RIGHT_SHIFT:
154 return "Op_ARITHMETIC_RIGHT_SHIFT";
155 case Op_BITWISE_AND:
156 return "Op_BITWISE_AND";
157 case Op_BITWISE_OR:
158 return "Op_BITWISE_OR";
159 case Op_BITWISE_XOR:
160 return "Op_BITWISE_XOR";
161 case Op_INTDIV:
162 return "Op_INTDIV";
163 case Op_LOGICAL_AND:
164 return "Op_LOGICAL_AND";
165 case Op_LOGICAL_LEFT_SHIFT:
166 return "Op_LOGICAL_LEFT_SHIFT";
167 case Op_LOGICAL_RIGHT_SHIFT:
168 return "Op_LOGICAL_RIGHT_SHIFT";
169 case Op_LOGICAL_OR:
170 return "Op_LOGICAL_OR";
171 case Op_LOGICAL_XOR:
172 return "Op_LOGICAL_XOR";
173 case Op_MAXIMUM:
174 return "Op_MAXIMUM";
175 case Op_MINIMUM:
176 return "Op_MINIMUM";
177 case Op_MUL:
178 return "Op_MUL";
179 case Op_POW:
180 return "Op_POW";
181 case Op_SUB:
182 return "Op_SUB";
183 case Op_TABLE:
184 return "Op_TABLE";
185 case Op_ABS:
186 return "Op_ABS";
187 case Op_BITWISE_NOT:
188 return "Op_BITWISE_NOT";
189 case Op_CEIL:
190 return "Op_CEIL";
191 case Op_CLZ:
192 return "Op_CLZ";
193 case Op_EXP:
194 return "Op_EXP";
195 case Op_FLOOR:
196 return "Op_FLOOR";
197 case Op_LOG:
198 return "Op_LOG";
199 case Op_LOGICAL_NOT:
200 return "Op_LOGICAL_NOT";
201 case Op_NEGATE:
202 return "Op_NEGATE";
203 case Op_RECIPROCAL:
204 return "Op_RECIPROCAL";
205 case Op_RSQRT:
206 return "Op_RSQRT";
207 case Op_SELECT:
208 return "Op_SELECT";
209 case Op_EQUAL:
210 return "Op_EQUAL";
211 case Op_GREATER:
212 return "Op_GREATER";
213 case Op_GREATER_EQUAL:
214 return "Op_GREATER_EQUAL";
215 case Op_REDUCE_ANY:
216 return "Op_REDUCE_ANY";
217 case Op_REDUCE_ALL:
218 return "Op_REDUCE_ALL";
219 case Op_REDUCE_MAX:
220 return "Op_REDUCE_MAX";
221 case Op_REDUCE_MIN:
222 return "Op_REDUCE_MIN";
223 case Op_REDUCE_PRODUCT:
224 return "Op_REDUCE_PRODUCT";
225 case Op_REDUCE_SUM:
226 return "Op_REDUCE_SUM";
227 case Op_CONCAT:
228 return "Op_CONCAT";
229 case Op_RESHAPE:
230 return "Op_RESHAPE";
231 case Op_REVERSE:
232 return "Op_REVERSE";
233 case Op_SLICE:
234 return "Op_SLICE";
235 case Op_TILE:
236 return "Op_TILE";
237 case Op_TRANSPOSE:
238 return "Op_TRANSPOSE";
239 case Op_GATHER:
240 return "Op_GATHER";
241 case Op_SCATTER:
242 return "Op_SCATTER";
243 case Op_RESIZE:
244 return "Op_RESIZE";
245 case Op_CAST:
246 return "Op_CAST";
247 case Op_RESCALE:
248 return "Op_RESCALE";
249 case Op_CONST:
250 return "Op_CONST";
251 case Op_IDENTITY:
252 return "Op_IDENTITY";
253 case Op_CUSTOM:
254 return "Op_CUSTOM";
255 case Op_COND_IF:
256 return "Op_COND_IF";
257 case Op_WHILE_LOOP:
258 return "Op_WHILE_LOOP";
259 }
260 return "";
261}
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000262
263inline std::vector<uint8_t> ConvertConstantTensorDataToBuffer(const std::shared_ptr<ConstTensorHandle>& tensorHandle)
264{
265 tosa_err_t error;
266 std::vector<uint8_t> uint8Data;
267 auto tensorInfo = tensorHandle->GetTensorInfo();
268
269 switch (tensorInfo.GetDataType())
270 {
271 case DataType::Float32:
272 {
273 std::vector<float> data(tensorInfo.GetNumElements());
274 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
275
276 error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data);
277 break;
278 }
279 case DataType::Float16:
280 {
281 std::vector<float> data(tensorInfo.GetNumElements());
282 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
283
284 error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data);
285 break;
286 }
287 case DataType::QSymmS8:
288 case DataType::QAsymmS8:
289 {
290 std::vector<int8_t> data(tensorInfo.GetNumElements());
291 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
292
293 error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data);
294 break;
295 }
296 case DataType::QAsymmU8:
297 {
298 memcpy(uint8Data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
299 break;
300 }
301 case DataType::QSymmS16:
302 {
303 std::vector<int16_t> data(tensorInfo.GetNumElements());
304 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
305
306 error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data);
307 break;
308 }
309 case DataType::Signed32:
310 {
311 std::vector<int32_t> data(tensorInfo.GetNumElements());
312 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
313
314 error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data);
315 break;
316 }
317 default:
318 {
319 throw armnn::Exception("SetConstantTensorData: An unsupported data type was encountered.");
320 }
321 }
322
323 if(error != tosa_err_t::TOSA_OK)
324 {
325 throw armnn::Exception("SetConstantTensorData: An error occurred when converting constant data");
326 }
327
328 tensorHandle->Unmap();
329 return uint8Data;
330}