blob: 060e14e53c78c0ce9ea4e548a335b993f3f3f91e [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_TYPE_CONVERSION_H
17#define OPS_TYPE_CONVERSION_H
18
19#include "graph_node.h"
20
21using namespace tosa;
22
23namespace TosaReference
24{
25template <int Rank, DType InDtype, DType OutDtype>
26class OpRescale : public GraphNode
27{
28public:
Kevin Chengacb550f2021-06-29 15:32:19 -070029 OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070030 virtual ~OpRescale();
31
32 virtual int checkTensorAttributes() final;
33 virtual int eval() final;
34
35 using InEigenType = typename GetEigenType<InDtype>::type;
36 using OutEigenType = typename GetEigenType<OutDtype>::type;
37 using TIn = Eigen::Tensor<InEigenType, Rank>;
38 using TOut = Eigen::Tensor<OutEigenType, Rank>;
39
40 static constexpr int32_t QMin = GetQMin<OutDtype>::value;
41 static constexpr int32_t QMax = GetQMax<OutDtype>::value;
42
43protected:
44 TosaRescaleAttribute* attribute;
45 TosaReference::TensorTemplate<TIn>* in;
46 TosaReference::TensorTemplate<TOut>* out;
47};
48
49template <DType InDtype, DType OutDtype>
50class CastHelper
51{
52public:
53 using InEigenType = typename GetEigenType<InDtype>::type;
54 using OutEigenType = typename GetEigenType<OutDtype>::type;
55 using FcnType = std::function<OutEigenType(InEigenType)>;
56 static constexpr int32_t OutBits = GetNumBits<OutDtype>::value;
57 CastHelper();
58 const FcnType& get_fcn() const
59 {
60 return fcn;
61 }
62
63private:
64 FcnType fcn;
65};
66
67template <DType InDtype>
68class CastHelper<InDtype, DType_BOOL>
69{
70public:
71 using InEigenType = typename GetEigenType<InDtype>::type;
72 using OutEigenType = typename GetEigenType<DType_BOOL>::type;
73 using FcnType = std::function<OutEigenType(InEigenType)>;
74 CastHelper();
75 const FcnType& get_fcn() const
76 {
77 return fcn;
78 }
79
80private:
81 FcnType fcn;
82};
83
84template <DType OutDtype>
85class CastHelper<DType_BOOL, OutDtype>
86{
87public:
88 using InEigenType = typename GetEigenType<DType_BOOL>::type;
89 using OutEigenType = typename GetEigenType<OutDtype>::type;
90 using FcnType = std::function<OutEigenType(InEigenType)>;
91 static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
92 static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
93 CastHelper();
94 const FcnType& get_fcn() const
95 {
96 return fcn;
97 }
98
99private:
100 FcnType fcn;
101};
102
103template <DType InDtype>
104class CastHelper<InDtype, DType_FLOAT>
105{
106public:
107 using InEigenType = typename GetEigenType<InDtype>::type;
108 using OutEigenType = typename GetEigenType<DType_FLOAT>::type;
109 using FcnType = std::function<OutEigenType(InEigenType)>;
110 CastHelper();
111 const FcnType& get_fcn() const
112 {
113 return fcn;
114 }
115
116private:
117 FcnType fcn;
118};
119
120template <DType OutDtype>
121class CastHelper<DType_FLOAT, OutDtype>
122{
123public:
124 using InEigenType = typename GetEigenType<DType_FLOAT>::type;
125 using OutEigenType = typename GetEigenType<OutDtype>::type;
126 using FcnType = std::function<OutEigenType(InEigenType)>;
127 static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
128 static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
129 CastHelper();
130 const FcnType& get_fcn() const
131 {
132 return fcn;
133 }
134
135private:
136 FcnType fcn;
137};
138
139template <int Rank, DType InDtype, DType OutDtype>
140class OpCast : public GraphNode
141{
142public:
Kevin Chengacb550f2021-06-29 15:32:19 -0700143 OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700144 virtual ~OpCast();
145
146 virtual int checkTensorAttributes() final;
147 virtual int eval() final;
148
149 using InEigenType = typename GetEigenType<InDtype>::type;
150 using OutEigenType = typename GetEigenType<OutDtype>::type;
151 using TIn = Eigen::Tensor<InEigenType, Rank>;
152 using TOut = Eigen::Tensor<OutEigenType, Rank>;
153
154protected:
155 CastHelper<InDtype, OutDtype> cast_helper;
156 TosaReference::TensorTemplate<TIn>* in;
157 TosaReference::TensorTemplate<TOut>* out;
158};
159
160}; // namespace TosaReference
161
162#endif