blob: e2fc6e2053066948828ad1953bc6aa99f8e0ef96 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
James Ward8b390432022-08-12 20:48:56 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_TYPE_CONVERSION_H
17#define OPS_TYPE_CONVERSION_H
18
19#include "graph_node.h"
20
21using namespace tosa;
22
23namespace TosaReference
24{
25template <int Rank, DType InDtype, DType OutDtype>
26class OpRescale : public GraphNode
27{
28public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000029 OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070030 virtual ~OpRescale();
31
32 virtual int checkTensorAttributes() final;
33 virtual int eval() final;
34
35 using InEigenType = typename GetEigenType<InDtype>::type;
36 using OutEigenType = typename GetEigenType<OutDtype>::type;
37 using TIn = Eigen::Tensor<InEigenType, Rank>;
38 using TOut = Eigen::Tensor<OutEigenType, Rank>;
39
40 static constexpr int32_t QMin = GetQMin<OutDtype>::value;
41 static constexpr int32_t QMax = GetQMax<OutDtype>::value;
42
43protected:
44 TosaRescaleAttribute* attribute;
45 TosaReference::TensorTemplate<TIn>* in;
46 TosaReference::TensorTemplate<TOut>* out;
47};
48
49template <DType InDtype, DType OutDtype>
50class CastHelper
51{
52public:
53 using InEigenType = typename GetEigenType<InDtype>::type;
54 using OutEigenType = typename GetEigenType<OutDtype>::type;
55 using FcnType = std::function<OutEigenType(InEigenType)>;
56 static constexpr int32_t OutBits = GetNumBits<OutDtype>::value;
57 CastHelper();
58 const FcnType& get_fcn() const
59 {
60 return fcn;
61 }
62
63private:
64 FcnType fcn;
65};
66
67template <DType InDtype>
68class CastHelper<InDtype, DType_BOOL>
69{
70public:
71 using InEigenType = typename GetEigenType<InDtype>::type;
72 using OutEigenType = typename GetEigenType<DType_BOOL>::type;
73 using FcnType = std::function<OutEigenType(InEigenType)>;
74 CastHelper();
75 const FcnType& get_fcn() const
76 {
77 return fcn;
78 }
79
80private:
81 FcnType fcn;
82};
83
84template <DType OutDtype>
85class CastHelper<DType_BOOL, OutDtype>
86{
87public:
88 using InEigenType = typename GetEigenType<DType_BOOL>::type;
89 using OutEigenType = typename GetEigenType<OutDtype>::type;
90 using FcnType = std::function<OutEigenType(InEigenType)>;
91 static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
92 static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
93 CastHelper();
94 const FcnType& get_fcn() const
95 {
96 return fcn;
97 }
98
99private:
100 FcnType fcn;
101};
102
103template <DType InDtype>
James Ward8b390432022-08-12 20:48:56 +0100104class CastHelper<InDtype, DType_FP16>
105{
106public:
107 using InEigenType = typename GetEigenType<InDtype>::type;
108 using OutEigenType = typename GetEigenType<DType_FP16>::type;
109 using FcnType = std::function<OutEigenType(InEigenType)>;
110 CastHelper();
111 const FcnType& get_fcn() const
112 {
113 return fcn;
114 }
115
116private:
117 FcnType fcn;
118};
119
120template <DType OutDtype>
121class CastHelper<DType_FP16, OutDtype>
122{
123public:
124 using InEigenType = typename GetEigenType<DType_FP16>::type;
125 using OutEigenType = typename GetEigenType<OutDtype>::type;
126 using FcnType = std::function<OutEigenType(InEigenType)>;
127 static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
128 static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
129 CastHelper();
130 const FcnType& get_fcn() const
131 {
132 return fcn;
133 }
134
135private:
136 FcnType fcn;
137};
138
James Ward736fd1a2023-01-23 17:13:37 +0000139template <>
140class CastHelper<DType_FP32, DType_FP16>
141{
142public:
143 using InEigenType = typename GetEigenType<DType_FP32>::type;
144 using OutEigenType = typename GetEigenType<DType_FP16>::type;
145 using FcnType = std::function<OutEigenType(InEigenType)>;
146 CastHelper();
147 const FcnType& get_fcn() const
148 {
149 return fcn;
150 }
151
152private:
153 FcnType fcn;
154};
155
156template <DType InDtype>
157class CastHelper<InDtype, DType_BF16>
158{
159public:
160 using InEigenType = typename GetEigenType<InDtype>::type;
161 using OutEigenType = typename GetEigenType<DType_BF16>::type;
162 using FcnType = std::function<OutEigenType(InEigenType)>;
163 CastHelper();
164 const FcnType& get_fcn() const
165 {
166 return fcn;
167 }
168
169private:
170 FcnType fcn;
171};
172
173template <DType OutDtype>
174class CastHelper<DType_BF16, OutDtype>
175{
176public:
177 using InEigenType = typename GetEigenType<DType_BF16>::type;
178 using OutEigenType = typename GetEigenType<OutDtype>::type;
179 using FcnType = std::function<OutEigenType(InEigenType)>;
180 static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
181 static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
182 CastHelper();
183 const FcnType& get_fcn() const
184 {
185 return fcn;
186 }
187
188private:
189 FcnType fcn;
190};
191
192template <>
193class CastHelper<DType_FP32, DType_BF16>
194{
195public:
196 using InEigenType = typename GetEigenType<DType_FP32>::type;
197 using OutEigenType = typename GetEigenType<DType_BF16>::type;
198 using FcnType = std::function<OutEigenType(InEigenType)>;
199 CastHelper();
200 const FcnType& get_fcn() const
201 {
202 return fcn;
203 }
204
205private:
206 FcnType fcn;
207};
208
James Ward8b390432022-08-12 20:48:56 +0100209template <DType InDtype>
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100210class CastHelper<InDtype, DType_FP32>
Eric Kunzee5e26762020-10-13 16:11:07 -0700211{
212public:
213 using InEigenType = typename GetEigenType<InDtype>::type;
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100214 using OutEigenType = typename GetEigenType<DType_FP32>::type;
Eric Kunzee5e26762020-10-13 16:11:07 -0700215 using FcnType = std::function<OutEigenType(InEigenType)>;
216 CastHelper();
217 const FcnType& get_fcn() const
218 {
219 return fcn;
220 }
221
222private:
223 FcnType fcn;
224};
225
James Ward736fd1a2023-01-23 17:13:37 +0000226template <>
227class CastHelper<DType_FP16, DType_FP32>
228{
229public:
230 using InEigenType = typename GetEigenType<DType_FP16>::type;
231 using OutEigenType = typename GetEigenType<DType_FP32>::type;
232 using FcnType = std::function<OutEigenType(InEigenType)>;
233 CastHelper();
234 const FcnType& get_fcn() const
235 {
236 return fcn;
237 }
238
239private:
240 FcnType fcn;
241};
242
243template <>
244class CastHelper<DType_BF16, DType_FP32>
245{
246public:
247 using InEigenType = typename GetEigenType<DType_BF16>::type;
248 using OutEigenType = typename GetEigenType<DType_FP32>::type;
249 using FcnType = std::function<OutEigenType(InEigenType)>;
250 CastHelper();
251 const FcnType& get_fcn() const
252 {
253 return fcn;
254 }
255
256private:
257 FcnType fcn;
258};
259
Eric Kunzee5e26762020-10-13 16:11:07 -0700260template <DType OutDtype>
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100261class CastHelper<DType_FP32, OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700262{
263public:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100264 using InEigenType = typename GetEigenType<DType_FP32>::type;
Eric Kunzee5e26762020-10-13 16:11:07 -0700265 using OutEigenType = typename GetEigenType<OutDtype>::type;
266 using FcnType = std::function<OutEigenType(InEigenType)>;
267 static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
268 static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
269 CastHelper();
270 const FcnType& get_fcn() const
271 {
272 return fcn;
273 }
274
275private:
276 FcnType fcn;
277};
278
279template <int Rank, DType InDtype, DType OutDtype>
280class OpCast : public GraphNode
281{
282public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000283 OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700284 virtual ~OpCast();
285
286 virtual int checkTensorAttributes() final;
287 virtual int eval() final;
288
289 using InEigenType = typename GetEigenType<InDtype>::type;
290 using OutEigenType = typename GetEigenType<OutDtype>::type;
291 using TIn = Eigen::Tensor<InEigenType, Rank>;
292 using TOut = Eigen::Tensor<OutEigenType, Rank>;
293
294protected:
295 CastHelper<InDtype, OutDtype> cast_helper;
296 TosaReference::TensorTemplate<TIn>* in;
297 TosaReference::TensorTemplate<TOut>* out;
298};
299
300}; // namespace TosaReference
301
302#endif