blob: 176e4e1cfbd11d9fc890694e2e21191d6aad8d1e [file] [log] [blame]
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +00008#include <Layer.hpp>
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01009#include <armnn/Tensor.hpp>
10#include <armnn/Types.hpp>
11
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000012#include "common/include/ProfilingGuid.hpp"
13
14#include <tosa_serialization_handler.h>
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010015
16using namespace armnn;
17using namespace tosa;
18
19// Function to return Tosa datatype from input ArmNN datatype.
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010020inline DType ArmNNToDType(const DataType& type)
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010021{
22 switch (type)
23 {
24 case DataType::Float16:
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010025 case DataType::BFloat16:
Matthew Sloyanda824cc2022-10-10 12:43:20 +010026 return DType_FP16;
27 case DataType::Float32:
28 return DType_FP32;
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010029 case DataType::QAsymmU8:
30 return DType_UINT8;
31 case DataType::QSymmS8:
32 case DataType::QAsymmS8:
33 return DType_INT8;
34 case DataType::QSymmS16:
35 return DType_INT16;
36 case DataType::Signed32:
37 return DType_INT32;
38 case DataType::Signed64:
39 // No signed 64, only DType_INT48.
40 return DType_UNKNOWN;
41 case DataType::Boolean:
42 return DType_BOOL;
43 default:
44 return DType_UNKNOWN;
45 }
46}
47
48// Function to return Tosa tensor shape from input ArmNN tensor shape.
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010049inline std::vector<int32_t> GetTosaTensorShape(const TensorShape& shape)
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010050{
51 std::vector<int32_t> returnShape;
52 for (u_int32_t i = 0; i < shape.GetNumDimensions(); i++)
53 {
54 returnShape.push_back(static_cast<int32_t>(shape[i]));
55 }
56 return returnShape;
57}
58
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000059// Function that generates unique name using the layer type, input slot and layer guid.
60inline std::string GenerateUniqueName(const Layer& layer, uint32_t layerSlot)
61{
62 std::string name;
63 std::string guid = std::to_string(layer.GetGuid());
64 std::string slotAndGuid = std::to_string(layerSlot) + "_" + guid;
65 LayerType layerType = layer.GetType();
66
67 if (layerType == LayerType::Input)
68 {
69 name = "input" + slotAndGuid;
70 }
71 else if (layerType == LayerType::Output)
72 {
73 name = "output" + slotAndGuid;
74 }
75 else if (layerType == LayerType::Constant)
76 {
77 name = "constant_" + guid;
78 }
79 else
80 {
81 name = "intermediate" + slotAndGuid;
82 }
83 return name;
84}
85
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010086// Function to return unique int as a string to ensure uniqueness between all input, output and block names.
87static int uniqueTosaMappingID = 0;
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010088inline std::string GetUniqueTosaMappingID()
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010089{
90 return std::to_string(++uniqueTosaMappingID);
91}
Cathal Corbettbd18eab2022-11-15 12:56:16 +000092
93// Function to return Tosa Op as string.
94inline std::string TosaOpToString(Op tosaOp)
95{
96 switch (tosaOp)
97 {
98 case Op_ADD:
99 return "Op_ADD";
100 case Op_AVG_POOL2D:
101 return "Op_AVG_POOL2D";
102 case Op_MAX_POOL2D:
103 return "Op_MAX_POOL2D";
104 case Op_PAD:
105 return "Op_PAD";
106 case Op_UNKNOWN:
107 return "Op_UNKNOWN";
108 case Op_ARGMAX:
109 return "Op_ARGMAX";
110 case Op_CONV2D:
111 return "Op_CONV2D";
112 case Op_CONV3D:
113 return "Op_CONV3D";
114 case Op_DEPTHWISE_CONV2D:
115 return "Op_DEPTHWISE_CONV2D";
116 case Op_FULLY_CONNECTED:
117 return "Op_FULLY_CONNECTED";
118 case Op_MATMUL:
119 return "Op_MATMUL";
120 case Op_TRANSPOSE_CONV2D:
121 return "Op_TRANSPOSE_CONV2D";
122 case Op_CLAMP:
123 return "Op_CLAMP";
124 case Op_RESERVED:
125 return "Op_RESERVED";
126 case Op_SIGMOID:
127 return "Op_SIGMOID";
128 case Op_TANH:
129 return "Op_TANH";
130 case Op_ARITHMETIC_RIGHT_SHIFT:
131 return "Op_ARITHMETIC_RIGHT_SHIFT";
132 case Op_BITWISE_AND:
133 return "Op_BITWISE_AND";
134 case Op_BITWISE_OR:
135 return "Op_BITWISE_OR";
136 case Op_BITWISE_XOR:
137 return "Op_BITWISE_XOR";
138 case Op_INTDIV:
139 return "Op_INTDIV";
140 case Op_LOGICAL_AND:
141 return "Op_LOGICAL_AND";
142 case Op_LOGICAL_LEFT_SHIFT:
143 return "Op_LOGICAL_LEFT_SHIFT";
144 case Op_LOGICAL_RIGHT_SHIFT:
145 return "Op_LOGICAL_RIGHT_SHIFT";
146 case Op_LOGICAL_OR:
147 return "Op_LOGICAL_OR";
148 case Op_LOGICAL_XOR:
149 return "Op_LOGICAL_XOR";
150 case Op_MAXIMUM:
151 return "Op_MAXIMUM";
152 case Op_MINIMUM:
153 return "Op_MINIMUM";
154 case Op_MUL:
155 return "Op_MUL";
156 case Op_POW:
157 return "Op_POW";
158 case Op_SUB:
159 return "Op_SUB";
160 case Op_TABLE:
161 return "Op_TABLE";
162 case Op_ABS:
163 return "Op_ABS";
164 case Op_BITWISE_NOT:
165 return "Op_BITWISE_NOT";
166 case Op_CEIL:
167 return "Op_CEIL";
168 case Op_CLZ:
169 return "Op_CLZ";
170 case Op_EXP:
171 return "Op_EXP";
172 case Op_FLOOR:
173 return "Op_FLOOR";
174 case Op_LOG:
175 return "Op_LOG";
176 case Op_LOGICAL_NOT:
177 return "Op_LOGICAL_NOT";
178 case Op_NEGATE:
179 return "Op_NEGATE";
180 case Op_RECIPROCAL:
181 return "Op_RECIPROCAL";
182 case Op_RSQRT:
183 return "Op_RSQRT";
184 case Op_SELECT:
185 return "Op_SELECT";
186 case Op_EQUAL:
187 return "Op_EQUAL";
188 case Op_GREATER:
189 return "Op_GREATER";
190 case Op_GREATER_EQUAL:
191 return "Op_GREATER_EQUAL";
192 case Op_REDUCE_ANY:
193 return "Op_REDUCE_ANY";
194 case Op_REDUCE_ALL:
195 return "Op_REDUCE_ALL";
196 case Op_REDUCE_MAX:
197 return "Op_REDUCE_MAX";
198 case Op_REDUCE_MIN:
199 return "Op_REDUCE_MIN";
200 case Op_REDUCE_PRODUCT:
201 return "Op_REDUCE_PRODUCT";
202 case Op_REDUCE_SUM:
203 return "Op_REDUCE_SUM";
204 case Op_CONCAT:
205 return "Op_CONCAT";
206 case Op_RESHAPE:
207 return "Op_RESHAPE";
208 case Op_REVERSE:
209 return "Op_REVERSE";
210 case Op_SLICE:
211 return "Op_SLICE";
212 case Op_TILE:
213 return "Op_TILE";
214 case Op_TRANSPOSE:
215 return "Op_TRANSPOSE";
216 case Op_GATHER:
217 return "Op_GATHER";
218 case Op_SCATTER:
219 return "Op_SCATTER";
220 case Op_RESIZE:
221 return "Op_RESIZE";
222 case Op_CAST:
223 return "Op_CAST";
224 case Op_RESCALE:
225 return "Op_RESCALE";
226 case Op_CONST:
227 return "Op_CONST";
228 case Op_IDENTITY:
229 return "Op_IDENTITY";
230 case Op_CUSTOM:
231 return "Op_CUSTOM";
232 case Op_COND_IF:
233 return "Op_COND_IF";
234 case Op_WHILE_LOOP:
235 return "Op_WHILE_LOOP";
236 }
237 return "";
238}
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000239
240inline std::vector<uint8_t> ConvertConstantTensorDataToBuffer(const std::shared_ptr<ConstTensorHandle>& tensorHandle)
241{
242 tosa_err_t error;
243 std::vector<uint8_t> uint8Data;
244 auto tensorInfo = tensorHandle->GetTensorInfo();
245
246 switch (tensorInfo.GetDataType())
247 {
248 case DataType::Float32:
249 {
250 std::vector<float> data(tensorInfo.GetNumElements());
251 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
252
253 error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data);
254 break;
255 }
256 case DataType::Float16:
257 {
258 std::vector<float> data(tensorInfo.GetNumElements());
259 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
260
261 error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data);
262 break;
263 }
264 case DataType::QSymmS8:
265 case DataType::QAsymmS8:
266 {
267 std::vector<int8_t> data(tensorInfo.GetNumElements());
268 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
269
270 error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data);
271 break;
272 }
273 case DataType::QAsymmU8:
274 {
275 memcpy(uint8Data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
276 break;
277 }
278 case DataType::QSymmS16:
279 {
280 std::vector<int16_t> data(tensorInfo.GetNumElements());
281 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
282
283 error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data);
284 break;
285 }
286 case DataType::Signed32:
287 {
288 std::vector<int32_t> data(tensorInfo.GetNumElements());
289 memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
290
291 error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data);
292 break;
293 }
294 default:
295 {
296 throw armnn::Exception("SetConstantTensorData: An unsupported data type was encountered.");
297 }
298 }
299
300 if(error != tosa_err_t::TOSA_OK)
301 {
302 throw armnn::Exception("SetConstantTensorData: An error occurred when converting constant data");
303 }
304
305 tensorHandle->Unmap();
306 return uint8Data;
307}