blob: f43e8b67a9740e1c4589765f7acbdd98d9cd98b0 [file] [log] [blame]
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Jan Eilers8eb25602020-03-09 12:13:48 +00008#include <armnn/utility/IgnoreUnused.hpp>
Matthew Bentham246bd462020-01-20 16:16:06 +00009#include <armnn/TypesUtils.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000010#include <armnnUtils/FloatingPointConverter.hpp>
11
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010012#include <ResolveType.hpp>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010013
Matthew Benthamc394a6d2019-06-24 12:51:25 +010014#include <boost/assert.hpp>
15
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010016namespace armnn
17{
18
19class BaseIterator
20{
21public:
22 BaseIterator() {}
23
24 virtual ~BaseIterator() {}
25
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +000026 virtual BaseIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) = 0;
27
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010028 virtual BaseIterator& operator++() = 0;
29
30 virtual BaseIterator& operator+=(const unsigned int increment) = 0;
31
32 virtual BaseIterator& operator-=(const unsigned int increment) = 0;
Francis Murtagh43aec582019-05-27 12:14:10 +010033
34 virtual BaseIterator& operator[](const unsigned int index) = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010035};
36
Derek Lambertif30f7d32019-04-09 10:25:02 +010037template<typename IType>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010038class Decoder : public BaseIterator
39{
40public:
Derek Lambertif30f7d32019-04-09 10:25:02 +010041 Decoder() {}
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010042
43 virtual ~Decoder() {}
44
Matthew Benthamc394a6d2019-06-24 12:51:25 +010045 virtual void Reset(void*) = 0;
46
Derek Lambertif30f7d32019-04-09 10:25:02 +010047 virtual IType Get() const = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010048};
49
Derek Lambertif30f7d32019-04-09 10:25:02 +010050template<typename IType>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010051class Encoder : public BaseIterator
52{
53public:
Derek Lambertif30f7d32019-04-09 10:25:02 +010054 Encoder() {}
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010055
56 virtual ~Encoder() {}
57
Matthew Benthamc394a6d2019-06-24 12:51:25 +010058 virtual void Reset(void*) = 0;
59
Derek Lambertif30f7d32019-04-09 10:25:02 +010060 virtual void Set(IType right) = 0;
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010061
62 virtual IType Get() const = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010063};
64
65template<typename T, typename Base>
66class TypedIterator : public Base
67{
68public:
Matthew Benthamc394a6d2019-06-24 12:51:25 +010069 TypedIterator(T* data = nullptr)
Francis Murtagh43aec582019-05-27 12:14:10 +010070 : m_Iterator(data), m_Start(data)
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010071 {}
72
Matthew Benthamc394a6d2019-06-24 12:51:25 +010073 void Reset(void* data) override
74 {
75 m_Iterator = reinterpret_cast<T*>(data);
76 m_Start = m_Iterator;
77 }
78
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010079 TypedIterator& operator++() override
80 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010081 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010082 ++m_Iterator;
83 return *this;
84 }
85
86 TypedIterator& operator+=(const unsigned int increment) override
87 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010088 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010089 m_Iterator += increment;
90 return *this;
91 }
92
93 TypedIterator& operator-=(const unsigned int increment) override
94 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010095 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010096 m_Iterator -= increment;
97 return *this;
98 }
99
Francis Murtagh43aec582019-05-27 12:14:10 +0100100 TypedIterator& operator[](const unsigned int index) override
101 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100102 BOOST_ASSERT(m_Iterator);
Francis Murtagh43aec582019-05-27 12:14:10 +0100103 m_Iterator = m_Start + index;
104 return *this;
105 }
106
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000107 TypedIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) override
108 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000109 IgnoreUnused(axisIndex);
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000110 BOOST_ASSERT(m_Iterator);
111 m_Iterator = m_Start + index;
112 return *this;
113 }
114
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100115protected:
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100116 T* m_Iterator;
Francis Murtagh43aec582019-05-27 12:14:10 +0100117 T* m_Start;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100118};
119
Derek Lambertif30f7d32019-04-09 10:25:02 +0100120class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100121{
122public:
123 QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset)
124 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
125
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100126 QASymm8Decoder(const float scale, const int32_t offset)
127 : QASymm8Decoder(nullptr, scale, offset) {}
128
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100129 float Get() const override
130 {
131 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
132 }
133
134private:
135 const float m_Scale;
136 const int32_t m_Offset;
137};
138
Ryan OShea9add1202020-02-07 10:06:33 +0000139class QASymmS8Decoder : public TypedIterator<const int8_t, Decoder<float>>
140{
141public:
142 QASymmS8Decoder(const int8_t* data, const float scale, const int32_t offset)
143 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
144
145 QASymmS8Decoder(const float scale, const int32_t offset)
146 : QASymmS8Decoder(nullptr, scale, offset) {}
147
148 float Get() const override
149 {
150 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
151 }
152
153private:
154 const float m_Scale;
155 const int32_t m_Offset;
156};
157
Finn Williamsfd271062019-12-04 14:27:27 +0000158class QSymmS8Decoder : public TypedIterator<const int8_t, Decoder<float>>
159{
160public:
161 QSymmS8Decoder(const int8_t* data, const float scale, const int32_t offset)
162 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
163
164 QSymmS8Decoder(const float scale, const int32_t offset)
165 : QSymmS8Decoder(nullptr, scale, offset) {}
166
167 float Get() const override
168 {
169 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
170 }
171
172private:
173 const float m_Scale;
174 const int32_t m_Offset;
175};
176
Derek Lambertif30f7d32019-04-09 10:25:02 +0100177class QSymm16Decoder : public TypedIterator<const int16_t, Decoder<float>>
Sadik Armagan2999a022019-04-09 14:20:12 +0100178{
179public:
180 QSymm16Decoder(const int16_t* data, const float scale, const int32_t offset)
181 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
182
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100183 QSymm16Decoder(const float scale, const int32_t offset)
184 : QSymm16Decoder(nullptr, scale, offset) {}
185
Sadik Armagan2999a022019-04-09 14:20:12 +0100186 float Get() const override
187 {
188 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
189 }
190
191private:
192 const float m_Scale;
193 const int32_t m_Offset;
194};
195
Narumol Prangnawarat88325222020-03-06 14:45:57 +0000196class BFloat16Decoder : public TypedIterator<const BFloat16, Decoder<float>>
197{
198public:
199 BFloat16Decoder(const BFloat16* data)
200 : TypedIterator(data) {}
201
202 BFloat16Decoder()
203 : BFloat16Decoder(nullptr) {}
204
205 float Get() const override
206 {
207 float val = 0.f;
208 armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32(m_Iterator, 1, &val);
209 return val;
210 }
211};
212
Matthew Jacksone69c3992019-09-09 14:31:21 +0100213class Float16Decoder : public TypedIterator<const Half, Decoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100214{
215public:
Matthew Jacksone69c3992019-09-09 14:31:21 +0100216 Float16Decoder(const Half* data)
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100217 : TypedIterator(data) {}
218
Matthew Jacksone69c3992019-09-09 14:31:21 +0100219 Float16Decoder()
220 : Float16Decoder(nullptr) {}
221
222 float Get() const override
223 {
224 float val = 0.f;
225 armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
226 return val;
227 }
228};
229
230class Float32Decoder : public TypedIterator<const float, Decoder<float>>
231{
232public:
233 Float32Decoder(const float* data)
234 : TypedIterator(data) {}
235
236 Float32Decoder()
237 : Float32Decoder(nullptr) {}
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100238
Derek Lambertif30f7d32019-04-09 10:25:02 +0100239 float Get() const override
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100240 {
Derek Lambertif30f7d32019-04-09 10:25:02 +0100241 return *m_Iterator;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100242 }
243};
244
Mike Kelly9b398322019-05-22 17:21:49 +0100245class ScaledInt32Decoder : public TypedIterator<const int32_t, Decoder<float>>
246{
247public:
248 ScaledInt32Decoder(const int32_t* data, const float scale)
249 : TypedIterator(data), m_Scale(scale) {}
250
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100251 ScaledInt32Decoder(const float scale)
252 : ScaledInt32Decoder(nullptr, scale) {}
253
Mike Kelly9b398322019-05-22 17:21:49 +0100254 float Get() const override
255 {
256 return static_cast<float>(*m_Iterator) * m_Scale;
257 }
258
259private:
260 const float m_Scale;
261};
262
Aron Virginas-Tar198ee402019-08-02 18:54:28 +0100263class Int32Decoder : public TypedIterator<const int32_t, Decoder<float>>
264{
265public:
266 Int32Decoder(const int32_t* data)
267 : TypedIterator(data) {}
268
269 Int32Decoder()
270 : Int32Decoder(nullptr) {}
271
272 float Get() const override
273 {
274 return static_cast<float>(*m_Iterator);
275 }
276};
277
Sadik Armaganb60dd242020-03-19 13:53:16 +0000278class BooleanDecoder : public TypedIterator<const uint8_t, Decoder<float>>
279{
280public:
281 BooleanDecoder(const uint8_t* data)
282 : TypedIterator(data) {}
283
284 BooleanDecoder()
285 : BooleanDecoder(nullptr) {}
286
287 float Get() const override
288 {
289 return *m_Iterator;
290 }
291
292};
293
Derek Lambertif30f7d32019-04-09 10:25:02 +0100294class QASymm8Encoder : public TypedIterator<uint8_t, Encoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100295{
296public:
297 QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset)
298 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
299
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100300 QASymm8Encoder(const float scale, const int32_t offset)
301 : QASymm8Encoder(nullptr, scale, offset) {}
302
Derek Lambertif30f7d32019-04-09 10:25:02 +0100303 void Set(float right) override
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100304 {
305 *m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
306 }
307
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100308 float Get() const override
309 {
310 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
311 }
312
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100313private:
314 const float m_Scale;
315 const int32_t m_Offset;
316};
317
Ryan OShea9add1202020-02-07 10:06:33 +0000318class QASymmS8Encoder : public TypedIterator<int8_t, Encoder<float>>
319{
320public:
321 QASymmS8Encoder(int8_t* data, const float scale, const int32_t offset)
322 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
323
324 QASymmS8Encoder(const float scale, const int32_t offset)
325 : QASymmS8Encoder(nullptr, scale, offset) {}
326
327 void Set(float right) override
328 {
329 *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale, m_Offset);
330 }
331
332 float Get() const override
333 {
334 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
335 }
336
337private:
338 const float m_Scale;
339 const int32_t m_Offset;
340};
341
Finn Williamsfd271062019-12-04 14:27:27 +0000342class QSymmS8Encoder : public TypedIterator<int8_t, Encoder<float>>
343{
344public:
345 QSymmS8Encoder(int8_t* data, const float scale, const int32_t offset)
346 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
347
348 QSymmS8Encoder(const float scale, const int32_t offset)
349 : QSymmS8Encoder(nullptr, scale, offset) {}
350
351 void Set(float right) override
352 {
353 *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale, m_Offset);
354 }
355
356 float Get() const override
357 {
358 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
359 }
360
361private:
362 const float m_Scale;
363 const int32_t m_Offset;
364};
365
Derek Lambertif30f7d32019-04-09 10:25:02 +0100366class QSymm16Encoder : public TypedIterator<int16_t, Encoder<float>>
367{
368public:
369 QSymm16Encoder(int16_t* data, const float scale, const int32_t offset)
370 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
371
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100372 QSymm16Encoder(const float scale, const int32_t offset)
373 : QSymm16Encoder(nullptr, scale, offset) {}
374
Derek Lambertif30f7d32019-04-09 10:25:02 +0100375 void Set(float right) override
376 {
377 *m_Iterator = armnn::Quantize<int16_t>(right, m_Scale, m_Offset);
378 }
379
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100380 float Get() const override
381 {
382 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
383 }
384
Derek Lambertif30f7d32019-04-09 10:25:02 +0100385private:
386 const float m_Scale;
387 const int32_t m_Offset;
388};
389
Narumol Prangnawarat88325222020-03-06 14:45:57 +0000390class BFloat16Encoder : public TypedIterator<armnn::BFloat16, Encoder<float>>
391{
392public:
393 BFloat16Encoder(armnn::BFloat16* data)
394 : TypedIterator(data) {}
395
396 BFloat16Encoder()
397 : BFloat16Encoder(nullptr) {}
398
399 void Set(float right) override
400 {
401 armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(&right, 1, m_Iterator);
402 }
403
404 float Get() const override
405 {
406 float val = 0.f;
407 armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32(m_Iterator, 1, &val);
408 return val;
409 }
410};
411
Matthew Jacksone69c3992019-09-09 14:31:21 +0100412class Float16Encoder : public TypedIterator<Half, Encoder<float>>
Derek Lambertif30f7d32019-04-09 10:25:02 +0100413{
414public:
Matthew Jacksone69c3992019-09-09 14:31:21 +0100415 Float16Encoder(Half* data)
Derek Lambertif30f7d32019-04-09 10:25:02 +0100416 : TypedIterator(data) {}
417
Matthew Jacksone69c3992019-09-09 14:31:21 +0100418 Float16Encoder()
419 : Float16Encoder(nullptr) {}
420
421 void Set(float right) override
422 {
423 armnnUtils::FloatingPointConverter::ConvertFloat32To16(&right, 1, m_Iterator);
424 }
425
426 float Get() const override
427 {
428 float val = 0.f;
429 armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
430 return val;
431 }
432};
433
434class Float32Encoder : public TypedIterator<float, Encoder<float>>
435{
436public:
437 Float32Encoder(float* data)
438 : TypedIterator(data) {}
439
440 Float32Encoder()
441 : Float32Encoder(nullptr) {}
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100442
Derek Lambertif30f7d32019-04-09 10:25:02 +0100443 void Set(float right) override
444 {
445 *m_Iterator = right;
446 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100447
448 float Get() const override
449 {
450 return *m_Iterator;
451 }
Derek Lambertif30f7d32019-04-09 10:25:02 +0100452};
453
Aron Virginas-Tar198ee402019-08-02 18:54:28 +0100454class Int32Encoder : public TypedIterator<int32_t, Encoder<float>>
455{
456public:
457 Int32Encoder(int32_t* data)
458 : TypedIterator(data) {}
459
460 Int32Encoder()
461 : Int32Encoder(nullptr) {}
462
463 void Set(float right) override
464 {
465 *m_Iterator = static_cast<int32_t>(right);
466 }
467
468 float Get() const override
469 {
470 return static_cast<float>(*m_Iterator);
471 }
472};
473
Derek Lambertif30f7d32019-04-09 10:25:02 +0100474class BooleanEncoder : public TypedIterator<uint8_t, Encoder<bool>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100475{
476public:
477 BooleanEncoder(uint8_t* data)
478 : TypedIterator(data) {}
479
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100480 BooleanEncoder()
481 : BooleanEncoder(nullptr) {}
482
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100483 void Set(bool right) override
484 {
485 *m_Iterator = right;
486 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100487
488 bool Get() const override
489 {
490 return *m_Iterator;
491 }
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100492};
493
Keith Davis5236e1d2019-11-04 08:58:33 +0000494// PerAxisIterator for per-axis quantization
495template<typename T, typename Base>
496class PerAxisIterator : public Base
497{
498public:
499 // axisFactor is used to calculate axisIndex
500 PerAxisIterator(T* data = nullptr, unsigned int axisFactor = 0)
501 : m_Iterator(data), m_Start(data), m_AxisIndex(0), m_AxisFactor(axisFactor)
502 {}
503
504 // This should be called to set index for per-axis Encoder/Decoder
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000505 PerAxisIterator& SetIndex(unsigned int index, unsigned int axisIndex) override
Keith Davis5236e1d2019-11-04 08:58:33 +0000506 {
507 BOOST_ASSERT(m_Iterator);
508 m_Iterator = m_Start + index;
509 m_AxisIndex = axisIndex;
510 return *this;
511 }
512
513 void Reset(void* data) override
514 {
515 m_Iterator = reinterpret_cast<T*>(data);
516 m_Start = m_Iterator;
517 m_AxisIndex = 0;
518 }
519
520 PerAxisIterator& operator++() override
521 {
522 BOOST_ASSERT(m_Iterator);
523 ++m_Iterator;
524 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
525 return *this;
526 }
527
528 PerAxisIterator& operator+=(const unsigned int increment) override
529 {
530 BOOST_ASSERT(m_Iterator);
531 m_Iterator += increment;
532 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
533 return *this;
534 }
535
536 PerAxisIterator& operator-=(const unsigned int decrement) override
537 {
538 BOOST_ASSERT(m_Iterator);
539 m_Iterator -= decrement;
540 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
541 return *this;
542 }
543
544 PerAxisIterator& operator[](const unsigned int index) override
545 {
546 BOOST_ASSERT(m_Iterator);
547 m_Iterator = m_Start + index;
548 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
549 return *this;
550 }
551
552 protected:
553 T* m_Iterator;
554 T* m_Start;
555 unsigned int m_AxisIndex;
556 unsigned int m_AxisFactor;
557};
558
559class QSymm8PerAxisDecoder : public PerAxisIterator<const int8_t, Decoder<float>>
560{
561public:
562 QSymm8PerAxisDecoder(const int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
563 : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
564
565 float Get() const override
566 {
567 return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
568 }
569
570 // Get scale of the current value
571 float GetScale() const
572 {
573 return m_Scale[m_AxisIndex];
574 }
575
576private:
577 std::vector<float> m_Scale;
578};
579
580class QSymm8PerAxisEncoder : public PerAxisIterator<int8_t, Encoder<float>>
581{
582public:
583 QSymm8PerAxisEncoder(int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
584 : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
585
586 void Set(float right)
587 {
588 *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale[m_AxisIndex], 0);
589 }
590
591 float Get() const
592 {
593 return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
594 }
595
596 // Get scale of the current value
597 float GetScale() const
598 {
599 return m_Scale[m_AxisIndex];
600 }
601
602private:
603 std::vector<float> m_Scale;
604};
605
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000606class ScaledInt32PerAxisDecoder : public PerAxisIterator<const int32_t, Decoder<float>>
607{
608public:
609 ScaledInt32PerAxisDecoder(const int32_t* data, const std::vector<float>& scales, unsigned int axisFactor)
610 : PerAxisIterator(data, axisFactor), m_Scales(scales) {}
611
612 float Get() const override
613 {
614 return armnn::Dequantize(*m_Iterator, m_Scales[m_AxisIndex], 0);
615 }
616
617 // Get scale of the current value
618 float GetScale() const
619 {
620 return m_Scales[m_AxisIndex];
621 }
622
623private:
624 std::vector<float> m_Scales;
625};
626
627} // namespace armnn