blob: 95a31fbdd6582448169be82d10028c983e6a0a20 [file] [log] [blame]
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Jacksone69c3992019-09-09 14:31:21 +01008#include "FloatingPointConverter.hpp"
9
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010010#include <armnn/ArmNN.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010011#include <ResolveType.hpp>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010012
Matthew Benthamc394a6d2019-06-24 12:51:25 +010013#include <boost/assert.hpp>
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +000014#include <boost/core/ignore_unused.hpp>
Matthew Benthamc394a6d2019-06-24 12:51:25 +010015
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010016namespace armnn
17{
18
19class BaseIterator
20{
21public:
22 BaseIterator() {}
23
24 virtual ~BaseIterator() {}
25
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +000026 virtual BaseIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) = 0;
27
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010028 virtual BaseIterator& operator++() = 0;
29
30 virtual BaseIterator& operator+=(const unsigned int increment) = 0;
31
32 virtual BaseIterator& operator-=(const unsigned int increment) = 0;
Francis Murtagh43aec582019-05-27 12:14:10 +010033
34 virtual BaseIterator& operator[](const unsigned int index) = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010035};
36
Derek Lambertif30f7d32019-04-09 10:25:02 +010037template<typename IType>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010038class Decoder : public BaseIterator
39{
40public:
Derek Lambertif30f7d32019-04-09 10:25:02 +010041 Decoder() {}
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010042
43 virtual ~Decoder() {}
44
Matthew Benthamc394a6d2019-06-24 12:51:25 +010045 virtual void Reset(void*) = 0;
46
Derek Lambertif30f7d32019-04-09 10:25:02 +010047 virtual IType Get() const = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010048};
49
Derek Lambertif30f7d32019-04-09 10:25:02 +010050template<typename IType>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010051class Encoder : public BaseIterator
52{
53public:
Derek Lambertif30f7d32019-04-09 10:25:02 +010054 Encoder() {}
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010055
56 virtual ~Encoder() {}
57
Matthew Benthamc394a6d2019-06-24 12:51:25 +010058 virtual void Reset(void*) = 0;
59
Derek Lambertif30f7d32019-04-09 10:25:02 +010060 virtual void Set(IType right) = 0;
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010061
62 virtual IType Get() const = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010063};
64
65template<typename T, typename Base>
66class TypedIterator : public Base
67{
68public:
Matthew Benthamc394a6d2019-06-24 12:51:25 +010069 TypedIterator(T* data = nullptr)
Francis Murtagh43aec582019-05-27 12:14:10 +010070 : m_Iterator(data), m_Start(data)
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010071 {}
72
Matthew Benthamc394a6d2019-06-24 12:51:25 +010073 void Reset(void* data) override
74 {
75 m_Iterator = reinterpret_cast<T*>(data);
76 m_Start = m_Iterator;
77 }
78
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010079 TypedIterator& operator++() override
80 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010081 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010082 ++m_Iterator;
83 return *this;
84 }
85
86 TypedIterator& operator+=(const unsigned int increment) override
87 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010088 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010089 m_Iterator += increment;
90 return *this;
91 }
92
93 TypedIterator& operator-=(const unsigned int increment) override
94 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010095 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010096 m_Iterator -= increment;
97 return *this;
98 }
99
Francis Murtagh43aec582019-05-27 12:14:10 +0100100 TypedIterator& operator[](const unsigned int index) override
101 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100102 BOOST_ASSERT(m_Iterator);
Francis Murtagh43aec582019-05-27 12:14:10 +0100103 m_Iterator = m_Start + index;
104 return *this;
105 }
106
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000107 TypedIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) override
108 {
109 boost::ignore_unused(axisIndex);
110 BOOST_ASSERT(m_Iterator);
111 m_Iterator = m_Start + index;
112 return *this;
113 }
114
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100115protected:
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100116 T* m_Iterator;
Francis Murtagh43aec582019-05-27 12:14:10 +0100117 T* m_Start;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100118};
119
Derek Lambertif30f7d32019-04-09 10:25:02 +0100120class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100121{
122public:
123 QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset)
124 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
125
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100126 QASymm8Decoder(const float scale, const int32_t offset)
127 : QASymm8Decoder(nullptr, scale, offset) {}
128
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100129 float Get() const override
130 {
131 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
132 }
133
134private:
135 const float m_Scale;
136 const int32_t m_Offset;
137};
138
Derek Lambertif30f7d32019-04-09 10:25:02 +0100139class QSymm16Decoder : public TypedIterator<const int16_t, Decoder<float>>
Sadik Armagan2999a022019-04-09 14:20:12 +0100140{
141public:
142 QSymm16Decoder(const int16_t* data, const float scale, const int32_t offset)
143 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
144
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100145 QSymm16Decoder(const float scale, const int32_t offset)
146 : QSymm16Decoder(nullptr, scale, offset) {}
147
Sadik Armagan2999a022019-04-09 14:20:12 +0100148 float Get() const override
149 {
150 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
151 }
152
153private:
154 const float m_Scale;
155 const int32_t m_Offset;
156};
157
Matthew Jacksone69c3992019-09-09 14:31:21 +0100158class Float16Decoder : public TypedIterator<const Half, Decoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100159{
160public:
Matthew Jacksone69c3992019-09-09 14:31:21 +0100161 Float16Decoder(const Half* data)
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100162 : TypedIterator(data) {}
163
Matthew Jacksone69c3992019-09-09 14:31:21 +0100164 Float16Decoder()
165 : Float16Decoder(nullptr) {}
166
167 float Get() const override
168 {
169 float val = 0.f;
170 armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
171 return val;
172 }
173};
174
175class Float32Decoder : public TypedIterator<const float, Decoder<float>>
176{
177public:
178 Float32Decoder(const float* data)
179 : TypedIterator(data) {}
180
181 Float32Decoder()
182 : Float32Decoder(nullptr) {}
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100183
Derek Lambertif30f7d32019-04-09 10:25:02 +0100184 float Get() const override
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100185 {
Derek Lambertif30f7d32019-04-09 10:25:02 +0100186 return *m_Iterator;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100187 }
188};
189
Mike Kelly9b398322019-05-22 17:21:49 +0100190class ScaledInt32Decoder : public TypedIterator<const int32_t, Decoder<float>>
191{
192public:
193 ScaledInt32Decoder(const int32_t* data, const float scale)
194 : TypedIterator(data), m_Scale(scale) {}
195
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100196 ScaledInt32Decoder(const float scale)
197 : ScaledInt32Decoder(nullptr, scale) {}
198
Mike Kelly9b398322019-05-22 17:21:49 +0100199 float Get() const override
200 {
201 return static_cast<float>(*m_Iterator) * m_Scale;
202 }
203
204private:
205 const float m_Scale;
206};
207
Aron Virginas-Tar198ee402019-08-02 18:54:28 +0100208class Int32Decoder : public TypedIterator<const int32_t, Decoder<float>>
209{
210public:
211 Int32Decoder(const int32_t* data)
212 : TypedIterator(data) {}
213
214 Int32Decoder()
215 : Int32Decoder(nullptr) {}
216
217 float Get() const override
218 {
219 return static_cast<float>(*m_Iterator);
220 }
221};
222
Derek Lambertif30f7d32019-04-09 10:25:02 +0100223class QASymm8Encoder : public TypedIterator<uint8_t, Encoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100224{
225public:
226 QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset)
227 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
228
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100229 QASymm8Encoder(const float scale, const int32_t offset)
230 : QASymm8Encoder(nullptr, scale, offset) {}
231
Derek Lambertif30f7d32019-04-09 10:25:02 +0100232 void Set(float right) override
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100233 {
234 *m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
235 }
236
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100237 float Get() const override
238 {
239 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
240 }
241
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100242private:
243 const float m_Scale;
244 const int32_t m_Offset;
245};
246
Derek Lambertif30f7d32019-04-09 10:25:02 +0100247class QSymm16Encoder : public TypedIterator<int16_t, Encoder<float>>
248{
249public:
250 QSymm16Encoder(int16_t* data, const float scale, const int32_t offset)
251 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
252
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100253 QSymm16Encoder(const float scale, const int32_t offset)
254 : QSymm16Encoder(nullptr, scale, offset) {}
255
Derek Lambertif30f7d32019-04-09 10:25:02 +0100256 void Set(float right) override
257 {
258 *m_Iterator = armnn::Quantize<int16_t>(right, m_Scale, m_Offset);
259 }
260
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100261 float Get() const override
262 {
263 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
264 }
265
Derek Lambertif30f7d32019-04-09 10:25:02 +0100266private:
267 const float m_Scale;
268 const int32_t m_Offset;
269};
270
Matthew Jacksone69c3992019-09-09 14:31:21 +0100271class Float16Encoder : public TypedIterator<Half, Encoder<float>>
Derek Lambertif30f7d32019-04-09 10:25:02 +0100272{
273public:
Matthew Jacksone69c3992019-09-09 14:31:21 +0100274 Float16Encoder(Half* data)
Derek Lambertif30f7d32019-04-09 10:25:02 +0100275 : TypedIterator(data) {}
276
Matthew Jacksone69c3992019-09-09 14:31:21 +0100277 Float16Encoder()
278 : Float16Encoder(nullptr) {}
279
280 void Set(float right) override
281 {
282 armnnUtils::FloatingPointConverter::ConvertFloat32To16(&right, 1, m_Iterator);
283 }
284
285 float Get() const override
286 {
287 float val = 0.f;
288 armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
289 return val;
290 }
291};
292
293class Float32Encoder : public TypedIterator<float, Encoder<float>>
294{
295public:
296 Float32Encoder(float* data)
297 : TypedIterator(data) {}
298
299 Float32Encoder()
300 : Float32Encoder(nullptr) {}
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100301
Derek Lambertif30f7d32019-04-09 10:25:02 +0100302 void Set(float right) override
303 {
304 *m_Iterator = right;
305 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100306
307 float Get() const override
308 {
309 return *m_Iterator;
310 }
Derek Lambertif30f7d32019-04-09 10:25:02 +0100311};
312
Aron Virginas-Tar198ee402019-08-02 18:54:28 +0100313class Int32Encoder : public TypedIterator<int32_t, Encoder<float>>
314{
315public:
316 Int32Encoder(int32_t* data)
317 : TypedIterator(data) {}
318
319 Int32Encoder()
320 : Int32Encoder(nullptr) {}
321
322 void Set(float right) override
323 {
324 *m_Iterator = static_cast<int32_t>(right);
325 }
326
327 float Get() const override
328 {
329 return static_cast<float>(*m_Iterator);
330 }
331};
332
Derek Lambertif30f7d32019-04-09 10:25:02 +0100333class BooleanEncoder : public TypedIterator<uint8_t, Encoder<bool>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100334{
335public:
336 BooleanEncoder(uint8_t* data)
337 : TypedIterator(data) {}
338
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100339 BooleanEncoder()
340 : BooleanEncoder(nullptr) {}
341
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100342 void Set(bool right) override
343 {
344 *m_Iterator = right;
345 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100346
347 bool Get() const override
348 {
349 return *m_Iterator;
350 }
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100351};
352
Keith Davis5236e1d2019-11-04 08:58:33 +0000353// PerAxisIterator for per-axis quantization
354template<typename T, typename Base>
355class PerAxisIterator : public Base
356{
357public:
358 // axisFactor is used to calculate axisIndex
359 PerAxisIterator(T* data = nullptr, unsigned int axisFactor = 0)
360 : m_Iterator(data), m_Start(data), m_AxisIndex(0), m_AxisFactor(axisFactor)
361 {}
362
363 // This should be called to set index for per-axis Encoder/Decoder
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000364 PerAxisIterator& SetIndex(unsigned int index, unsigned int axisIndex) override
Keith Davis5236e1d2019-11-04 08:58:33 +0000365 {
366 BOOST_ASSERT(m_Iterator);
367 m_Iterator = m_Start + index;
368 m_AxisIndex = axisIndex;
369 return *this;
370 }
371
372 void Reset(void* data) override
373 {
374 m_Iterator = reinterpret_cast<T*>(data);
375 m_Start = m_Iterator;
376 m_AxisIndex = 0;
377 }
378
379 PerAxisIterator& operator++() override
380 {
381 BOOST_ASSERT(m_Iterator);
382 ++m_Iterator;
383 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
384 return *this;
385 }
386
387 PerAxisIterator& operator+=(const unsigned int increment) override
388 {
389 BOOST_ASSERT(m_Iterator);
390 m_Iterator += increment;
391 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
392 return *this;
393 }
394
395 PerAxisIterator& operator-=(const unsigned int decrement) override
396 {
397 BOOST_ASSERT(m_Iterator);
398 m_Iterator -= decrement;
399 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
400 return *this;
401 }
402
403 PerAxisIterator& operator[](const unsigned int index) override
404 {
405 BOOST_ASSERT(m_Iterator);
406 m_Iterator = m_Start + index;
407 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
408 return *this;
409 }
410
411 protected:
412 T* m_Iterator;
413 T* m_Start;
414 unsigned int m_AxisIndex;
415 unsigned int m_AxisFactor;
416};
417
418class QSymm8PerAxisDecoder : public PerAxisIterator<const int8_t, Decoder<float>>
419{
420public:
421 QSymm8PerAxisDecoder(const int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
422 : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
423
424 float Get() const override
425 {
426 return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
427 }
428
429 // Get scale of the current value
430 float GetScale() const
431 {
432 return m_Scale[m_AxisIndex];
433 }
434
435private:
436 std::vector<float> m_Scale;
437};
438
439class QSymm8PerAxisEncoder : public PerAxisIterator<int8_t, Encoder<float>>
440{
441public:
442 QSymm8PerAxisEncoder(int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
443 : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
444
445 void Set(float right)
446 {
447 *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale[m_AxisIndex], 0);
448 }
449
450 float Get() const
451 {
452 return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
453 }
454
455 // Get scale of the current value
456 float GetScale() const
457 {
458 return m_Scale[m_AxisIndex];
459 }
460
461private:
462 std::vector<float> m_Scale;
463};
464
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000465class ScaledInt32PerAxisDecoder : public PerAxisIterator<const int32_t, Decoder<float>>
466{
467public:
468 ScaledInt32PerAxisDecoder(const int32_t* data, const std::vector<float>& scales, unsigned int axisFactor)
469 : PerAxisIterator(data, axisFactor), m_Scales(scales) {}
470
471 float Get() const override
472 {
473 return armnn::Dequantize(*m_Iterator, m_Scales[m_AxisIndex], 0);
474 }
475
476 // Get scale of the current value
477 float GetScale() const
478 {
479 return m_Scales[m_AxisIndex];
480 }
481
482private:
483 std::vector<float> m_Scales;
484};
485
486} // namespace armnn