blob: ca5110c2fdbf29c355cbf0df547e8eb2a790d0bc [file] [log] [blame]
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/ArmNN.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00009
10#include <armnnUtils/FloatingPointConverter.hpp>
11
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010012#include <ResolveType.hpp>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010013
Matthew Benthamc394a6d2019-06-24 12:51:25 +010014#include <boost/assert.hpp>
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +000015#include <boost/core/ignore_unused.hpp>
Matthew Benthamc394a6d2019-06-24 12:51:25 +010016
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010017namespace armnn
18{
19
20class BaseIterator
21{
22public:
23 BaseIterator() {}
24
25 virtual ~BaseIterator() {}
26
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +000027 virtual BaseIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) = 0;
28
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010029 virtual BaseIterator& operator++() = 0;
30
31 virtual BaseIterator& operator+=(const unsigned int increment) = 0;
32
33 virtual BaseIterator& operator-=(const unsigned int increment) = 0;
Francis Murtagh43aec582019-05-27 12:14:10 +010034
35 virtual BaseIterator& operator[](const unsigned int index) = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010036};
37
Derek Lambertif30f7d32019-04-09 10:25:02 +010038template<typename IType>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010039class Decoder : public BaseIterator
40{
41public:
Derek Lambertif30f7d32019-04-09 10:25:02 +010042 Decoder() {}
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010043
44 virtual ~Decoder() {}
45
Matthew Benthamc394a6d2019-06-24 12:51:25 +010046 virtual void Reset(void*) = 0;
47
Derek Lambertif30f7d32019-04-09 10:25:02 +010048 virtual IType Get() const = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010049};
50
Derek Lambertif30f7d32019-04-09 10:25:02 +010051template<typename IType>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010052class Encoder : public BaseIterator
53{
54public:
Derek Lambertif30f7d32019-04-09 10:25:02 +010055 Encoder() {}
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010056
57 virtual ~Encoder() {}
58
Matthew Benthamc394a6d2019-06-24 12:51:25 +010059 virtual void Reset(void*) = 0;
60
Derek Lambertif30f7d32019-04-09 10:25:02 +010061 virtual void Set(IType right) = 0;
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010062
63 virtual IType Get() const = 0;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010064};
65
66template<typename T, typename Base>
67class TypedIterator : public Base
68{
69public:
Matthew Benthamc394a6d2019-06-24 12:51:25 +010070 TypedIterator(T* data = nullptr)
Francis Murtagh43aec582019-05-27 12:14:10 +010071 : m_Iterator(data), m_Start(data)
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010072 {}
73
Matthew Benthamc394a6d2019-06-24 12:51:25 +010074 void Reset(void* data) override
75 {
76 m_Iterator = reinterpret_cast<T*>(data);
77 m_Start = m_Iterator;
78 }
79
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010080 TypedIterator& operator++() override
81 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010082 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010083 ++m_Iterator;
84 return *this;
85 }
86
87 TypedIterator& operator+=(const unsigned int increment) override
88 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010089 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010090 m_Iterator += increment;
91 return *this;
92 }
93
94 TypedIterator& operator-=(const unsigned int increment) override
95 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +010096 BOOST_ASSERT(m_Iterator);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010097 m_Iterator -= increment;
98 return *this;
99 }
100
Francis Murtagh43aec582019-05-27 12:14:10 +0100101 TypedIterator& operator[](const unsigned int index) override
102 {
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100103 BOOST_ASSERT(m_Iterator);
Francis Murtagh43aec582019-05-27 12:14:10 +0100104 m_Iterator = m_Start + index;
105 return *this;
106 }
107
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000108 TypedIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) override
109 {
110 boost::ignore_unused(axisIndex);
111 BOOST_ASSERT(m_Iterator);
112 m_Iterator = m_Start + index;
113 return *this;
114 }
115
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100116protected:
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100117 T* m_Iterator;
Francis Murtagh43aec582019-05-27 12:14:10 +0100118 T* m_Start;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100119};
120
Derek Lambertif30f7d32019-04-09 10:25:02 +0100121class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100122{
123public:
124 QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset)
125 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
126
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100127 QASymm8Decoder(const float scale, const int32_t offset)
128 : QASymm8Decoder(nullptr, scale, offset) {}
129
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100130 float Get() const override
131 {
132 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
133 }
134
135private:
136 const float m_Scale;
137 const int32_t m_Offset;
138};
139
Derek Lambertif30f7d32019-04-09 10:25:02 +0100140class QSymm16Decoder : public TypedIterator<const int16_t, Decoder<float>>
Sadik Armagan2999a022019-04-09 14:20:12 +0100141{
142public:
143 QSymm16Decoder(const int16_t* data, const float scale, const int32_t offset)
144 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
145
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100146 QSymm16Decoder(const float scale, const int32_t offset)
147 : QSymm16Decoder(nullptr, scale, offset) {}
148
Sadik Armagan2999a022019-04-09 14:20:12 +0100149 float Get() const override
150 {
151 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
152 }
153
154private:
155 const float m_Scale;
156 const int32_t m_Offset;
157};
158
Matthew Jacksone69c3992019-09-09 14:31:21 +0100159class Float16Decoder : public TypedIterator<const Half, Decoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100160{
161public:
Matthew Jacksone69c3992019-09-09 14:31:21 +0100162 Float16Decoder(const Half* data)
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100163 : TypedIterator(data) {}
164
Matthew Jacksone69c3992019-09-09 14:31:21 +0100165 Float16Decoder()
166 : Float16Decoder(nullptr) {}
167
168 float Get() const override
169 {
170 float val = 0.f;
171 armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
172 return val;
173 }
174};
175
176class Float32Decoder : public TypedIterator<const float, Decoder<float>>
177{
178public:
179 Float32Decoder(const float* data)
180 : TypedIterator(data) {}
181
182 Float32Decoder()
183 : Float32Decoder(nullptr) {}
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100184
Derek Lambertif30f7d32019-04-09 10:25:02 +0100185 float Get() const override
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100186 {
Derek Lambertif30f7d32019-04-09 10:25:02 +0100187 return *m_Iterator;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100188 }
189};
190
Mike Kelly9b398322019-05-22 17:21:49 +0100191class ScaledInt32Decoder : public TypedIterator<const int32_t, Decoder<float>>
192{
193public:
194 ScaledInt32Decoder(const int32_t* data, const float scale)
195 : TypedIterator(data), m_Scale(scale) {}
196
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100197 ScaledInt32Decoder(const float scale)
198 : ScaledInt32Decoder(nullptr, scale) {}
199
Mike Kelly9b398322019-05-22 17:21:49 +0100200 float Get() const override
201 {
202 return static_cast<float>(*m_Iterator) * m_Scale;
203 }
204
205private:
206 const float m_Scale;
207};
208
Aron Virginas-Tar198ee402019-08-02 18:54:28 +0100209class Int32Decoder : public TypedIterator<const int32_t, Decoder<float>>
210{
211public:
212 Int32Decoder(const int32_t* data)
213 : TypedIterator(data) {}
214
215 Int32Decoder()
216 : Int32Decoder(nullptr) {}
217
218 float Get() const override
219 {
220 return static_cast<float>(*m_Iterator);
221 }
222};
223
Derek Lambertif30f7d32019-04-09 10:25:02 +0100224class QASymm8Encoder : public TypedIterator<uint8_t, Encoder<float>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100225{
226public:
227 QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset)
228 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
229
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100230 QASymm8Encoder(const float scale, const int32_t offset)
231 : QASymm8Encoder(nullptr, scale, offset) {}
232
Derek Lambertif30f7d32019-04-09 10:25:02 +0100233 void Set(float right) override
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100234 {
235 *m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
236 }
237
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100238 float Get() const override
239 {
240 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
241 }
242
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100243private:
244 const float m_Scale;
245 const int32_t m_Offset;
246};
247
Derek Lambertif30f7d32019-04-09 10:25:02 +0100248class QSymm16Encoder : public TypedIterator<int16_t, Encoder<float>>
249{
250public:
251 QSymm16Encoder(int16_t* data, const float scale, const int32_t offset)
252 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
253
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100254 QSymm16Encoder(const float scale, const int32_t offset)
255 : QSymm16Encoder(nullptr, scale, offset) {}
256
Derek Lambertif30f7d32019-04-09 10:25:02 +0100257 void Set(float right) override
258 {
259 *m_Iterator = armnn::Quantize<int16_t>(right, m_Scale, m_Offset);
260 }
261
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100262 float Get() const override
263 {
264 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
265 }
266
Derek Lambertif30f7d32019-04-09 10:25:02 +0100267private:
268 const float m_Scale;
269 const int32_t m_Offset;
270};
271
Matthew Jacksone69c3992019-09-09 14:31:21 +0100272class Float16Encoder : public TypedIterator<Half, Encoder<float>>
Derek Lambertif30f7d32019-04-09 10:25:02 +0100273{
274public:
Matthew Jacksone69c3992019-09-09 14:31:21 +0100275 Float16Encoder(Half* data)
Derek Lambertif30f7d32019-04-09 10:25:02 +0100276 : TypedIterator(data) {}
277
Matthew Jacksone69c3992019-09-09 14:31:21 +0100278 Float16Encoder()
279 : Float16Encoder(nullptr) {}
280
281 void Set(float right) override
282 {
283 armnnUtils::FloatingPointConverter::ConvertFloat32To16(&right, 1, m_Iterator);
284 }
285
286 float Get() const override
287 {
288 float val = 0.f;
289 armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
290 return val;
291 }
292};
293
294class Float32Encoder : public TypedIterator<float, Encoder<float>>
295{
296public:
297 Float32Encoder(float* data)
298 : TypedIterator(data) {}
299
300 Float32Encoder()
301 : Float32Encoder(nullptr) {}
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100302
Derek Lambertif30f7d32019-04-09 10:25:02 +0100303 void Set(float right) override
304 {
305 *m_Iterator = right;
306 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100307
308 float Get() const override
309 {
310 return *m_Iterator;
311 }
Derek Lambertif30f7d32019-04-09 10:25:02 +0100312};
313
Aron Virginas-Tar198ee402019-08-02 18:54:28 +0100314class Int32Encoder : public TypedIterator<int32_t, Encoder<float>>
315{
316public:
317 Int32Encoder(int32_t* data)
318 : TypedIterator(data) {}
319
320 Int32Encoder()
321 : Int32Encoder(nullptr) {}
322
323 void Set(float right) override
324 {
325 *m_Iterator = static_cast<int32_t>(right);
326 }
327
328 float Get() const override
329 {
330 return static_cast<float>(*m_Iterator);
331 }
332};
333
Derek Lambertif30f7d32019-04-09 10:25:02 +0100334class BooleanEncoder : public TypedIterator<uint8_t, Encoder<bool>>
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100335{
336public:
337 BooleanEncoder(uint8_t* data)
338 : TypedIterator(data) {}
339
Matthew Benthamc394a6d2019-06-24 12:51:25 +0100340 BooleanEncoder()
341 : BooleanEncoder(nullptr) {}
342
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100343 void Set(bool right) override
344 {
345 *m_Iterator = right;
346 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100347
348 bool Get() const override
349 {
350 return *m_Iterator;
351 }
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100352};
353
Keith Davis5236e1d2019-11-04 08:58:33 +0000354// PerAxisIterator for per-axis quantization
355template<typename T, typename Base>
356class PerAxisIterator : public Base
357{
358public:
359 // axisFactor is used to calculate axisIndex
360 PerAxisIterator(T* data = nullptr, unsigned int axisFactor = 0)
361 : m_Iterator(data), m_Start(data), m_AxisIndex(0), m_AxisFactor(axisFactor)
362 {}
363
364 // This should be called to set index for per-axis Encoder/Decoder
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000365 PerAxisIterator& SetIndex(unsigned int index, unsigned int axisIndex) override
Keith Davis5236e1d2019-11-04 08:58:33 +0000366 {
367 BOOST_ASSERT(m_Iterator);
368 m_Iterator = m_Start + index;
369 m_AxisIndex = axisIndex;
370 return *this;
371 }
372
373 void Reset(void* data) override
374 {
375 m_Iterator = reinterpret_cast<T*>(data);
376 m_Start = m_Iterator;
377 m_AxisIndex = 0;
378 }
379
380 PerAxisIterator& operator++() override
381 {
382 BOOST_ASSERT(m_Iterator);
383 ++m_Iterator;
384 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
385 return *this;
386 }
387
388 PerAxisIterator& operator+=(const unsigned int increment) override
389 {
390 BOOST_ASSERT(m_Iterator);
391 m_Iterator += increment;
392 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
393 return *this;
394 }
395
396 PerAxisIterator& operator-=(const unsigned int decrement) override
397 {
398 BOOST_ASSERT(m_Iterator);
399 m_Iterator -= decrement;
400 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
401 return *this;
402 }
403
404 PerAxisIterator& operator[](const unsigned int index) override
405 {
406 BOOST_ASSERT(m_Iterator);
407 m_Iterator = m_Start + index;
408 m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
409 return *this;
410 }
411
412 protected:
413 T* m_Iterator;
414 T* m_Start;
415 unsigned int m_AxisIndex;
416 unsigned int m_AxisFactor;
417};
418
419class QSymm8PerAxisDecoder : public PerAxisIterator<const int8_t, Decoder<float>>
420{
421public:
422 QSymm8PerAxisDecoder(const int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
423 : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
424
425 float Get() const override
426 {
427 return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
428 }
429
430 // Get scale of the current value
431 float GetScale() const
432 {
433 return m_Scale[m_AxisIndex];
434 }
435
436private:
437 std::vector<float> m_Scale;
438};
439
440class QSymm8PerAxisEncoder : public PerAxisIterator<int8_t, Encoder<float>>
441{
442public:
443 QSymm8PerAxisEncoder(int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
444 : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
445
446 void Set(float right)
447 {
448 *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale[m_AxisIndex], 0);
449 }
450
451 float Get() const
452 {
453 return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
454 }
455
456 // Get scale of the current value
457 float GetScale() const
458 {
459 return m_Scale[m_AxisIndex];
460 }
461
462private:
463 std::vector<float> m_Scale;
464};
465
Aron Virginas-Tarb67f9572019-11-04 15:00:19 +0000466class ScaledInt32PerAxisDecoder : public PerAxisIterator<const int32_t, Decoder<float>>
467{
468public:
469 ScaledInt32PerAxisDecoder(const int32_t* data, const std::vector<float>& scales, unsigned int axisFactor)
470 : PerAxisIterator(data, axisFactor), m_Scales(scales) {}
471
472 float Get() const override
473 {
474 return armnn::Dequantize(*m_Iterator, m_Scales[m_AxisIndex], 0);
475 }
476
477 // Get scale of the current value
478 float GetScale() const
479 {
480 return m_Scales[m_AxisIndex];
481 }
482
483private:
484 std::vector<float> m_Scales;
485};
486
487} // namespace armnn