blob: cbb409c8a10c7ca10caf28a935fbcba9ce85936c [file] [log] [blame]
Matthew Bentham314d3e22023-06-23 10:53:52 +00001/*
2 * Copyright (c) 2016-2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_CORE_UTILS_DATATYPEUTILS_H
25#define ARM_COMPUTE_CORE_UTILS_DATATYPEUTILS_H
26
27#include "arm_compute/core/PixelValue.h"
28#include "arm_compute/core/Types.h"
29
30namespace arm_compute
31{
32/** The size in bytes of the data type
33 *
34 * @param[in] data_type Input data type
35 *
36 * @return The size in bytes of the data type
37 */
38inline size_t data_size_from_type(DataType data_type)
39{
40 switch(data_type)
41 {
42 case DataType::U8:
43 case DataType::S8:
44 case DataType::QSYMM8:
45 case DataType::QASYMM8:
46 case DataType::QASYMM8_SIGNED:
47 case DataType::QSYMM8_PER_CHANNEL:
48 return 1;
49 case DataType::U16:
50 case DataType::S16:
51 case DataType::QSYMM16:
52 case DataType::QASYMM16:
53 case DataType::BFLOAT16:
54 case DataType::F16:
55 return 2;
56 case DataType::F32:
57 case DataType::U32:
58 case DataType::S32:
59 return 4;
60 case DataType::F64:
61 case DataType::U64:
62 case DataType::S64:
63 return 8;
64 case DataType::SIZET:
65 return sizeof(size_t);
66 default:
67 ARM_COMPUTE_ERROR("Invalid data type");
68 return 0;
69 }
70}
71
72/** The size in bytes of the data type
73 *
74 * @param[in] dt Input data type
75 *
76 * @return The size in bytes of the data type
77 */
78inline size_t element_size_from_data_type(DataType dt)
79{
80 switch(dt)
81 {
82 case DataType::S8:
83 case DataType::U8:
84 case DataType::QSYMM8:
85 case DataType::QASYMM8:
86 case DataType::QASYMM8_SIGNED:
87 case DataType::QSYMM8_PER_CHANNEL:
88 return 1;
89 case DataType::U16:
90 case DataType::S16:
91 case DataType::QSYMM16:
92 case DataType::QASYMM16:
93 case DataType::BFLOAT16:
94 case DataType::F16:
95 return 2;
96 case DataType::U32:
97 case DataType::S32:
98 case DataType::F32:
99 return 4;
100 case DataType::U64:
101 case DataType::S64:
102 return 8;
103 default:
104 ARM_COMPUTE_ERROR("Undefined element size for given data type");
105 return 0;
106 }
107}
108
109/** Return the data type used by a given single-planar pixel format
110 *
111 * @param[in] format Input format
112 *
113 * @return The size in bytes of the pixel format
114 */
115inline DataType data_type_from_format(Format format)
116{
117 switch(format)
118 {
119 case Format::U8:
120 case Format::UV88:
121 case Format::RGB888:
122 case Format::RGBA8888:
123 case Format::YUYV422:
124 case Format::UYVY422:
125 return DataType::U8;
126 case Format::U16:
127 return DataType::U16;
128 case Format::S16:
129 return DataType::S16;
130 case Format::U32:
131 return DataType::U32;
132 case Format::S32:
133 return DataType::S32;
134 case Format::BFLOAT16:
135 return DataType::BFLOAT16;
136 case Format::F16:
137 return DataType::F16;
138 case Format::F32:
139 return DataType::F32;
140 //Doesn't make sense for planar formats:
141 case Format::NV12:
142 case Format::NV21:
143 case Format::IYUV:
144 case Format::YUV444:
145 default:
146 ARM_COMPUTE_ERROR("Not supported data_type for given format");
147 return DataType::UNKNOWN;
148 }
149}
150
151/** Return the promoted data type of a given data type.
152 *
153 * @note If promoted data type is not supported an error will be thrown
154 *
155 * @param[in] dt Data type to get the promoted type of.
156 *
157 * @return Promoted data type
158 */
159inline DataType get_promoted_data_type(DataType dt)
160{
161 switch(dt)
162 {
163 case DataType::U8:
164 return DataType::U16;
165 case DataType::S8:
166 return DataType::S16;
167 case DataType::U16:
168 return DataType::U32;
169 case DataType::S16:
170 return DataType::S32;
171 case DataType::QSYMM8:
172 case DataType::QASYMM8:
173 case DataType::QASYMM8_SIGNED:
174 case DataType::QSYMM8_PER_CHANNEL:
175 case DataType::QSYMM16:
176 case DataType::QASYMM16:
177 case DataType::BFLOAT16:
178 case DataType::F16:
179 case DataType::U32:
180 case DataType::S32:
181 case DataType::F32:
182 ARM_COMPUTE_ERROR("Unsupported data type promotions!");
183 default:
184 ARM_COMPUTE_ERROR("Undefined data type!");
185 }
186 return DataType::UNKNOWN;
187}
188
189/** Compute the mininum and maximum values a data type can take
190 *
191 * @param[in] dt Data type to get the min/max bounds of
192 *
193 * @return A tuple (min,max) with the minimum and maximum values respectively wrapped in PixelValue.
194 */
195inline std::tuple<PixelValue, PixelValue> get_min_max(DataType dt)
196{
197 PixelValue min{};
198 PixelValue max{};
199 switch(dt)
200 {
201 case DataType::U8:
202 case DataType::QASYMM8:
203 {
204 min = PixelValue(static_cast<int32_t>(std::numeric_limits<uint8_t>::lowest()));
205 max = PixelValue(static_cast<int32_t>(std::numeric_limits<uint8_t>::max()));
206 break;
207 }
208 case DataType::S8:
209 case DataType::QSYMM8:
210 case DataType::QASYMM8_SIGNED:
211 case DataType::QSYMM8_PER_CHANNEL:
212 {
213 min = PixelValue(static_cast<int32_t>(std::numeric_limits<int8_t>::lowest()));
214 max = PixelValue(static_cast<int32_t>(std::numeric_limits<int8_t>::max()));
215 break;
216 }
217 case DataType::U16:
218 case DataType::QASYMM16:
219 {
220 min = PixelValue(static_cast<int32_t>(std::numeric_limits<uint16_t>::lowest()));
221 max = PixelValue(static_cast<int32_t>(std::numeric_limits<uint16_t>::max()));
222 break;
223 }
224 case DataType::S16:
225 case DataType::QSYMM16:
226 {
227 min = PixelValue(static_cast<int32_t>(std::numeric_limits<int16_t>::lowest()));
228 max = PixelValue(static_cast<int32_t>(std::numeric_limits<int16_t>::max()));
229 break;
230 }
231 case DataType::U32:
232 {
233 min = PixelValue(std::numeric_limits<uint32_t>::lowest());
234 max = PixelValue(std::numeric_limits<uint32_t>::max());
235 break;
236 }
237 case DataType::S32:
238 {
239 min = PixelValue(std::numeric_limits<int32_t>::lowest());
240 max = PixelValue(std::numeric_limits<int32_t>::max());
241 break;
242 }
243 case DataType::BFLOAT16:
244 {
245 min = PixelValue(bfloat16::lowest());
246 max = PixelValue(bfloat16::max());
247 break;
248 }
249 case DataType::F16:
250 {
251 min = PixelValue(std::numeric_limits<half>::lowest());
252 max = PixelValue(std::numeric_limits<half>::max());
253 break;
254 }
255 case DataType::F32:
256 {
257 min = PixelValue(std::numeric_limits<float>::lowest());
258 max = PixelValue(std::numeric_limits<float>::max());
259 break;
260 }
261 default:
262 ARM_COMPUTE_ERROR("Undefined data type!");
263 }
264 return std::make_tuple(min, max);
265}
266
267/** Convert a data type identity into a string.
268 *
269 * @param[in] dt @ref DataType to be translated to string.
270 *
271 * @return The string describing the data type.
272 */
273const std::string &string_from_data_type(DataType dt);
274
275/** Convert a string to DataType
276 *
277 * @param[in] name The name of the data type
278 *
279 * @return DataType
280 */
281DataType data_type_from_name(const std::string &name);
282
283/** Input Stream operator for @ref DataType
284 *
285 * @param[in] stream Stream to parse
286 * @param[out] data_type Output data type
287 *
288 * @return Updated stream
289 */
290inline ::std::istream &operator>>(::std::istream &stream, DataType &data_type)
291{
292 std::string value;
293 stream >> value;
294 data_type = data_type_from_name(value);
295 return stream;
296}
297
298/** Check if a given data type is of floating point type
299 *
300 * @param[in] dt Input data type.
301 *
302 * @return True if data type is of floating point type, else false.
303 */
304inline bool is_data_type_float(DataType dt)
305{
306 switch(dt)
307 {
308 case DataType::F16:
309 case DataType::F32:
310 return true;
311 default:
312 return false;
313 }
314}
315
316/** Check if a given data type is of quantized type
317 *
318 * @note Quantized is considered a super-set of fixed-point and asymmetric data types.
319 *
320 * @param[in] dt Input data type.
321 *
322 * @return True if data type is of quantized type, else false.
323 */
324inline bool is_data_type_quantized(DataType dt)
325{
326 switch(dt)
327 {
328 case DataType::QSYMM8:
329 case DataType::QASYMM8:
330 case DataType::QASYMM8_SIGNED:
331 case DataType::QSYMM8_PER_CHANNEL:
332 case DataType::QSYMM16:
333 case DataType::QASYMM16:
334 return true;
335 default:
336 return false;
337 }
338}
339
340/** Check if a given data type is of asymmetric quantized type
341 *
342 * @param[in] dt Input data type.
343 *
344 * @return True if data type is of asymmetric quantized type, else false.
345 */
346inline bool is_data_type_quantized_asymmetric(DataType dt)
347{
348 switch(dt)
349 {
350 case DataType::QASYMM8:
351 case DataType::QASYMM8_SIGNED:
352 case DataType::QASYMM16:
353 return true;
354 default:
355 return false;
356 }
357}
358
359/** Check if a given data type is of asymmetric quantized signed type
360 *
361 * @param[in] dt Input data type.
362 *
363 * @return True if data type is of asymmetric quantized signed type, else false.
364 */
365inline bool is_data_type_quantized_asymmetric_signed(DataType dt)
366{
367 switch(dt)
368 {
369 case DataType::QASYMM8_SIGNED:
370 return true;
371 default:
372 return false;
373 }
374}
375
376/** Check if a given data type is of symmetric quantized type
377 *
378 * @param[in] dt Input data type.
379 *
380 * @return True if data type is of symmetric quantized type, else false.
381 */
382inline bool is_data_type_quantized_symmetric(DataType dt)
383{
384 switch(dt)
385 {
386 case DataType::QSYMM8:
387 case DataType::QSYMM8_PER_CHANNEL:
388 case DataType::QSYMM16:
389 return true;
390 default:
391 return false;
392 }
393}
394
395/** Check if a given data type is of per channel type
396 *
397 * @param[in] dt Input data type.
398 *
399 * @return True if data type is of per channel type, else false.
400 */
401inline bool is_data_type_quantized_per_channel(DataType dt)
402{
403 switch(dt)
404 {
405 case DataType::QSYMM8_PER_CHANNEL:
406 return true;
407 default:
408 return false;
409 }
410}
411
412/** Returns true if the value can be represented by the given data type
413 *
414 * @param[in] val value to be checked
415 * @param[in] dt data type that is checked
416 * @param[in] qinfo (Optional) quantization info if the data type is QASYMM8
417 *
418 * @return true if the data type can hold the value.
419 */
420template <typename T>
421bool check_value_range(T val, DataType dt, QuantizationInfo qinfo = QuantizationInfo())
422{
423 switch(dt)
424 {
425 case DataType::U8:
426 {
427 const auto val_u8 = static_cast<uint8_t>(val);
428 return ((val_u8 == val) && val >= std::numeric_limits<uint8_t>::lowest() && val <= std::numeric_limits<uint8_t>::max());
429 }
430 case DataType::QASYMM8:
431 {
432 double min = static_cast<double>(dequantize_qasymm8(0, qinfo));
433 double max = static_cast<double>(dequantize_qasymm8(std::numeric_limits<uint8_t>::max(), qinfo));
434 return ((double)val >= min && (double)val <= max);
435 }
436 case DataType::S8:
437 {
438 const auto val_s8 = static_cast<int8_t>(val);
439 return ((val_s8 == val) && val >= std::numeric_limits<int8_t>::lowest() && val <= std::numeric_limits<int8_t>::max());
440 }
441 case DataType::U16:
442 {
443 const auto val_u16 = static_cast<uint16_t>(val);
444 return ((val_u16 == val) && val >= std::numeric_limits<uint16_t>::lowest() && val <= std::numeric_limits<uint16_t>::max());
445 }
446 case DataType::S16:
447 {
448 const auto val_s16 = static_cast<int16_t>(val);
449 return ((val_s16 == val) && val >= std::numeric_limits<int16_t>::lowest() && val <= std::numeric_limits<int16_t>::max());
450 }
451 case DataType::U32:
452 {
453 const auto val_d64 = static_cast<double>(val);
454 const auto val_u32 = static_cast<uint32_t>(val);
455 return ((val_u32 == val_d64) && val_d64 >= std::numeric_limits<uint32_t>::lowest() && val_d64 <= std::numeric_limits<uint32_t>::max());
456 }
457 case DataType::S32:
458 {
459 const auto val_d64 = static_cast<double>(val);
460 const auto val_s32 = static_cast<int32_t>(val);
461 return ((val_s32 == val_d64) && val_d64 >= std::numeric_limits<int32_t>::lowest() && val_d64 <= std::numeric_limits<int32_t>::max());
462 }
463 case DataType::BFLOAT16:
464 return (val >= bfloat16::lowest() && val <= bfloat16::max());
465 case DataType::F16:
466 return (val >= std::numeric_limits<half>::lowest() && val <= std::numeric_limits<half>::max());
467 case DataType::F32:
468 return (val >= std::numeric_limits<float>::lowest() && val <= std::numeric_limits<float>::max());
469 default:
470 ARM_COMPUTE_ERROR("Data type not supported");
471 return false;
472 }
473}
474
475/** Returns the suffix string of CPU kernel implementation names based on the given data type
476 *
477 * @param[in] data_type The data type the CPU kernel implemetation uses
478 *
479 * @return the suffix string of CPU kernel implementations
480 */
481inline std::string cpu_impl_dt(const DataType &data_type)
482{
483 std::string ret = "";
484
485 switch(data_type)
486 {
487 case DataType::F32:
488 ret = "fp32";
489 break;
490 case DataType::F16:
491 ret = "fp16";
492 break;
493 case DataType::U8:
494 ret = "u8";
495 break;
496 case DataType::S16:
497 ret = "s16";
498 break;
499 case DataType::S32:
500 ret = "s32";
501 break;
502 case DataType::QASYMM8:
503 ret = "qu8";
504 break;
505 case DataType::QASYMM8_SIGNED:
506 ret = "qs8";
507 break;
508 case DataType::QSYMM16:
509 ret = "qs16";
510 break;
511 case DataType::QSYMM8_PER_CHANNEL:
512 ret = "qp8";
513 break;
514 case DataType::BFLOAT16:
515 ret = "bf16";
516 break;
517 default:
518 ARM_COMPUTE_ERROR("Unsupported.");
519 }
520
521 return ret;
522}
523
524}
525#endif /*ARM_COMPUTE_CORE_UTILS_DATATYPEUTILS_H */