blob: be2f53e413acfc749598336f6829c5545ac36389 [file] [log] [blame]
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +00008#include <Layer.hpp>
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01009#include <armnn/Tensor.hpp>
10#include <armnn/Types.hpp>
11
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000012#include "common/include/ProfilingGuid.hpp"
13
14#include <tosa_serialization_handler.h>
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010015
16using namespace armnn;
17using namespace tosa;
18
19// Function to return Tosa datatype from input ArmNN datatype.
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010020inline DType ArmNNToDType(const DataType& type)
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010021{
22 switch (type)
23 {
24 case DataType::Float16:
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010025 case DataType::BFloat16:
Matthew Sloyanda824cc2022-10-10 12:43:20 +010026 return DType_FP16;
27 case DataType::Float32:
28 return DType_FP32;
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010029 case DataType::QAsymmU8:
30 return DType_UINT8;
31 case DataType::QSymmS8:
32 case DataType::QAsymmS8:
33 return DType_INT8;
34 case DataType::QSymmS16:
35 return DType_INT16;
36 case DataType::Signed32:
37 return DType_INT32;
38 case DataType::Signed64:
39 // No signed 64, only DType_INT48.
40 return DType_UNKNOWN;
41 case DataType::Boolean:
42 return DType_BOOL;
43 default:
44 return DType_UNKNOWN;
45 }
46}
47
48// Function to return Tosa tensor shape from input ArmNN tensor shape.
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010049inline std::vector<int32_t> GetTosaTensorShape(const TensorShape& shape)
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010050{
51 std::vector<int32_t> returnShape;
52 for (u_int32_t i = 0; i < shape.GetNumDimensions(); i++)
53 {
54 returnShape.push_back(static_cast<int32_t>(shape[i]));
55 }
56 return returnShape;
57}
58
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000059// Function that generates unique name using the layer type, input slot and layer guid.
60inline std::string GenerateUniqueName(const Layer& layer, uint32_t layerSlot)
61{
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000062 std::string guid = std::to_string(layer.GetGuid());
63 std::string slotAndGuid = std::to_string(layerSlot) + "_" + guid;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000064
Cathal Corbettb30e6552022-12-07 11:50:50 +000065 switch (layer.GetType())
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000066 {
Cathal Corbettb30e6552022-12-07 11:50:50 +000067 case LayerType::Input:
68 return "input" + slotAndGuid;
69 case LayerType::Output:
70 return "output" + slotAndGuid;
71 case LayerType::Constant:
72 return "constant_" + guid;
73 default:
74 return "intermediate" + slotAndGuid;
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000075 }
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000076}
77
Matthew Sloyanfc9d5e72022-12-08 13:38:23 +000078// Function that generates unique output name using the layer type, input slot and layer guid.
79inline std::string GenerateUniqueOutputName(const Layer& layer, uint32_t layerSlot)
80{
81 Layer& connectedLayer = layer.GetOutputSlot().GetConnection(0)->GetOwningLayer();
82
83 // Get the layer connected to the output slot, if output use that layer and id,
84 // otherwise use current layer and id.
85 if(connectedLayer.GetType() == LayerType::Output)
86 {
87 return GenerateUniqueName(connectedLayer, layerSlot);
88 }
89 else
90 {
91 return GenerateUniqueName(layer, layerSlot);
92 }
93}
94
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010095// Function to return unique int as a string to ensure uniqueness between all input, output and block names.
96static int uniqueTosaMappingID = 0;
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010097inline std::string GetUniqueTosaMappingID()
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010098{
99 return std::to_string(++uniqueTosaMappingID);
100}
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000101
Cathal Corbettb30e6552022-12-07 11:50:50 +0000102// Function to return Tosa DType as string.
103inline std::string TosaDTypeToString(DType tosaDType)
104{
105 switch (tosaDType)
106 {
107 case DType_UNKNOWN:
108 return "DType_UNKNOWN";
109 case DType_BOOL:
110 return "DType_BOOL";
111 case DType_UINT8:
112 return "DType_UINT8";
113 case DType_INT4:
114 return "DType_INT4";
115 case DType_INT8:
116 return "DType_INT8";
117 case DType_INT16:
118 return "DType_INT16";
119 case DType_INT32:
120 return "DType_INT32";
121 case DType_INT48:
122 return "DType_INT48";
123 case DType_FP32:
124 return "DType_FP32";
125 case DType_UINT16:
126 return "DType_UINT16";
127 case DType_FP16:
128 return "DType_FP16";
129 }
130 return "";
131}
132
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000133// Function to return Tosa Op as string.
134inline std::string TosaOpToString(Op tosaOp)
135{
136 switch (tosaOp)
137 {
138 case Op_ADD:
139 return "Op_ADD";
140 case Op_AVG_POOL2D:
141 return "Op_AVG_POOL2D";
142 case Op_MAX_POOL2D:
143 return "Op_MAX_POOL2D";
144 case Op_PAD:
145 return "Op_PAD";
146 case Op_UNKNOWN:
147 return "Op_UNKNOWN";
148 case Op_ARGMAX:
149 return "Op_ARGMAX";
150 case Op_CONV2D:
151 return "Op_CONV2D";
152 case Op_CONV3D:
153 return "Op_CONV3D";
154 case Op_DEPTHWISE_CONV2D:
155 return "Op_DEPTHWISE_CONV2D";
156 case Op_FULLY_CONNECTED:
157 return "Op_FULLY_CONNECTED";
158 case Op_MATMUL:
159 return "Op_MATMUL";
160 case Op_TRANSPOSE_CONV2D:
161 return "Op_TRANSPOSE_CONV2D";
162 case Op_CLAMP:
163 return "Op_CLAMP";
164 case Op_RESERVED:
165 return "Op_RESERVED";
166 case Op_SIGMOID:
167 return "Op_SIGMOID";
168 case Op_TANH:
169 return "Op_TANH";
170 case Op_ARITHMETIC_RIGHT_SHIFT:
171 return "Op_ARITHMETIC_RIGHT_SHIFT";
172 case Op_BITWISE_AND:
173 return "Op_BITWISE_AND";
174 case Op_BITWISE_OR:
175 return "Op_BITWISE_OR";
176 case Op_BITWISE_XOR:
177 return "Op_BITWISE_XOR";
178 case Op_INTDIV:
179 return "Op_INTDIV";
180 case Op_LOGICAL_AND:
181 return "Op_LOGICAL_AND";
182 case Op_LOGICAL_LEFT_SHIFT:
183 return "Op_LOGICAL_LEFT_SHIFT";
184 case Op_LOGICAL_RIGHT_SHIFT:
185 return "Op_LOGICAL_RIGHT_SHIFT";
186 case Op_LOGICAL_OR:
187 return "Op_LOGICAL_OR";
188 case Op_LOGICAL_XOR:
189 return "Op_LOGICAL_XOR";
190 case Op_MAXIMUM:
191 return "Op_MAXIMUM";
192 case Op_MINIMUM:
193 return "Op_MINIMUM";
194 case Op_MUL:
195 return "Op_MUL";
196 case Op_POW:
197 return "Op_POW";
198 case Op_SUB:
199 return "Op_SUB";
200 case Op_TABLE:
201 return "Op_TABLE";
202 case Op_ABS:
203 return "Op_ABS";
204 case Op_BITWISE_NOT:
205 return "Op_BITWISE_NOT";
206 case Op_CEIL:
207 return "Op_CEIL";
208 case Op_CLZ:
209 return "Op_CLZ";
210 case Op_EXP:
211 return "Op_EXP";
212 case Op_FLOOR:
213 return "Op_FLOOR";
214 case Op_LOG:
215 return "Op_LOG";
216 case Op_LOGICAL_NOT:
217 return "Op_LOGICAL_NOT";
218 case Op_NEGATE:
219 return "Op_NEGATE";
220 case Op_RECIPROCAL:
221 return "Op_RECIPROCAL";
222 case Op_RSQRT:
223 return "Op_RSQRT";
224 case Op_SELECT:
225 return "Op_SELECT";
226 case Op_EQUAL:
227 return "Op_EQUAL";
228 case Op_GREATER:
229 return "Op_GREATER";
230 case Op_GREATER_EQUAL:
231 return "Op_GREATER_EQUAL";
232 case Op_REDUCE_ANY:
233 return "Op_REDUCE_ANY";
234 case Op_REDUCE_ALL:
235 return "Op_REDUCE_ALL";
236 case Op_REDUCE_MAX:
237 return "Op_REDUCE_MAX";
238 case Op_REDUCE_MIN:
239 return "Op_REDUCE_MIN";
240 case Op_REDUCE_PRODUCT:
241 return "Op_REDUCE_PRODUCT";
242 case Op_REDUCE_SUM:
243 return "Op_REDUCE_SUM";
244 case Op_CONCAT:
245 return "Op_CONCAT";
246 case Op_RESHAPE:
247 return "Op_RESHAPE";
248 case Op_REVERSE:
249 return "Op_REVERSE";
250 case Op_SLICE:
251 return "Op_SLICE";
252 case Op_TILE:
253 return "Op_TILE";
254 case Op_TRANSPOSE:
255 return "Op_TRANSPOSE";
256 case Op_GATHER:
257 return "Op_GATHER";
258 case Op_SCATTER:
259 return "Op_SCATTER";
260 case Op_RESIZE:
261 return "Op_RESIZE";
262 case Op_CAST:
263 return "Op_CAST";
264 case Op_RESCALE:
265 return "Op_RESCALE";
266 case Op_CONST:
267 return "Op_CONST";
268 case Op_IDENTITY:
269 return "Op_IDENTITY";
270 case Op_CUSTOM:
271 return "Op_CUSTOM";
272 case Op_COND_IF:
273 return "Op_COND_IF";
274 case Op_WHILE_LOOP:
275 return "Op_WHILE_LOOP";
276 }
277 return "";
278}
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000279
280inline std::vector<uint8_t> ConvertConstantTensorDataToBuffer(const std::shared_ptr<ConstTensorHandle>& tensorHandle)
281{
282 tosa_err_t error;
283 std::vector<uint8_t> uint8Data;
284 auto tensorInfo = tensorHandle->GetTensorInfo();
285
286 switch (tensorInfo.GetDataType())
287 {
288 case DataType::Float32:
289 {
290 std::vector<float> data(tensorInfo.GetNumElements());
291 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
292
293 error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data);
294 break;
295 }
296 case DataType::Float16:
297 {
298 std::vector<float> data(tensorInfo.GetNumElements());
299 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
300
301 error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data);
302 break;
303 }
304 case DataType::QSymmS8:
305 case DataType::QAsymmS8:
306 {
307 std::vector<int8_t> data(tensorInfo.GetNumElements());
308 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
309
310 error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data);
311 break;
312 }
313 case DataType::QAsymmU8:
314 {
315 memcpy(uint8Data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
316 break;
317 }
318 case DataType::QSymmS16:
319 {
320 std::vector<int16_t> data(tensorInfo.GetNumElements());
321 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
322
323 error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data);
324 break;
325 }
326 case DataType::Signed32:
327 {
328 std::vector<int32_t> data(tensorInfo.GetNumElements());
329 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
330
331 error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data);
332 break;
333 }
334 default:
335 {
336 throw armnn::Exception("SetConstantTensorData: An unsupported data type was encountered.");
337 }
338 }
339
340 if(error != tosa_err_t::TOSA_OK)
341 {
342 throw armnn::Exception("SetConstantTensorData: An error occurred when converting constant data");
343 }
344
345 tensorHandle->Unmap();
346 return uint8Data;
347}