blob: 6b28502bed60ff5e8108ffb1192e59006030e2eb [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
James Ward8b390432022-08-12 20:48:56 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OP_TEMPLATE_TYPES_H
17#define OP_TEMPLATE_TYPES_H
18
19#include "tosa_generated.h"
20#include <Eigen/CXX11/Tensor>
James Ward8b390432022-08-12 20:48:56 +010021#include "half.hpp"
James Ward24dbc422022-10-19 12:20:31 +010022#include <Eigen/Core>
23#include "arith_util.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070024
25using namespace tosa;
26
27namespace TosaReference
28{
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010029// Shorter alias templates for common Eigen::Tensor types
Eric Kunzee5e26762020-10-13 16:11:07 -070030template <typename T>
31using ETensor0 = Eigen::Tensor<T, 0>;
32template <typename T>
33using ETensor1 = Eigen::Tensor<T, 1>;
34template <typename T>
35using ETensor2 = Eigen::Tensor<T, 2>;
36template <typename T>
37using ETensor3 = Eigen::Tensor<T, 3>;
38template <typename T>
39using ETensor4 = Eigen::Tensor<T, 4>;
40template <typename T>
41using ETensor5 = Eigen::Tensor<T, 5>;
42template <typename T>
43using ETensor6 = Eigen::Tensor<T, 6>;
44
45// Forward declaration
46template <class T>
47class TensorTemplate;
48
49// Shortcut to hide the TensorTemplate class.
50// For example, declare Tensor1<float> to get a TensorTemplate
51// with an Eigen::Tensor<float, 1>
52template <typename T>
53using Tensor0 = TensorTemplate<ETensor0<T>>;
54template <typename T>
55using Tensor1 = TensorTemplate<ETensor1<T>>;
56template <typename T>
57using Tensor2 = TensorTemplate<ETensor2<T>>;
58template <typename T>
59using Tensor3 = TensorTemplate<ETensor3<T>>;
60template <typename T>
61using Tensor4 = TensorTemplate<ETensor4<T>>;
62template <typename T>
63using Tensor5 = TensorTemplate<ETensor5<T>>;
64template <typename T>
65using Tensor6 = TensorTemplate<ETensor6<T>>;
66
67template <DType type>
68struct GetEigenType;
69template <>
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010070struct GetEigenType<DType_FP32>
Eric Kunzee5e26762020-10-13 16:11:07 -070071{
72 using type = float;
73};
74template <>
James Ward8b390432022-08-12 20:48:56 +010075struct GetEigenType<DType_FP16>
76{
77 // NOTE: full precision used
78 using type = float;
79};
80template <>
James Ward24dbc422022-10-19 12:20:31 +010081struct GetEigenType<DType_BF16>
82{
83 // NOTE: full precision used
84 using type = float;
85};
86template <>
Eric Kunzee5e26762020-10-13 16:11:07 -070087struct GetEigenType<DType_INT32>
88{
89 using type = int32_t;
90};
91template <>
92struct GetEigenType<DType_INT48>
93{
94 using type = int64_t;
95};
96template <>
97struct GetEigenType<DType_BOOL>
98{
99 using type = bool;
100};
101template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700102struct GetEigenType<DType_UINT8>
103{
104 using type = int32_t;
105};
106template <>
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100107struct GetEigenType<DType_UINT16>
108{
109 using type = int32_t;
110};
111template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700112struct GetEigenType<DType_INT4>
113{
114 using type = int32_t;
115};
116template <>
117struct GetEigenType<DType_INT8>
118{
119 using type = int32_t;
120};
121template <>
122struct GetEigenType<DType_INT16>
123{
124 using type = int32_t;
125};
126
James Ward8b390432022-08-12 20:48:56 +0100127/* Get Accumulate Eigen Type:
128Same behaviour as GetEigenType for all DTypes except the
129single specialised case of DType_FP16. */
130template <DType Dtype>
131struct GetAccEigenType;
132template <>
133struct GetAccEigenType<DType_FP16>
134{
135 using type = half_float::half;
136};
137template <DType Dtype>
138struct GetAccEigenType
139{
140 using type = typename GetEigenType<Dtype>::type;
141};
142
Eric Kunzee5e26762020-10-13 16:11:07 -0700143// Meta function to get number of bits
144template <DType T>
145struct GetNumBits
146{
147 static constexpr int32_t value = 0;
148};
149template <>
150struct GetNumBits<DType_BOOL>
151{
152 static constexpr int32_t value = 1;
153};
154template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700155struct GetNumBits<DType_UINT8>
156{
157 static constexpr int32_t value = 8;
158};
159template <>
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100160struct GetNumBits<DType_UINT16>
161{
162 static constexpr int32_t value = 16;
163};
164template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700165struct GetNumBits<DType_INT4>
166{
167 static constexpr int32_t value = 4;
168};
169template <>
170struct GetNumBits<DType_INT8>
171{
172 static constexpr int32_t value = 8;
173};
174template <>
175struct GetNumBits<DType_INT16>
176{
177 static constexpr int32_t value = 16;
178};
179template <>
180struct GetNumBits<DType_INT32>
181{
182 static constexpr int32_t value = 32;
183};
184template <>
185struct GetNumBits<DType_INT48>
186{
187 static constexpr int32_t value = 48;
188};
James Ward8b390432022-08-12 20:48:56 +0100189template <>
190struct GetNumBits<DType_FP16>
191{
192 static constexpr int32_t value = 16;
193};
Eric Kunzee5e26762020-10-13 16:11:07 -0700194
195// Meta function to get quantized min/max in compile time
196template <DType T>
197struct GetQMin
198{
199 static constexpr int64_t value = 0L;
200};
201template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700202struct GetQMin<DType_UINT8>
203{
204 static constexpr int64_t value = 0L;
205};
206template <>
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100207struct GetQMin<DType_UINT16>
208{
209 static constexpr int64_t value = 0L;
210};
211template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700212struct GetQMin<DType_INT4>
213{
214 static constexpr int64_t value = -8L;
215};
216template <>
217struct GetQMin<DType_INT8>
218{
219 static constexpr int64_t value = -128L;
220};
221template <>
222struct GetQMin<DType_INT16>
223{
224 static constexpr int64_t value = -32768L;
225};
226template <>
227struct GetQMin<DType_INT32>
228{
229 static constexpr int64_t value = -(1L << 31);
230};
231template <>
232struct GetQMin<DType_INT48>
233{
234 static constexpr int64_t value = -(1L << 47);
235};
236
237template <DType T>
238struct GetQMax
239{
240 static constexpr int64_t value = 0L;
241};
242template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700243struct GetQMax<DType_UINT8>
244{
245 static constexpr int64_t value = 255L;
246};
247template <>
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100248struct GetQMax<DType_UINT16>
249{
250 static constexpr int64_t value = 65535L;
251};
252template <>
Eric Kunzee5e26762020-10-13 16:11:07 -0700253struct GetQMax<DType_INT4>
254{
255 static constexpr int64_t value = 7L;
256};
257template <>
258struct GetQMax<DType_INT8>
259{
260 static constexpr int64_t value = 127L;
261};
262template <>
263struct GetQMax<DType_INT16>
264{
265 static constexpr int64_t value = 32767L;
266};
267template <>
268struct GetQMax<DType_INT32>
269{
270 static constexpr int64_t value = (1L << 31) - 1;
271};
272template <>
273struct GetQMax<DType_INT48>
274{
275 static constexpr int64_t value = (1L << 47) - 1;
276};
277
Eric Kunzee5e26762020-10-13 16:11:07 -0700278}; // namespace TosaReference
279
280#endif