blob: c48201837b2459186fa9fc57a9b7961c886f6af0 [file] [log] [blame]
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matteo Martincighe011d202019-11-28 11:35:47 +00008
Matthew Bentham246bd462020-01-20 16:16:06 +00009#include <armnn/TypesUtils.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000010#include <armnnUtils/FloatingPointConverter.hpp>
11
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010012#include <ResolveType.hpp>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010013
Matthew Benthamc394a6d2019-06-24 12:51:25 +010014#include <boost/assert.hpp>
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +000015#include <boost/core/ignore_unused.hpp>
Matthew Benthamc394a6d2019-06-24 12:51:25 +010016
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010017namespace armnn
18{
19
20class BaseIterator
21{
22public:
23 BaseIterator() {}
24
25 virtual ~BaseIterator() {}
26
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +000027 virtual BaseIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) = 0;
28
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010029 virtual BaseIterator& operator++() = 0;
30
31 virtual BaseIterator& operator+=(const unsigned int increment) = 0;
32
33 virtual BaseIterator& operator-=(const unsigned int increment) = 0;
Francis Murtagh43aec582019-05-27 12:14:10 +010034
35 virtual BaseIterator& operator[](const unsigned int index) = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010036};
37
Derek Lambertif30f7d32019-04-09 10:25:02 +010038template<typename IType>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010039class Decoder : public BaseIterator
40{
41public:
Derek Lambertif30f7d32019-04-09 10:25:02 +010042 Decoder() {}
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010043
44 virtual ~Decoder() {}
45
Matthew Benthamc394a6d2019-06-24 12:51:25 +010046 virtual void Reset(void*) = 0;
47
Derek Lambertif30f7d32019-04-09 10:25:02 +010048 virtual IType Get() const = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010049};
50
Derek Lambertif30f7d32019-04-09 10:25:02 +010051template<typename IType>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010052class Encoder : public BaseIterator
53{
54public:
Derek Lambertif30f7d32019-04-09 10:25:02 +010055 Encoder() {}
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010056
57 virtual ~Encoder() {}
58
Matthew Benthamc394a6d2019-06-24 12:51:25 +010059 virtual void Reset(void*) = 0;
60
Derek Lambertif30f7d32019-04-09 10:25:02 +010061 virtual void Set(IType right) = 0;
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010062
63 virtual IType Get() const = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010064};
65
66template<typename T, typename Base>
67class TypedIterator : public Base
68{
69public:
Matthew Benthamc394a6d2019-06-24 12:51:25 +010070 TypedIterator(T* data = nullptr)
Francis Murtagh43aec582019-05-27 12:14:10 +010071 : m_Iterator(data), m_Start(data)
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010072 {}
73
Matthew Benthamc394a6d2019-06-24 12:51:25 +010074 void Reset(void* data) override
75 {
76 m_Iterator = reinterpret_cast<T*>(data);
77 m_Start = m_Iterator;
78 }
79
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010080 TypedIterator& operator++() override
81 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010082 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010083 ++m_Iterator;
84 return *this;
85 }
86
87 TypedIterator& operator+=(const unsigned int increment) override
88 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010089 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010090 m_Iterator += increment;
91 return *this;
92 }
93
94 TypedIterator& operator-=(const unsigned int increment) override
95 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010096 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010097 m_Iterator -= increment;
98 return *this;
99 }
100
Francis Murtagh43aec582019-05-27 12:14:10 +0100101 TypedIterator& operator[](const unsigned int index) override
102 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100103 BOOST_ASSERT(m_Iterator);
Francis Murtagh43aec582019-05-27 12:14:10 +0100104 m_Iterator = m_Start + index;
105 return *this;
106 }
107
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000108 TypedIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) override
109 {
110 boost::ignore_unused(axisIndex);
111 BOOST_ASSERT(m_Iterator);
112 m_Iterator = m_Start + index;
113 return *this;
114 }
115
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100116protected:
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100117 T* m_Iterator;
Francis Murtagh43aec582019-05-27 12:14:10 +0100118 T* m_Start;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100119};
120
Derek Lambertif30f7d32019-04-09 10:25:02 +0100121class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100122{
123public:
124 QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset)
125 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
126
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100127 QASymm8Decoder(const float scale, const int32_t offset)
128 : QASymm8Decoder(nullptr, scale, offset) {}
129
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100130 float Get() const override
131 {
132 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
133 }
134
135private:
136 const float m_Scale;
137 const int32_t m_Offset;
138};
139
Ryan OShea9add1202020-02-07 10:06:33 +0000140class QASymmS8Decoder : public TypedIterator<const int8_t, Decoder<float>>
141{
142public:
143 QASymmS8Decoder(const int8_t* data, const float scale, const int32_t offset)
144 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
145
146 QASymmS8Decoder(const float scale, const int32_t offset)
147 : QASymmS8Decoder(nullptr, scale, offset) {}
148
149 float Get() const override
150 {
151 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
152 }
153
154private:
155 const float m_Scale;
156 const int32_t m_Offset;
157};
158
Finn Williamsfd271062019-12-04 14:27:27 +0000159class QSymmS8Decoder : public TypedIterator<const int8_t, Decoder<float>>
160{
161public:
162 QSymmS8Decoder(const int8_t* data, const float scale, const int32_t offset)
163 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
164
165 QSymmS8Decoder(const float scale, const int32_t offset)
166 : QSymmS8Decoder(nullptr, scale, offset) {}
167
168 float Get() const override
169 {
170 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
171 }
172
173private:
174 const float m_Scale;
175 const int32_t m_Offset;
176};
177
Derek Lambertif30f7d32019-04-09 10:25:02 +0100178class QSymm16Decoder : public TypedIterator<const int16_t, Decoder<float>>
Sadik Armagan2999a022019-04-09 14:20:12 +0100179{
180public:
181 QSymm16Decoder(const int16_t* data, const float scale, const int32_t offset)
182 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
183
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100184 QSymm16Decoder(const float scale, const int32_t offset)
185 : QSymm16Decoder(nullptr, scale, offset) {}
186
Sadik Armagan2999a022019-04-09 14:20:12 +0100187 float Get() const override
188 {
189 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
190 }
191
192private:
193 const float m_Scale;
194 const int32_t m_Offset;
195};
196
Matthew Jacksone69c3992019-09-09 14:31:21 +0100197class Float16Decoder : public TypedIterator<const Half, Decoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100198{
199public:
Matthew Jacksone69c3992019-09-09 14:31:21 +0100200 Float16Decoder(const Half* data)
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100201 : TypedIterator(data) {}
202
Matthew Jacksone69c3992019-09-09 14:31:21 +0100203 Float16Decoder()
204 : Float16Decoder(nullptr) {}
205
206 float Get() const override
207 {
208 float val = 0.f;
209 armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
210 return val;
211 }
212};
213
214class Float32Decoder : public TypedIterator<const float, Decoder<float>>
215{
216public:
217 Float32Decoder(const float* data)
218 : TypedIterator(data) {}
219
220 Float32Decoder()
221 : Float32Decoder(nullptr) {}
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100222
Derek Lambertif30f7d32019-04-09 10:25:02 +0100223 float Get() const override
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100224 {
Derek Lambertif30f7d32019-04-09 10:25:02 +0100225 return *m_Iterator;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100226 }
227};
228
Mike Kelly9b398322019-05-22 17:21:49 +0100229class ScaledInt32Decoder : public TypedIterator<const int32_t, Decoder<float>>
230{
231public:
232 ScaledInt32Decoder(const int32_t* data, const float scale)
233 : TypedIterator(data), m_Scale(scale) {}
234
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100235 ScaledInt32Decoder(const float scale)
236 : ScaledInt32Decoder(nullptr, scale) {}
237
Mike Kelly9b398322019-05-22 17:21:49 +0100238 float Get() const override
239 {
240 return static_cast<float>(*m_Iterator) * m_Scale;
241 }
242
243private:
244 const float m_Scale;
245};
246
Aron Virginas-Tar198ee402019-08-02 18:54:28 +0100247class Int32Decoder : public TypedIterator<const int32_t, Decoder<float>>
248{
249public:
250 Int32Decoder(const int32_t* data)
251 : TypedIterator(data) {}
252
253 Int32Decoder()
254 : Int32Decoder(nullptr) {}
255
256 float Get() const override
257 {
258 return static_cast<float>(*m_Iterator);
259 }
260};
261
Derek Lambertif30f7d32019-04-09 10:25:02 +0100262class QASymm8Encoder : public TypedIterator<uint8_t, Encoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100263{
264public:
265 QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset)
266 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
267
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100268 QASymm8Encoder(const float scale, const int32_t offset)
269 : QASymm8Encoder(nullptr, scale, offset) {}
270
Derek Lambertif30f7d32019-04-09 10:25:02 +0100271 void Set(float right) override
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100272 {
273 *m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
274 }
275
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100276 float Get() const override
277 {
278 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
279 }
280
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100281private:
282 const float m_Scale;
283 const int32_t m_Offset;
284};
285
Ryan OShea9add1202020-02-07 10:06:33 +0000286class QASymmS8Encoder : public TypedIterator<int8_t, Encoder<float>>
287{
288public:
289 QASymmS8Encoder(int8_t* data, const float scale, const int32_t offset)
290 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
291
292 QASymmS8Encoder(const float scale, const int32_t offset)
293 : QASymmS8Encoder(nullptr, scale, offset) {}
294
295 void Set(float right) override
296 {
297 *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale, m_Offset);
298 }
299
300 float Get() const override
301 {
302 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
303 }
304
305private:
306 const float m_Scale;
307 const int32_t m_Offset;
308};
309
Finn Williamsfd271062019-12-04 14:27:27 +0000310class QSymmS8Encoder : public TypedIterator<int8_t, Encoder<float>>
311{
312public:
313 QSymmS8Encoder(int8_t* data, const float scale, const int32_t offset)
314 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
315
316 QSymmS8Encoder(const float scale, const int32_t offset)
317 : QSymmS8Encoder(nullptr, scale, offset) {}
318
319 void Set(float right) override
320 {
321 *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale, m_Offset);
322 }
323
324 float Get() const override
325 {
326 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
327 }
328
329private:
330 const float m_Scale;
331 const int32_t m_Offset;
332};
333
Derek Lambertif30f7d32019-04-09 10:25:02 +0100334class QSymm16Encoder : public TypedIterator<int16_t, Encoder<float>>
335{
336public:
337 QSymm16Encoder(int16_t* data, const float scale, const int32_t offset)
338 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
339
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100340 QSymm16Encoder(const float scale, const int32_t offset)
341 : QSymm16Encoder(nullptr, scale, offset) {}
342
Derek Lambertif30f7d32019-04-09 10:25:02 +0100343 void Set(float right) override
344 {
345 *m_Iterator = armnn::Quantize<int16_t>(right, m_Scale, m_Offset);
346 }
347
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100348 float Get() const override
349 {
350 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
351 }
352
Derek Lambertif30f7d32019-04-09 10:25:02 +0100353private:
354 const float m_Scale;
355 const int32_t m_Offset;
356};
357
Matthew Jacksone69c3992019-09-09 14:31:21 +0100358class Float16Encoder : public TypedIterator<Half, Encoder<float>>
Derek Lambertif30f7d32019-04-09 10:25:02 +0100359{
360public:
Matthew Jacksone69c3992019-09-09 14:31:21 +0100361 Float16Encoder(Half* data)
Derek Lambertif30f7d32019-04-09 10:25:02 +0100362 : TypedIterator(data) {}
363
Matthew Jacksone69c3992019-09-09 14:31:21 +0100364 Float16Encoder()
365 : Float16Encoder(nullptr) {}
366
367 void Set(float right) override
368 {
369 armnnUtils::FloatingPointConverter::ConvertFloat32To16(&right, 1, m_Iterator);
370 }
371
372 float Get() const override
373 {
374 float val = 0.f;
375 armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
376 return val;
377 }
378};
379
380class Float32Encoder : public TypedIterator<float, Encoder<float>>
381{
382public:
383 Float32Encoder(float* data)
384 : TypedIterator(data) {}
385
386 Float32Encoder()
387 : Float32Encoder(nullptr) {}
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100388
Derek Lambertif30f7d32019-04-09 10:25:02 +0100389 void Set(float right) override
390 {
391 *m_Iterator = right;
392 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100393
394 float Get() const override
395 {
396 return *m_Iterator;
397 }
Derek Lambertif30f7d32019-04-09 10:25:02 +0100398};
399
Aron Virginas-Tar198ee402019-08-02 18:54:28 +0100400class Int32Encoder : public TypedIterator<int32_t, Encoder<float>>
401{
402public:
403 Int32Encoder(int32_t* data)
404 : TypedIterator(data) {}
405
406 Int32Encoder()
407 : Int32Encoder(nullptr) {}
408
409 void Set(float right) override
410 {
411 *m_Iterator = static_cast<int32_t>(right);
412 }
413
414 float Get() const override
415 {
416 return static_cast<float>(*m_Iterator);
417 }
418};
419
Derek Lambertif30f7d32019-04-09 10:25:02 +0100420class BooleanEncoder : public TypedIterator<uint8_t, Encoder<bool>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100421{
422public:
423 BooleanEncoder(uint8_t* data)
424 : TypedIterator(data) {}
425
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100426 BooleanEncoder()
427 : BooleanEncoder(nullptr) {}
428
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100429 void Set(bool right) override
430 {
431 *m_Iterator = right;
432 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100433
434 bool Get() const override
435 {
436 return *m_Iterator;
437 }
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100438};
439
Keith Davis5236e1d2019-11-04 08:58:33 +0000440// PerAxisIterator for per-axis quantization
441template<typename T, typename Base>
442class PerAxisIterator : public Base
443{
444public:
445 // axisFactor is used to calculate axisIndex
446 PerAxisIterator(T* data = nullptr, unsigned int axisFactor = 0)
447 : m_Iterator(data), m_Start(data), m_AxisIndex(0), m_AxisFactor(axisFactor)
448 {}
449
450 // This should be called to set index for per-axis Encoder/Decoder
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000451 PerAxisIterator& SetIndex(unsigned int index, unsigned int axisIndex) override
Keith Davis5236e1d2019-11-04 08:58:33 +0000452 {
453 BOOST_ASSERT(m_Iterator);
454 m_Iterator = m_Start + index;
455 m_AxisIndex = axisIndex;
456 return *this;
457 }
458
459 void Reset(void* data) override
460 {
461 m_Iterator = reinterpret_cast<T*>(data);
462 m_Start = m_Iterator;
463 m_AxisIndex = 0;
464 }
465
466 PerAxisIterator& operator++() override
467 {
468 BOOST_ASSERT(m_Iterator);
469 ++m_Iterator;
470 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
471 return *this;
472 }
473
474 PerAxisIterator& operator+=(const unsigned int increment) override
475 {
476 BOOST_ASSERT(m_Iterator);
477 m_Iterator += increment;
478 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
479 return *this;
480 }
481
482 PerAxisIterator& operator-=(const unsigned int decrement) override
483 {
484 BOOST_ASSERT(m_Iterator);
485 m_Iterator -= decrement;
486 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
487 return *this;
488 }
489
490 PerAxisIterator& operator[](const unsigned int index) override
491 {
492 BOOST_ASSERT(m_Iterator);
493 m_Iterator = m_Start + index;
494 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
495 return *this;
496 }
497
498 protected:
499 T* m_Iterator;
500 T* m_Start;
501 unsigned int m_AxisIndex;
502 unsigned int m_AxisFactor;
503};
504
505class QSymm8PerAxisDecoder : public PerAxisIterator<const int8_t, Decoder<float>>
506{
507public:
508 QSymm8PerAxisDecoder(const int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
509 : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
510
511 float Get() const override
512 {
513 return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
514 }
515
516 // Get scale of the current value
517 float GetScale() const
518 {
519 return m_Scale[m_AxisIndex];
520 }
521
522private:
523 std::vector<float> m_Scale;
524};
525
526class QSymm8PerAxisEncoder : public PerAxisIterator<int8_t, Encoder<float>>
527{
528public:
529 QSymm8PerAxisEncoder(int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
530 : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
531
532 void Set(float right)
533 {
534 *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale[m_AxisIndex], 0);
535 }
536
537 float Get() const
538 {
539 return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
540 }
541
542 // Get scale of the current value
543 float GetScale() const
544 {
545 return m_Scale[m_AxisIndex];
546 }
547
548private:
549 std::vector<float> m_Scale;
550};
551
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000552class ScaledInt32PerAxisDecoder : public PerAxisIterator<const int32_t, Decoder<float>>
553{
554public:
555 ScaledInt32PerAxisDecoder(const int32_t* data, const std::vector<float>& scales, unsigned int axisFactor)
556 : PerAxisIterator(data, axisFactor), m_Scales(scales) {}
557
558 float Get() const override
559 {
560 return armnn::Dequantize(*m_Iterator, m_Scales[m_AxisIndex], 0);
561 }
562
563 // Get scale of the current value
564 float GetScale() const
565 {
566 return m_Scales[m_AxisIndex];
567 }
568
569private:
570 std::vector<float> m_Scales;
571};
572
573} // namespace armnn