blob: 1859e03cce6095c46a447fa6ec6a7570c2f52bfd [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OP_TEMPLATE_TYPES_H
17#define OP_TEMPLATE_TYPES_H
18
19#include "tosa_generated.h"
20#include <Eigen/CXX11/Tensor>
21
22using namespace tosa;
23
24namespace TosaReference
25{
26// Shorter aliase templates for common Eigen::Tensor types
27template <typename T>
28using ETensor0 = Eigen::Tensor<T, 0>;
29template <typename T>
30using ETensor1 = Eigen::Tensor<T, 1>;
31template <typename T>
32using ETensor2 = Eigen::Tensor<T, 2>;
33template <typename T>
34using ETensor3 = Eigen::Tensor<T, 3>;
35template <typename T>
36using ETensor4 = Eigen::Tensor<T, 4>;
37template <typename T>
38using ETensor5 = Eigen::Tensor<T, 5>;
39template <typename T>
40using ETensor6 = Eigen::Tensor<T, 6>;
41
42// Forward declaration
43template <class T>
44class TensorTemplate;
45
46// Shortcut to hide the TensorTemplate class.
47// For example, declare Tensor1<float> to get a TensorTemplate
48// with an Eigen::Tensor<float, 1>
49template <typename T>
50using Tensor0 = TensorTemplate<ETensor0<T>>;
51template <typename T>
52using Tensor1 = TensorTemplate<ETensor1<T>>;
53template <typename T>
54using Tensor2 = TensorTemplate<ETensor2<T>>;
55template <typename T>
56using Tensor3 = TensorTemplate<ETensor3<T>>;
57template <typename T>
58using Tensor4 = TensorTemplate<ETensor4<T>>;
59template <typename T>
60using Tensor5 = TensorTemplate<ETensor5<T>>;
61template <typename T>
62using Tensor6 = TensorTemplate<ETensor6<T>>;
63
64template <DType type>
65struct GetEigenType;
66template <>
67struct GetEigenType<DType_FLOAT>
68{
69 using type = float;
70};
71template <>
72struct GetEigenType<DType_INT32>
73{
74 using type = int32_t;
75};
76template <>
77struct GetEigenType<DType_INT48>
78{
79 using type = int64_t;
80};
81template <>
82struct GetEigenType<DType_BOOL>
83{
84 using type = bool;
85};
86template <>
87struct GetEigenType<DType_AINT8>
88{
89 using type = int32_t;
90};
91template <>
92struct GetEigenType<DType_UINT8>
93{
94 using type = int32_t;
95};
96template <>
97struct GetEigenType<DType_INT4>
98{
99 using type = int32_t;
100};
101template <>
102struct GetEigenType<DType_INT8>
103{
104 using type = int32_t;
105};
106template <>
107struct GetEigenType<DType_INT16>
108{
109 using type = int32_t;
110};
111
112// Meta function to get number of bits
113template <DType T>
114struct GetNumBits
115{
116 static constexpr int32_t value = 0;
117};
118template <>
119struct GetNumBits<DType_BOOL>
120{
121 static constexpr int32_t value = 1;
122};
123template <>
124struct GetNumBits<DType_AINT8>
125{
126 static constexpr int32_t value = 8;
127};
128template <>
129struct GetNumBits<DType_UINT8>
130{
131 static constexpr int32_t value = 8;
132};
133template <>
134struct GetNumBits<DType_INT4>
135{
136 static constexpr int32_t value = 4;
137};
138template <>
139struct GetNumBits<DType_INT8>
140{
141 static constexpr int32_t value = 8;
142};
143template <>
144struct GetNumBits<DType_INT16>
145{
146 static constexpr int32_t value = 16;
147};
148template <>
149struct GetNumBits<DType_INT32>
150{
151 static constexpr int32_t value = 32;
152};
153template <>
154struct GetNumBits<DType_INT48>
155{
156 static constexpr int32_t value = 48;
157};
158
159// Meta function to get quantized min/max in compile time
160template <DType T>
161struct GetQMin
162{
163 static constexpr int64_t value = 0L;
164};
165template <>
166struct GetQMin<DType_AINT8>
167{
168 static constexpr int64_t value = -128L;
169};
170template <>
171struct GetQMin<DType_UINT8>
172{
173 static constexpr int64_t value = 0L;
174};
175template <>
176struct GetQMin<DType_INT4>
177{
178 static constexpr int64_t value = -8L;
179};
180template <>
181struct GetQMin<DType_INT8>
182{
183 static constexpr int64_t value = -128L;
184};
185template <>
186struct GetQMin<DType_INT16>
187{
188 static constexpr int64_t value = -32768L;
189};
190template <>
191struct GetQMin<DType_INT32>
192{
193 static constexpr int64_t value = -(1L << 31);
194};
195template <>
196struct GetQMin<DType_INT48>
197{
198 static constexpr int64_t value = -(1L << 47);
199};
200
201template <DType T>
202struct GetQMax
203{
204 static constexpr int64_t value = 0L;
205};
206template <>
207struct GetQMax<DType_AINT8>
208{
209 static constexpr int64_t value = 127L;
210};
211template <>
212struct GetQMax<DType_UINT8>
213{
214 static constexpr int64_t value = 255L;
215};
216template <>
217struct GetQMax<DType_INT4>
218{
219 static constexpr int64_t value = 7L;
220};
221template <>
222struct GetQMax<DType_INT8>
223{
224 static constexpr int64_t value = 127L;
225};
226template <>
227struct GetQMax<DType_INT16>
228{
229 static constexpr int64_t value = 32767L;
230};
231template <>
232struct GetQMax<DType_INT32>
233{
234 static constexpr int64_t value = (1L << 31) - 1;
235};
236template <>
237struct GetQMax<DType_INT48>
238{
239 static constexpr int64_t value = (1L << 47) - 1;
240};
241
242template <DType TIn1, DType TIn2>
243struct GetAccDType;
244template <>
245struct GetAccDType<DType_AINT8, DType_AINT8>
246{
247 static constexpr DType value = DType_INT32;
248};
249template <>
250struct GetAccDType<DType_AINT8, DType_INT4>
251{
252 static constexpr DType value = DType_INT32;
253};
254template <>
255struct GetAccDType<DType_AINT8, DType_INT8>
256{
257 static constexpr DType value = DType_INT32;
258};
259template <>
260struct GetAccDType<DType_INT16, DType_INT8>
261{
262 static constexpr DType value = DType_INT48;
263};
264template <>
265struct GetAccDType<DType_INT16, DType_INT16>
266{
267 static constexpr DType value = DType_INT48;
268};
269template <>
270struct GetAccDType<DType_FLOAT, DType_FLOAT>
271{
272 static constexpr DType value = DType_FLOAT;
273};
274
275}; // namespace TosaReference
276
277#endif