blob: ca6d3cbc60fb0115111a72a209faa8725d7d4368 [file] [log] [blame]
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/ArmNN.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00009
10#include <armnnUtils/FloatingPointConverter.hpp>
11
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010012#include <ResolveType.hpp>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010013
Matthew Benthamc394a6d2019-06-24 12:51:25 +010014#include <boost/assert.hpp>
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +000015#include <boost/core/ignore_unused.hpp>
Matthew Benthamc394a6d2019-06-24 12:51:25 +010016
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010017namespace armnn
18{
19
20class BaseIterator
21{
22public:
23 BaseIterator() {}
24
25 virtual ~BaseIterator() {}
26
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +000027 virtual BaseIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) = 0;
28
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010029 virtual BaseIterator& operator++() = 0;
30
31 virtual BaseIterator& operator+=(const unsigned int increment) = 0;
32
33 virtual BaseIterator& operator-=(const unsigned int increment) = 0;
Francis Murtagh43aec582019-05-27 12:14:10 +010034
35 virtual BaseIterator& operator[](const unsigned int index) = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010036};
37
Derek Lambertif30f7d32019-04-09 10:25:02 +010038template<typename IType>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010039class Decoder : public BaseIterator
40{
41public:
Derek Lambertif30f7d32019-04-09 10:25:02 +010042 Decoder() {}
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010043
44 virtual ~Decoder() {}
45
Matthew Benthamc394a6d2019-06-24 12:51:25 +010046 virtual void Reset(void*) = 0;
47
Derek Lambertif30f7d32019-04-09 10:25:02 +010048 virtual IType Get() const = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010049};
50
Derek Lambertif30f7d32019-04-09 10:25:02 +010051template<typename IType>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010052class Encoder : public BaseIterator
53{
54public:
Derek Lambertif30f7d32019-04-09 10:25:02 +010055 Encoder() {}
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010056
57 virtual ~Encoder() {}
58
Matthew Benthamc394a6d2019-06-24 12:51:25 +010059 virtual void Reset(void*) = 0;
60
Derek Lambertif30f7d32019-04-09 10:25:02 +010061 virtual void Set(IType right) = 0;
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010062
63 virtual IType Get() const = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010064};
65
66template<typename T, typename Base>
67class TypedIterator : public Base
68{
69public:
Matthew Benthamc394a6d2019-06-24 12:51:25 +010070 TypedIterator(T* data = nullptr)
Francis Murtagh43aec582019-05-27 12:14:10 +010071 : m_Iterator(data), m_Start(data)
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010072 {}
73
Matthew Benthamc394a6d2019-06-24 12:51:25 +010074 void Reset(void* data) override
75 {
76 m_Iterator = reinterpret_cast<T*>(data);
77 m_Start = m_Iterator;
78 }
79
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010080 TypedIterator& operator++() override
81 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010082 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010083 ++m_Iterator;
84 return *this;
85 }
86
87 TypedIterator& operator+=(const unsigned int increment) override
88 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010089 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010090 m_Iterator += increment;
91 return *this;
92 }
93
94 TypedIterator& operator-=(const unsigned int increment) override
95 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010096 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010097 m_Iterator -= increment;
98 return *this;
99 }
100
Francis Murtagh43aec582019-05-27 12:14:10 +0100101 TypedIterator& operator[](const unsigned int index) override
102 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100103 BOOST_ASSERT(m_Iterator);
Francis Murtagh43aec582019-05-27 12:14:10 +0100104 m_Iterator = m_Start + index;
105 return *this;
106 }
107
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000108 TypedIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) override
109 {
110 boost::ignore_unused(axisIndex);
111 BOOST_ASSERT(m_Iterator);
112 m_Iterator = m_Start + index;
113 return *this;
114 }
115
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100116protected:
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100117 T* m_Iterator;
Francis Murtagh43aec582019-05-27 12:14:10 +0100118 T* m_Start;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100119};
120
Derek Lambertif30f7d32019-04-09 10:25:02 +0100121class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100122{
123public:
124 QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset)
125 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
126
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100127 QASymm8Decoder(const float scale, const int32_t offset)
128 : QASymm8Decoder(nullptr, scale, offset) {}
129
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100130 float Get() const override
131 {
132 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
133 }
134
135private:
136 const float m_Scale;
137 const int32_t m_Offset;
138};
139
Finn Williamsfd271062019-12-04 14:27:27 +0000140class QSymmS8Decoder : public TypedIterator<const int8_t, Decoder<float>>
141{
142public:
143 QSymmS8Decoder(const int8_t* data, const float scale, const int32_t offset)
144 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
145
146 QSymmS8Decoder(const float scale, const int32_t offset)
147 : QSymmS8Decoder(nullptr, scale, offset) {}
148
149 float Get() const override
150 {
151 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
152 }
153
154private:
155 const float m_Scale;
156 const int32_t m_Offset;
157};
158
Derek Lambertif30f7d32019-04-09 10:25:02 +0100159class QSymm16Decoder : public TypedIterator<const int16_t, Decoder<float>>
Sadik Armagan2999a022019-04-09 14:20:12 +0100160{
161public:
162 QSymm16Decoder(const int16_t* data, const float scale, const int32_t offset)
163 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
164
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100165 QSymm16Decoder(const float scale, const int32_t offset)
166 : QSymm16Decoder(nullptr, scale, offset) {}
167
Sadik Armagan2999a022019-04-09 14:20:12 +0100168 float Get() const override
169 {
170 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
171 }
172
173private:
174 const float m_Scale;
175 const int32_t m_Offset;
176};
177
Matthew Jacksone69c3992019-09-09 14:31:21 +0100178class Float16Decoder : public TypedIterator<const Half, Decoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100179{
180public:
Matthew Jacksone69c3992019-09-09 14:31:21 +0100181 Float16Decoder(const Half* data)
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100182 : TypedIterator(data) {}
183
Matthew Jacksone69c3992019-09-09 14:31:21 +0100184 Float16Decoder()
185 : Float16Decoder(nullptr) {}
186
187 float Get() const override
188 {
189 float val = 0.f;
190 armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
191 return val;
192 }
193};
194
195class Float32Decoder : public TypedIterator<const float, Decoder<float>>
196{
197public:
198 Float32Decoder(const float* data)
199 : TypedIterator(data) {}
200
201 Float32Decoder()
202 : Float32Decoder(nullptr) {}
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100203
Derek Lambertif30f7d32019-04-09 10:25:02 +0100204 float Get() const override
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100205 {
Derek Lambertif30f7d32019-04-09 10:25:02 +0100206 return *m_Iterator;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100207 }
208};
209
Mike Kelly9b398322019-05-22 17:21:49 +0100210class ScaledInt32Decoder : public TypedIterator<const int32_t, Decoder<float>>
211{
212public:
213 ScaledInt32Decoder(const int32_t* data, const float scale)
214 : TypedIterator(data), m_Scale(scale) {}
215
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100216 ScaledInt32Decoder(const float scale)
217 : ScaledInt32Decoder(nullptr, scale) {}
218
Mike Kelly9b398322019-05-22 17:21:49 +0100219 float Get() const override
220 {
221 return static_cast<float>(*m_Iterator) * m_Scale;
222 }
223
224private:
225 const float m_Scale;
226};
227
Aron Virginas-Tar198ee402019-08-02 18:54:28 +0100228class Int32Decoder : public TypedIterator<const int32_t, Decoder<float>>
229{
230public:
231 Int32Decoder(const int32_t* data)
232 : TypedIterator(data) {}
233
234 Int32Decoder()
235 : Int32Decoder(nullptr) {}
236
237 float Get() const override
238 {
239 return static_cast<float>(*m_Iterator);
240 }
241};
242
Derek Lambertif30f7d32019-04-09 10:25:02 +0100243class QASymm8Encoder : public TypedIterator<uint8_t, Encoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100244{
245public:
246 QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset)
247 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
248
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100249 QASymm8Encoder(const float scale, const int32_t offset)
250 : QASymm8Encoder(nullptr, scale, offset) {}
251
Derek Lambertif30f7d32019-04-09 10:25:02 +0100252 void Set(float right) override
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100253 {
254 *m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
255 }
256
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100257 float Get() const override
258 {
259 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
260 }
261
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100262private:
263 const float m_Scale;
264 const int32_t m_Offset;
265};
266
Finn Williamsfd271062019-12-04 14:27:27 +0000267class QSymmS8Encoder : public TypedIterator<int8_t, Encoder<float>>
268{
269public:
270 QSymmS8Encoder(int8_t* data, const float scale, const int32_t offset)
271 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
272
273 QSymmS8Encoder(const float scale, const int32_t offset)
274 : QSymmS8Encoder(nullptr, scale, offset) {}
275
276 void Set(float right) override
277 {
278 *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale, m_Offset);
279 }
280
281 float Get() const override
282 {
283 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
284 }
285
286private:
287 const float m_Scale;
288 const int32_t m_Offset;
289};
290
Derek Lambertif30f7d32019-04-09 10:25:02 +0100291class QSymm16Encoder : public TypedIterator<int16_t, Encoder<float>>
292{
293public:
294 QSymm16Encoder(int16_t* data, const float scale, const int32_t offset)
295 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
296
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100297 QSymm16Encoder(const float scale, const int32_t offset)
298 : QSymm16Encoder(nullptr, scale, offset) {}
299
Derek Lambertif30f7d32019-04-09 10:25:02 +0100300 void Set(float right) override
301 {
302 *m_Iterator = armnn::Quantize<int16_t>(right, m_Scale, m_Offset);
303 }
304
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100305 float Get() const override
306 {
307 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
308 }
309
Derek Lambertif30f7d32019-04-09 10:25:02 +0100310private:
311 const float m_Scale;
312 const int32_t m_Offset;
313};
314
Matthew Jacksone69c3992019-09-09 14:31:21 +0100315class Float16Encoder : public TypedIterator<Half, Encoder<float>>
Derek Lambertif30f7d32019-04-09 10:25:02 +0100316{
317public:
Matthew Jacksone69c3992019-09-09 14:31:21 +0100318 Float16Encoder(Half* data)
Derek Lambertif30f7d32019-04-09 10:25:02 +0100319 : TypedIterator(data) {}
320
Matthew Jacksone69c3992019-09-09 14:31:21 +0100321 Float16Encoder()
322 : Float16Encoder(nullptr) {}
323
324 void Set(float right) override
325 {
326 armnnUtils::FloatingPointConverter::ConvertFloat32To16(&right, 1, m_Iterator);
327 }
328
329 float Get() const override
330 {
331 float val = 0.f;
332 armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
333 return val;
334 }
335};
336
337class Float32Encoder : public TypedIterator<float, Encoder<float>>
338{
339public:
340 Float32Encoder(float* data)
341 : TypedIterator(data) {}
342
343 Float32Encoder()
344 : Float32Encoder(nullptr) {}
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100345
Derek Lambertif30f7d32019-04-09 10:25:02 +0100346 void Set(float right) override
347 {
348 *m_Iterator = right;
349 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100350
351 float Get() const override
352 {
353 return *m_Iterator;
354 }
Derek Lambertif30f7d32019-04-09 10:25:02 +0100355};
356
Aron Virginas-Tar198ee402019-08-02 18:54:28 +0100357class Int32Encoder : public TypedIterator<int32_t, Encoder<float>>
358{
359public:
360 Int32Encoder(int32_t* data)
361 : TypedIterator(data) {}
362
363 Int32Encoder()
364 : Int32Encoder(nullptr) {}
365
366 void Set(float right) override
367 {
368 *m_Iterator = static_cast<int32_t>(right);
369 }
370
371 float Get() const override
372 {
373 return static_cast<float>(*m_Iterator);
374 }
375};
376
Derek Lambertif30f7d32019-04-09 10:25:02 +0100377class BooleanEncoder : public TypedIterator<uint8_t, Encoder<bool>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100378{
379public:
380 BooleanEncoder(uint8_t* data)
381 : TypedIterator(data) {}
382
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100383 BooleanEncoder()
384 : BooleanEncoder(nullptr) {}
385
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100386 void Set(bool right) override
387 {
388 *m_Iterator = right;
389 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100390
391 bool Get() const override
392 {
393 return *m_Iterator;
394 }
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100395};
396
Keith Davis5236e1d2019-11-04 08:58:33 +0000397// PerAxisIterator for per-axis quantization
398template<typename T, typename Base>
399class PerAxisIterator : public Base
400{
401public:
402 // axisFactor is used to calculate axisIndex
403 PerAxisIterator(T* data = nullptr, unsigned int axisFactor = 0)
404 : m_Iterator(data), m_Start(data), m_AxisIndex(0), m_AxisFactor(axisFactor)
405 {}
406
407 // This should be called to set index for per-axis Encoder/Decoder
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000408 PerAxisIterator& SetIndex(unsigned int index, unsigned int axisIndex) override
Keith Davis5236e1d2019-11-04 08:58:33 +0000409 {
410 BOOST_ASSERT(m_Iterator);
411 m_Iterator = m_Start + index;
412 m_AxisIndex = axisIndex;
413 return *this;
414 }
415
416 void Reset(void* data) override
417 {
418 m_Iterator = reinterpret_cast<T*>(data);
419 m_Start = m_Iterator;
420 m_AxisIndex = 0;
421 }
422
423 PerAxisIterator& operator++() override
424 {
425 BOOST_ASSERT(m_Iterator);
426 ++m_Iterator;
427 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
428 return *this;
429 }
430
431 PerAxisIterator& operator+=(const unsigned int increment) override
432 {
433 BOOST_ASSERT(m_Iterator);
434 m_Iterator += increment;
435 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
436 return *this;
437 }
438
439 PerAxisIterator& operator-=(const unsigned int decrement) override
440 {
441 BOOST_ASSERT(m_Iterator);
442 m_Iterator -= decrement;
443 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
444 return *this;
445 }
446
447 PerAxisIterator& operator[](const unsigned int index) override
448 {
449 BOOST_ASSERT(m_Iterator);
450 m_Iterator = m_Start + index;
451 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
452 return *this;
453 }
454
455 protected:
456 T* m_Iterator;
457 T* m_Start;
458 unsigned int m_AxisIndex;
459 unsigned int m_AxisFactor;
460};
461
462class QSymm8PerAxisDecoder : public PerAxisIterator<const int8_t, Decoder<float>>
463{
464public:
465 QSymm8PerAxisDecoder(const int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
466 : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
467
468 float Get() const override
469 {
470 return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
471 }
472
473 // Get scale of the current value
474 float GetScale() const
475 {
476 return m_Scale[m_AxisIndex];
477 }
478
479private:
480 std::vector<float> m_Scale;
481};
482
483class QSymm8PerAxisEncoder : public PerAxisIterator<int8_t, Encoder<float>>
484{
485public:
486 QSymm8PerAxisEncoder(int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
487 : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
488
489 void Set(float right)
490 {
491 *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale[m_AxisIndex], 0);
492 }
493
494 float Get() const
495 {
496 return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
497 }
498
499 // Get scale of the current value
500 float GetScale() const
501 {
502 return m_Scale[m_AxisIndex];
503 }
504
505private:
506 std::vector<float> m_Scale;
507};
508
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000509class ScaledInt32PerAxisDecoder : public PerAxisIterator<const int32_t, Decoder<float>>
510{
511public:
512 ScaledInt32PerAxisDecoder(const int32_t* data, const std::vector<float>& scales, unsigned int axisFactor)
513 : PerAxisIterator(data, axisFactor), m_Scales(scales) {}
514
515 float Get() const override
516 {
517 return armnn::Dequantize(*m_Iterator, m_Scales[m_AxisIndex], 0);
518 }
519
520 // Get scale of the current value
521 float GetScale() const
522 {
523 return m_Scales[m_AxisIndex];
524 }
525
526private:
527 std::vector<float> m_Scales;
528};
529
530} // namespace armnn