blob: cfa8ce7e91d3bc4a21134140b362e3b75c0a4e9d [file] [log] [blame]
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/ArmNN.hpp>
9#include <TypeUtils.hpp>
10
11namespace armnn
12{
13
14class BaseIterator
15{
16public:
17 BaseIterator() {}
18
19 virtual ~BaseIterator() {}
20
21 virtual BaseIterator& operator++() = 0;
22
23 virtual BaseIterator& operator+=(const unsigned int increment) = 0;
24
25 virtual BaseIterator& operator-=(const unsigned int increment) = 0;
26};
27
28class Decoder : public BaseIterator
29{
30public:
31 Decoder() : BaseIterator() {}
32
33 virtual ~Decoder() {}
34
35 virtual float Get() const = 0;
36};
37
38class Encoder : public BaseIterator
39{
40public:
41 Encoder() : BaseIterator() {}
42
43 virtual ~Encoder() {}
44
45 virtual void Set(const float& right) = 0;
46};
47
48class ComparisonEncoder : public BaseIterator
49{
50public:
51 ComparisonEncoder() : BaseIterator() {}
52
53 virtual ~ComparisonEncoder() {}
54
55 virtual void Set(bool right) = 0;
56};
57
58template<typename T, typename Base>
59class TypedIterator : public Base
60{
61public:
62 TypedIterator(T* data)
63 : m_Iterator(data)
64 {}
65
66 TypedIterator& operator++() override
67 {
68 ++m_Iterator;
69 return *this;
70 }
71
72 TypedIterator& operator+=(const unsigned int increment) override
73 {
74 m_Iterator += increment;
75 return *this;
76 }
77
78 TypedIterator& operator-=(const unsigned int increment) override
79 {
80 m_Iterator -= increment;
81 return *this;
82 }
83
84 T* m_Iterator;
85};
86
87class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder>
88{
89public:
90 QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset)
91 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
92
93 float Get() const override
94 {
95 return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
96 }
97
98private:
99 const float m_Scale;
100 const int32_t m_Offset;
101};
102
103class FloatDecoder : public TypedIterator<const float, Decoder>
104{
105public:
106 FloatDecoder(const float* data)
107 : TypedIterator(data) {}
108
109 float Get() const override
110 {
111 return *m_Iterator;
112 }
113};
114
115class FloatEncoder : public TypedIterator<float, Encoder>
116{
117public:
118 FloatEncoder(float* data)
119 : TypedIterator(data) {}
120
121 void Set(const float& right) override
122 {
123 *m_Iterator = right;
124 }
125};
126
127class QASymm8Encoder : public TypedIterator<uint8_t, Encoder>
128{
129public:
130 QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset)
131 : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
132
133 void Set(const float& right) override
134 {
135 *m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
136 }
137
138private:
139 const float m_Scale;
140 const int32_t m_Offset;
141};
142
143class BooleanEncoder : public TypedIterator<uint8_t, ComparisonEncoder>
144{
145public:
146 BooleanEncoder(uint8_t* data)
147 : TypedIterator(data) {}
148
149 void Set(bool right) override
150 {
151 *m_Iterator = right;
152 }
153};
154
155} //namespace armnn