blob: 50475312a559fc8728bdbb96ab699743b8118214 [file] [log] [blame]
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Jacksone69c3992019-09-09 14:31:21 +01008#include "FloatingPointConverter.hpp"
9
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010010#include <armnn/ArmNN.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010011#include <ResolveType.hpp>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010012
Matthew Benthamc394a6d2019-06-24 12:51:25 +010013#include <boost/assert.hpp>
14
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010015namespace armnn
16{
17
18class BaseIterator
19{
20public:
21 BaseIterator() {}
22
23 virtual ~BaseIterator() {}
24
25 virtual BaseIterator& operator++() = 0;
26
27 virtual BaseIterator& operator+=(const unsigned int increment) = 0;
28
29 virtual BaseIterator& operator-=(const unsigned int increment) = 0;
Francis Murtagh43aec582019-05-27 12:14:10 +010030
31 virtual BaseIterator& operator[](const unsigned int index) = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010032};
33
Derek Lambertif30f7d32019-04-09 10:25:02 +010034template<typename IType>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010035class Decoder : public BaseIterator
36{
37public:
Derek Lambertif30f7d32019-04-09 10:25:02 +010038 Decoder() {}
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010039
40 virtual ~Decoder() {}
41
Matthew Benthamc394a6d2019-06-24 12:51:25 +010042 virtual void Reset(void*) = 0;
43
Derek Lambertif30f7d32019-04-09 10:25:02 +010044 virtual IType Get() const = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010045};
46
Derek Lambertif30f7d32019-04-09 10:25:02 +010047template<typename IType>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010048class Encoder : public BaseIterator
49{
50public:
Derek Lambertif30f7d32019-04-09 10:25:02 +010051 Encoder() {}
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010052
53 virtual ~Encoder() {}
54
Matthew Benthamc394a6d2019-06-24 12:51:25 +010055 virtual void Reset(void*) = 0;
56
Derek Lambertif30f7d32019-04-09 10:25:02 +010057 virtual void Set(IType right) = 0;
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010058
59 virtual IType Get() const = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010060};
61
62template<typename T, typename Base>
63class TypedIterator : public Base
64{
65public:
Matthew Benthamc394a6d2019-06-24 12:51:25 +010066 TypedIterator(T* data = nullptr)
Francis Murtagh43aec582019-05-27 12:14:10 +010067 : m_Iterator(data), m_Start(data)
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010068 {}
69
Matthew Benthamc394a6d2019-06-24 12:51:25 +010070 void Reset(void* data) override
71 {
72 m_Iterator = reinterpret_cast<T*>(data);
73 m_Start = m_Iterator;
74 }
75
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010076 TypedIterator& operator++() override
77 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010078 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010079 ++m_Iterator;
80 return *this;
81 }
82
83 TypedIterator& operator+=(const unsigned int increment) override
84 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010085 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010086 m_Iterator += increment;
87 return *this;
88 }
89
90 TypedIterator& operator-=(const unsigned int increment) override
91 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010092 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010093 m_Iterator -= increment;
94 return *this;
95 }
96
Francis Murtagh43aec582019-05-27 12:14:10 +010097 TypedIterator& operator[](const unsigned int index) override
98 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010099 BOOST_ASSERT(m_Iterator);
Francis Murtagh43aec582019-05-27 12:14:10 +0100100 m_Iterator = m_Start + index;
101 return *this;
102 }
103
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100104protected:
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100105 T* m_Iterator;
Francis Murtagh43aec582019-05-27 12:14:10 +0100106 T* m_Start;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100107};
108
Derek Lambertif30f7d32019-04-09 10:25:02 +0100109class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100110{
111public:
112 QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset)
113 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
114
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100115 QASymm8Decoder(const float scale, const int32_t offset)
116 : QASymm8Decoder(nullptr, scale, offset) {}
117
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100118 float Get() const override
119 {
120 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
121 }
122
123private:
124 const float m_Scale;
125 const int32_t m_Offset;
126};
127
Derek Lambertif30f7d32019-04-09 10:25:02 +0100128class QSymm16Decoder : public TypedIterator<const int16_t, Decoder<float>>
Sadik Armagan2999a022019-04-09 14:20:12 +0100129{
130public:
131 QSymm16Decoder(const int16_t* data, const float scale, const int32_t offset)
132 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
133
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100134 QSymm16Decoder(const float scale, const int32_t offset)
135 : QSymm16Decoder(nullptr, scale, offset) {}
136
Sadik Armagan2999a022019-04-09 14:20:12 +0100137 float Get() const override
138 {
139 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
140 }
141
142private:
143 const float m_Scale;
144 const int32_t m_Offset;
145};
146
Matthew Jacksone69c3992019-09-09 14:31:21 +0100147class Float16Decoder : public TypedIterator<const Half, Decoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100148{
149public:
Matthew Jacksone69c3992019-09-09 14:31:21 +0100150 Float16Decoder(const Half* data)
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100151 : TypedIterator(data) {}
152
Matthew Jacksone69c3992019-09-09 14:31:21 +0100153 Float16Decoder()
154 : Float16Decoder(nullptr) {}
155
156 float Get() const override
157 {
158 float val = 0.f;
159 armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
160 return val;
161 }
162};
163
164class Float32Decoder : public TypedIterator<const float, Decoder<float>>
165{
166public:
167 Float32Decoder(const float* data)
168 : TypedIterator(data) {}
169
170 Float32Decoder()
171 : Float32Decoder(nullptr) {}
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100172
Derek Lambertif30f7d32019-04-09 10:25:02 +0100173 float Get() const override
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100174 {
Derek Lambertif30f7d32019-04-09 10:25:02 +0100175 return *m_Iterator;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100176 }
177};
178
Mike Kelly9b398322019-05-22 17:21:49 +0100179class ScaledInt32Decoder : public TypedIterator<const int32_t, Decoder<float>>
180{
181public:
182 ScaledInt32Decoder(const int32_t* data, const float scale)
183 : TypedIterator(data), m_Scale(scale) {}
184
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100185 ScaledInt32Decoder(const float scale)
186 : ScaledInt32Decoder(nullptr, scale) {}
187
Mike Kelly9b398322019-05-22 17:21:49 +0100188 float Get() const override
189 {
190 return static_cast<float>(*m_Iterator) * m_Scale;
191 }
192
193private:
194 const float m_Scale;
195};
196
Aron Virginas-Tar198ee402019-08-02 18:54:28 +0100197class Int32Decoder : public TypedIterator<const int32_t, Decoder<float>>
198{
199public:
200 Int32Decoder(const int32_t* data)
201 : TypedIterator(data) {}
202
203 Int32Decoder()
204 : Int32Decoder(nullptr) {}
205
206 float Get() const override
207 {
208 return static_cast<float>(*m_Iterator);
209 }
210};
211
Derek Lambertif30f7d32019-04-09 10:25:02 +0100212class QASymm8Encoder : public TypedIterator<uint8_t, Encoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100213{
214public:
215 QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset)
216 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
217
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100218 QASymm8Encoder(const float scale, const int32_t offset)
219 : QASymm8Encoder(nullptr, scale, offset) {}
220
Derek Lambertif30f7d32019-04-09 10:25:02 +0100221 void Set(float right) override
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100222 {
223 *m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
224 }
225
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100226 float Get() const override
227 {
228 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
229 }
230
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100231private:
232 const float m_Scale;
233 const int32_t m_Offset;
234};
235
Derek Lambertif30f7d32019-04-09 10:25:02 +0100236class QSymm16Encoder : public TypedIterator<int16_t, Encoder<float>>
237{
238public:
239 QSymm16Encoder(int16_t* data, const float scale, const int32_t offset)
240 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
241
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100242 QSymm16Encoder(const float scale, const int32_t offset)
243 : QSymm16Encoder(nullptr, scale, offset) {}
244
Derek Lambertif30f7d32019-04-09 10:25:02 +0100245 void Set(float right) override
246 {
247 *m_Iterator = armnn::Quantize<int16_t>(right, m_Scale, m_Offset);
248 }
249
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100250 float Get() const override
251 {
252 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
253 }
254
Derek Lambertif30f7d32019-04-09 10:25:02 +0100255private:
256 const float m_Scale;
257 const int32_t m_Offset;
258};
259
Matthew Jacksone69c3992019-09-09 14:31:21 +0100260class Float16Encoder : public TypedIterator<Half, Encoder<float>>
Derek Lambertif30f7d32019-04-09 10:25:02 +0100261{
262public:
Matthew Jacksone69c3992019-09-09 14:31:21 +0100263 Float16Encoder(Half* data)
Derek Lambertif30f7d32019-04-09 10:25:02 +0100264 : TypedIterator(data) {}
265
Matthew Jacksone69c3992019-09-09 14:31:21 +0100266 Float16Encoder()
267 : Float16Encoder(nullptr) {}
268
269 void Set(float right) override
270 {
271 armnnUtils::FloatingPointConverter::ConvertFloat32To16(&right, 1, m_Iterator);
272 }
273
274 float Get() const override
275 {
276 float val = 0.f;
277 armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
278 return val;
279 }
280};
281
282class Float32Encoder : public TypedIterator<float, Encoder<float>>
283{
284public:
285 Float32Encoder(float* data)
286 : TypedIterator(data) {}
287
288 Float32Encoder()
289 : Float32Encoder(nullptr) {}
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100290
Derek Lambertif30f7d32019-04-09 10:25:02 +0100291 void Set(float right) override
292 {
293 *m_Iterator = right;
294 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100295
296 float Get() const override
297 {
298 return *m_Iterator;
299 }
Derek Lambertif30f7d32019-04-09 10:25:02 +0100300};
301
Aron Virginas-Tar198ee402019-08-02 18:54:28 +0100302class Int32Encoder : public TypedIterator<int32_t, Encoder<float>>
303{
304public:
305 Int32Encoder(int32_t* data)
306 : TypedIterator(data) {}
307
308 Int32Encoder()
309 : Int32Encoder(nullptr) {}
310
311 void Set(float right) override
312 {
313 *m_Iterator = static_cast<int32_t>(right);
314 }
315
316 float Get() const override
317 {
318 return static_cast<float>(*m_Iterator);
319 }
320};
321
Derek Lambertif30f7d32019-04-09 10:25:02 +0100322class BooleanEncoder : public TypedIterator<uint8_t, Encoder<bool>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100323{
324public:
325 BooleanEncoder(uint8_t* data)
326 : TypedIterator(data) {}
327
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100328 BooleanEncoder()
329 : BooleanEncoder(nullptr) {}
330
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100331 void Set(bool right) override
332 {
333 *m_Iterator = right;
334 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100335
336 bool Get() const override
337 {
338 return *m_Iterator;
339 }
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100340};
341
Keith Davis5236e1d2019-11-04 08:58:33 +0000342// PerAxisIterator for per-axis quantization
343template<typename T, typename Base>
344class PerAxisIterator : public Base
345{
346public:
347 // axisFactor is used to calculate axisIndex
348 PerAxisIterator(T* data = nullptr, unsigned int axisFactor = 0)
349 : m_Iterator(data), m_Start(data), m_AxisIndex(0), m_AxisFactor(axisFactor)
350 {}
351
352 // This should be called to set index for per-axis Encoder/Decoder
353 PerAxisIterator& SetIndex(unsigned int index, unsigned int axisIndex)
354 {
355 BOOST_ASSERT(m_Iterator);
356 m_Iterator = m_Start + index;
357 m_AxisIndex = axisIndex;
358 return *this;
359 }
360
361 void Reset(void* data) override
362 {
363 m_Iterator = reinterpret_cast<T*>(data);
364 m_Start = m_Iterator;
365 m_AxisIndex = 0;
366 }
367
368 PerAxisIterator& operator++() override
369 {
370 BOOST_ASSERT(m_Iterator);
371 ++m_Iterator;
372 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
373 return *this;
374 }
375
376 PerAxisIterator& operator+=(const unsigned int increment) override
377 {
378 BOOST_ASSERT(m_Iterator);
379 m_Iterator += increment;
380 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
381 return *this;
382 }
383
384 PerAxisIterator& operator-=(const unsigned int decrement) override
385 {
386 BOOST_ASSERT(m_Iterator);
387 m_Iterator -= decrement;
388 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
389 return *this;
390 }
391
392 PerAxisIterator& operator[](const unsigned int index) override
393 {
394 BOOST_ASSERT(m_Iterator);
395 m_Iterator = m_Start + index;
396 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
397 return *this;
398 }
399
400 protected:
401 T* m_Iterator;
402 T* m_Start;
403 unsigned int m_AxisIndex;
404 unsigned int m_AxisFactor;
405};
406
407class QSymm8PerAxisDecoder : public PerAxisIterator<const int8_t, Decoder<float>>
408{
409public:
410 QSymm8PerAxisDecoder(const int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
411 : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
412
413 float Get() const override
414 {
415 return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
416 }
417
418 // Get scale of the current value
419 float GetScale() const
420 {
421 return m_Scale[m_AxisIndex];
422 }
423
424private:
425 std::vector<float> m_Scale;
426};
427
428class QSymm8PerAxisEncoder : public PerAxisIterator<int8_t, Encoder<float>>
429{
430public:
431 QSymm8PerAxisEncoder(int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
432 : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
433
434 void Set(float right)
435 {
436 *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale[m_AxisIndex], 0);
437 }
438
439 float Get() const
440 {
441 return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
442 }
443
444 // Get scale of the current value
445 float GetScale() const
446 {
447 return m_Scale[m_AxisIndex];
448 }
449
450private:
451 std::vector<float> m_Scale;
452};
453
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000454class ScaledInt32PerAxisDecoder : public PerAxisIterator<const int32_t, Decoder<float>>
455{
456public:
457 ScaledInt32PerAxisDecoder(const int32_t* data, const std::vector<float>& scales, unsigned int axisFactor)
458 : PerAxisIterator(data, axisFactor), m_Scales(scales) {}
459
460 float Get() const override
461 {
462 return armnn::Dequantize(*m_Iterator, m_Scales[m_AxisIndex], 0);
463 }
464
465 // Get scale of the current value
466 float GetScale() const
467 {
468 return m_Scales[m_AxisIndex];
469 }
470
471private:
472 std::vector<float> m_Scales;
473};
474
475} // namespace armnn