blob: 1f4f2da717147552b529fcf26bad787b919354e0 [file] [log] [blame]
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Bentham246bd462020-01-20 16:16:06 +00008#include <armnn/TypesUtils.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01009#include <armnn/utility/Assert.hpp>
10#include <armnn/utility/IgnoreUnused.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000011#include <armnnUtils/FloatingPointConverter.hpp>
12
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010013#include <ResolveType.hpp>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010014
15namespace armnn
16{
17
18class BaseIterator
19{
20public:
21 BaseIterator() {}
22
23 virtual ~BaseIterator() {}
24
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +000025 virtual BaseIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) = 0;
26
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010027 virtual BaseIterator& operator++() = 0;
28
29 virtual BaseIterator& operator+=(const unsigned int increment) = 0;
30
31 virtual BaseIterator& operator-=(const unsigned int increment) = 0;
Francis Murtagh43aec582019-05-27 12:14:10 +010032
33 virtual BaseIterator& operator[](const unsigned int index) = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010034};
35
Derek Lambertif30f7d32019-04-09 10:25:02 +010036template<typename IType>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010037class Decoder : public BaseIterator
38{
39public:
Derek Lambertif30f7d32019-04-09 10:25:02 +010040 Decoder() {}
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010041
42 virtual ~Decoder() {}
43
Matthew Benthamc394a6d2019-06-24 12:51:25 +010044 virtual void Reset(void*) = 0;
45
Derek Lambertif30f7d32019-04-09 10:25:02 +010046 virtual IType Get() const = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010047};
48
Derek Lambertif30f7d32019-04-09 10:25:02 +010049template<typename IType>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010050class Encoder : public BaseIterator
51{
52public:
Derek Lambertif30f7d32019-04-09 10:25:02 +010053 Encoder() {}
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010054
55 virtual ~Encoder() {}
56
Matthew Benthamc394a6d2019-06-24 12:51:25 +010057 virtual void Reset(void*) = 0;
58
Derek Lambertif30f7d32019-04-09 10:25:02 +010059 virtual void Set(IType right) = 0;
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010060
61 virtual IType Get() const = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010062};
63
64template<typename T, typename Base>
65class TypedIterator : public Base
66{
67public:
Matthew Benthamc394a6d2019-06-24 12:51:25 +010068 TypedIterator(T* data = nullptr)
Francis Murtagh43aec582019-05-27 12:14:10 +010069 : m_Iterator(data), m_Start(data)
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010070 {}
71
Matthew Benthamc394a6d2019-06-24 12:51:25 +010072 void Reset(void* data) override
73 {
74 m_Iterator = reinterpret_cast<T*>(data);
75 m_Start = m_Iterator;
76 }
77
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010078 TypedIterator& operator++() override
79 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010080 ARMNN_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010081 ++m_Iterator;
82 return *this;
83 }
84
85 TypedIterator& operator+=(const unsigned int increment) override
86 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010087 ARMNN_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010088 m_Iterator += increment;
89 return *this;
90 }
91
92 TypedIterator& operator-=(const unsigned int increment) override
93 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010094 ARMNN_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010095 m_Iterator -= increment;
96 return *this;
97 }
98
Francis Murtagh43aec582019-05-27 12:14:10 +010099 TypedIterator& operator[](const unsigned int index) override
100 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100101 ARMNN_ASSERT(m_Iterator);
Francis Murtagh43aec582019-05-27 12:14:10 +0100102 m_Iterator = m_Start + index;
103 return *this;
104 }
105
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000106 TypedIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) override
107 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000108 IgnoreUnused(axisIndex);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100109 ARMNN_ASSERT(m_Iterator);
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000110 m_Iterator = m_Start + index;
111 return *this;
112 }
113
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100114protected:
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100115 T* m_Iterator;
Francis Murtagh43aec582019-05-27 12:14:10 +0100116 T* m_Start;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100117};
118
Derek Lambertif30f7d32019-04-09 10:25:02 +0100119class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100120{
121public:
122 QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset)
123 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
124
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100125 QASymm8Decoder(const float scale, const int32_t offset)
126 : QASymm8Decoder(nullptr, scale, offset) {}
127
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100128 float Get() const override
129 {
130 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
131 }
132
133private:
134 const float m_Scale;
135 const int32_t m_Offset;
136};
137
Ryan OShea9add1202020-02-07 10:06:33 +0000138class QASymmS8Decoder : public TypedIterator<const int8_t, Decoder<float>>
139{
140public:
141 QASymmS8Decoder(const int8_t* data, const float scale, const int32_t offset)
142 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
143
144 QASymmS8Decoder(const float scale, const int32_t offset)
145 : QASymmS8Decoder(nullptr, scale, offset) {}
146
147 float Get() const override
148 {
149 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
150 }
151
152private:
153 const float m_Scale;
154 const int32_t m_Offset;
155};
156
Finn Williamsfd271062019-12-04 14:27:27 +0000157class QSymmS8Decoder : public TypedIterator<const int8_t, Decoder<float>>
158{
159public:
160 QSymmS8Decoder(const int8_t* data, const float scale, const int32_t offset)
161 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
162
163 QSymmS8Decoder(const float scale, const int32_t offset)
164 : QSymmS8Decoder(nullptr, scale, offset) {}
165
166 float Get() const override
167 {
168 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
169 }
170
171private:
172 const float m_Scale;
173 const int32_t m_Offset;
174};
175
Derek Lambertif30f7d32019-04-09 10:25:02 +0100176class QSymm16Decoder : public TypedIterator<const int16_t, Decoder<float>>
Sadik Armagan2999a022019-04-09 14:20:12 +0100177{
178public:
179 QSymm16Decoder(const int16_t* data, const float scale, const int32_t offset)
180 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
181
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100182 QSymm16Decoder(const float scale, const int32_t offset)
183 : QSymm16Decoder(nullptr, scale, offset) {}
184
Sadik Armagan2999a022019-04-09 14:20:12 +0100185 float Get() const override
186 {
187 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
188 }
189
190private:
191 const float m_Scale;
192 const int32_t m_Offset;
193};
194
Narumol Prangnawarat88325222020-03-06 14:45:57 +0000195class BFloat16Decoder : public TypedIterator<const BFloat16, Decoder<float>>
196{
197public:
198 BFloat16Decoder(const BFloat16* data)
199 : TypedIterator(data) {}
200
201 BFloat16Decoder()
202 : BFloat16Decoder(nullptr) {}
203
204 float Get() const override
205 {
206 float val = 0.f;
207 armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32(m_Iterator, 1, &val);
208 return val;
209 }
210};
211
Matthew Jacksone69c3992019-09-09 14:31:21 +0100212class Float16Decoder : public TypedIterator<const Half, Decoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100213{
214public:
Matthew Jacksone69c3992019-09-09 14:31:21 +0100215 Float16Decoder(const Half* data)
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100216 : TypedIterator(data) {}
217
Matthew Jacksone69c3992019-09-09 14:31:21 +0100218 Float16Decoder()
219 : Float16Decoder(nullptr) {}
220
221 float Get() const override
222 {
223 float val = 0.f;
224 armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
225 return val;
226 }
227};
228
229class Float32Decoder : public TypedIterator<const float, Decoder<float>>
230{
231public:
232 Float32Decoder(const float* data)
233 : TypedIterator(data) {}
234
235 Float32Decoder()
236 : Float32Decoder(nullptr) {}
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100237
Derek Lambertif30f7d32019-04-09 10:25:02 +0100238 float Get() const override
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100239 {
Derek Lambertif30f7d32019-04-09 10:25:02 +0100240 return *m_Iterator;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100241 }
242};
243
Mike Kelly9b398322019-05-22 17:21:49 +0100244class ScaledInt32Decoder : public TypedIterator<const int32_t, Decoder<float>>
245{
246public:
247 ScaledInt32Decoder(const int32_t* data, const float scale)
248 : TypedIterator(data), m_Scale(scale) {}
249
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100250 ScaledInt32Decoder(const float scale)
251 : ScaledInt32Decoder(nullptr, scale) {}
252
Mike Kelly9b398322019-05-22 17:21:49 +0100253 float Get() const override
254 {
255 return static_cast<float>(*m_Iterator) * m_Scale;
256 }
257
258private:
259 const float m_Scale;
260};
261
Aron Virginas-Tar198ee402019-08-02 18:54:28 +0100262class Int32Decoder : public TypedIterator<const int32_t, Decoder<float>>
263{
264public:
265 Int32Decoder(const int32_t* data)
266 : TypedIterator(data) {}
267
268 Int32Decoder()
269 : Int32Decoder(nullptr) {}
270
271 float Get() const override
272 {
273 return static_cast<float>(*m_Iterator);
274 }
275};
276
Finn Williamscbd2c232020-06-22 15:58:32 +0100277class Int32ToInt32tDecoder : public TypedIterator<const int32_t, Decoder<int32_t>>
278{
279public:
280 Int32ToInt32tDecoder(const int32_t* data)
281 : TypedIterator(data){}
282
283 Int32ToInt32tDecoder()
284 : Int32ToInt32tDecoder(nullptr) {}
285
286 int32_t Get() const override
287 {
288 return *m_Iterator;
289 }
290};
291
Sadik Armaganb60dd242020-03-19 13:53:16 +0000292class BooleanDecoder : public TypedIterator<const uint8_t, Decoder<float>>
293{
294public:
295 BooleanDecoder(const uint8_t* data)
296 : TypedIterator(data) {}
297
298 BooleanDecoder()
299 : BooleanDecoder(nullptr) {}
300
301 float Get() const override
302 {
303 return *m_Iterator;
304 }
305
306};
307
Derek Lambertif30f7d32019-04-09 10:25:02 +0100308class QASymm8Encoder : public TypedIterator<uint8_t, Encoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100309{
310public:
311 QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset)
312 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
313
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100314 QASymm8Encoder(const float scale, const int32_t offset)
315 : QASymm8Encoder(nullptr, scale, offset) {}
316
Derek Lambertif30f7d32019-04-09 10:25:02 +0100317 void Set(float right) override
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100318 {
319 *m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
320 }
321
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100322 float Get() const override
323 {
324 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
325 }
326
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100327private:
328 const float m_Scale;
329 const int32_t m_Offset;
330};
331
Ryan OShea9add1202020-02-07 10:06:33 +0000332class QASymmS8Encoder : public TypedIterator<int8_t, Encoder<float>>
333{
334public:
335 QASymmS8Encoder(int8_t* data, const float scale, const int32_t offset)
336 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
337
338 QASymmS8Encoder(const float scale, const int32_t offset)
339 : QASymmS8Encoder(nullptr, scale, offset) {}
340
341 void Set(float right) override
342 {
343 *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale, m_Offset);
344 }
345
346 float Get() const override
347 {
348 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
349 }
350
351private:
352 const float m_Scale;
353 const int32_t m_Offset;
354};
355
Finn Williamsfd271062019-12-04 14:27:27 +0000356class QSymmS8Encoder : public TypedIterator<int8_t, Encoder<float>>
357{
358public:
359 QSymmS8Encoder(int8_t* data, const float scale, const int32_t offset)
360 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
361
362 QSymmS8Encoder(const float scale, const int32_t offset)
363 : QSymmS8Encoder(nullptr, scale, offset) {}
364
365 void Set(float right) override
366 {
367 *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale, m_Offset);
368 }
369
370 float Get() const override
371 {
372 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
373 }
374
375private:
376 const float m_Scale;
377 const int32_t m_Offset;
378};
379
Derek Lambertif30f7d32019-04-09 10:25:02 +0100380class QSymm16Encoder : public TypedIterator<int16_t, Encoder<float>>
381{
382public:
383 QSymm16Encoder(int16_t* data, const float scale, const int32_t offset)
384 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
385
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100386 QSymm16Encoder(const float scale, const int32_t offset)
387 : QSymm16Encoder(nullptr, scale, offset) {}
388
Derek Lambertif30f7d32019-04-09 10:25:02 +0100389 void Set(float right) override
390 {
391 *m_Iterator = armnn::Quantize<int16_t>(right, m_Scale, m_Offset);
392 }
393
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100394 float Get() const override
395 {
396 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
397 }
398
Derek Lambertif30f7d32019-04-09 10:25:02 +0100399private:
400 const float m_Scale;
401 const int32_t m_Offset;
402};
403
Narumol Prangnawarat88325222020-03-06 14:45:57 +0000404class BFloat16Encoder : public TypedIterator<armnn::BFloat16, Encoder<float>>
405{
406public:
407 BFloat16Encoder(armnn::BFloat16* data)
408 : TypedIterator(data) {}
409
410 BFloat16Encoder()
411 : BFloat16Encoder(nullptr) {}
412
413 void Set(float right) override
414 {
415 armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(&right, 1, m_Iterator);
416 }
417
418 float Get() const override
419 {
420 float val = 0.f;
421 armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32(m_Iterator, 1, &val);
422 return val;
423 }
424};
425
Matthew Jacksone69c3992019-09-09 14:31:21 +0100426class Float16Encoder : public TypedIterator<Half, Encoder<float>>
Derek Lambertif30f7d32019-04-09 10:25:02 +0100427{
428public:
Matthew Jacksone69c3992019-09-09 14:31:21 +0100429 Float16Encoder(Half* data)
Derek Lambertif30f7d32019-04-09 10:25:02 +0100430 : TypedIterator(data) {}
431
Matthew Jacksone69c3992019-09-09 14:31:21 +0100432 Float16Encoder()
433 : Float16Encoder(nullptr) {}
434
435 void Set(float right) override
436 {
437 armnnUtils::FloatingPointConverter::ConvertFloat32To16(&right, 1, m_Iterator);
438 }
439
440 float Get() const override
441 {
442 float val = 0.f;
443 armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
444 return val;
445 }
446};
447
448class Float32Encoder : public TypedIterator<float, Encoder<float>>
449{
450public:
451 Float32Encoder(float* data)
452 : TypedIterator(data) {}
453
454 Float32Encoder()
455 : Float32Encoder(nullptr) {}
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100456
Derek Lambertif30f7d32019-04-09 10:25:02 +0100457 void Set(float right) override
458 {
459 *m_Iterator = right;
460 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100461
462 float Get() const override
463 {
464 return *m_Iterator;
465 }
Derek Lambertif30f7d32019-04-09 10:25:02 +0100466};
467
Aron Virginas-Tar198ee402019-08-02 18:54:28 +0100468class Int32Encoder : public TypedIterator<int32_t, Encoder<float>>
469{
470public:
471 Int32Encoder(int32_t* data)
472 : TypedIterator(data) {}
473
474 Int32Encoder()
475 : Int32Encoder(nullptr) {}
476
477 void Set(float right) override
478 {
479 *m_Iterator = static_cast<int32_t>(right);
480 }
481
482 float Get() const override
483 {
484 return static_cast<float>(*m_Iterator);
485 }
486};
487
Finn Williamscbd2c232020-06-22 15:58:32 +0100488class Int32ToInt32tEncoder : public TypedIterator<int32_t, Encoder<int32_t>>
489{
490public:
491 Int32ToInt32tEncoder(int32_t* data)
492 : TypedIterator(data){}
493
494 Int32ToInt32tEncoder()
495 : Int32ToInt32tEncoder(nullptr) {}
496
497 void Set(int32_t right) override
498 {
499 *m_Iterator = right;
500 }
501
502 int32_t Get() const override
503 {
504 return *m_Iterator;
505 }
506};
507
Derek Lambertif30f7d32019-04-09 10:25:02 +0100508class BooleanEncoder : public TypedIterator<uint8_t, Encoder<bool>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100509{
510public:
511 BooleanEncoder(uint8_t* data)
512 : TypedIterator(data) {}
513
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100514 BooleanEncoder()
515 : BooleanEncoder(nullptr) {}
516
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100517 void Set(bool right) override
518 {
519 *m_Iterator = right;
520 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100521
522 bool Get() const override
523 {
524 return *m_Iterator;
525 }
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100526};
527
Keith Davis5236e1d2019-11-04 08:58:33 +0000528// PerAxisIterator for per-axis quantization
529template<typename T, typename Base>
530class PerAxisIterator : public Base
531{
532public:
533 // axisFactor is used to calculate axisIndex
534 PerAxisIterator(T* data = nullptr, unsigned int axisFactor = 0)
535 : m_Iterator(data), m_Start(data), m_AxisIndex(0), m_AxisFactor(axisFactor)
536 {}
537
538 // This should be called to set index for per-axis Encoder/Decoder
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000539 PerAxisIterator& SetIndex(unsigned int index, unsigned int axisIndex) override
Keith Davis5236e1d2019-11-04 08:58:33 +0000540 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100541 ARMNN_ASSERT(m_Iterator);
Keith Davis5236e1d2019-11-04 08:58:33 +0000542 m_Iterator = m_Start + index;
543 m_AxisIndex = axisIndex;
544 return *this;
545 }
546
547 void Reset(void* data) override
548 {
549 m_Iterator = reinterpret_cast<T*>(data);
550 m_Start = m_Iterator;
551 m_AxisIndex = 0;
552 }
553
554 PerAxisIterator& operator++() override
555 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100556 ARMNN_ASSERT(m_Iterator);
Keith Davis5236e1d2019-11-04 08:58:33 +0000557 ++m_Iterator;
558 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
559 return *this;
560 }
561
562 PerAxisIterator& operator+=(const unsigned int increment) override
563 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100564 ARMNN_ASSERT(m_Iterator);
Keith Davis5236e1d2019-11-04 08:58:33 +0000565 m_Iterator += increment;
566 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
567 return *this;
568 }
569
570 PerAxisIterator& operator-=(const unsigned int decrement) override
571 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100572 ARMNN_ASSERT(m_Iterator);
Keith Davis5236e1d2019-11-04 08:58:33 +0000573 m_Iterator -= decrement;
574 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
575 return *this;
576 }
577
578 PerAxisIterator& operator[](const unsigned int index) override
579 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100580 ARMNN_ASSERT(m_Iterator);
Keith Davis5236e1d2019-11-04 08:58:33 +0000581 m_Iterator = m_Start + index;
582 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
583 return *this;
584 }
585
586 protected:
587 T* m_Iterator;
588 T* m_Start;
589 unsigned int m_AxisIndex;
590 unsigned int m_AxisFactor;
591};
592
593class QSymm8PerAxisDecoder : public PerAxisIterator<const int8_t, Decoder<float>>
594{
595public:
596 QSymm8PerAxisDecoder(const int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
597 : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
598
599 float Get() const override
600 {
601 return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
602 }
603
604 // Get scale of the current value
605 float GetScale() const
606 {
607 return m_Scale[m_AxisIndex];
608 }
609
610private:
611 std::vector<float> m_Scale;
612};
613
614class QSymm8PerAxisEncoder : public PerAxisIterator<int8_t, Encoder<float>>
615{
616public:
617 QSymm8PerAxisEncoder(int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
618 : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
619
620 void Set(float right)
621 {
622 *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale[m_AxisIndex], 0);
623 }
624
625 float Get() const
626 {
627 return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
628 }
629
630 // Get scale of the current value
631 float GetScale() const
632 {
633 return m_Scale[m_AxisIndex];
634 }
635
636private:
637 std::vector<float> m_Scale;
638};
639
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000640class ScaledInt32PerAxisDecoder : public PerAxisIterator<const int32_t, Decoder<float>>
641{
642public:
643 ScaledInt32PerAxisDecoder(const int32_t* data, const std::vector<float>& scales, unsigned int axisFactor)
644 : PerAxisIterator(data, axisFactor), m_Scales(scales) {}
645
646 float Get() const override
647 {
648 return armnn::Dequantize(*m_Iterator, m_Scales[m_AxisIndex], 0);
649 }
650
651 // Get scale of the current value
652 float GetScale() const
653 {
654 return m_Scales[m_AxisIndex];
655 }
656
657private:
658 std::vector<float> m_Scales;
659};
660
661} // namespace armnn