blob: 020ddb5b155de79137619f3598a03c522883f22c [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_EWISE_BINARY_H
17#define OPS_EWISE_BINARY_H
18
19#include "graph_node.h"
20
21using namespace tosa;
22
23namespace TosaReference
24{
25
26// class BinaryNodeBase: hold common functions of all the binary nodes
27// when an binary op is created, the virtual OpXXX::register_fcn() will be called
28// and 'fcn' will be register with lambda function which has two inputs
29// class BinaryNode: the level of indirection to partially specialize template for rank 0
30// eval() from toplevel called should call the .binaryExpr(dims, fcn) here
31// this needs to be partially specialize or
32// compiler will statically fail when trying to broadcast rank0 tensor
33// class OpXXX: implement per-element lambda function based on different data type
34// unlike BinaryNode, this doesn't need to be partially specialized
35
36// Eigen::Tensor does support some binary element-wise natively (e.g. CWiseMax, or '+', etc.)
37// which might be faster since it could be implemented with SIMD instructions
38// the way of registering lambda + .binaryExpr() might sacrifice performance here
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010039// but it can avoid partially specialization for combination of {rankN, rank0} x {FP32/INT32, QU8, ...}
Eric Kunzee5e26762020-10-13 16:11:07 -070040// needs to revisit if performance becomes a bottleneck here
41template <int Rank, DType InDtype, DType OutDtype>
42class BinaryNodeBase : public GraphNode
43{
44public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000045 BinaryNodeBase(SubgraphTraverser* sgt_, const Op& nodeType, const uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070046 virtual ~BinaryNodeBase();
47
48 virtual int checkTensorAttributes() final;
49 virtual int eval() = 0;
50 virtual int register_fcn() = 0;
51
52 using InEigenType = typename GetEigenType<InDtype>::type;
53 using OutEigenType = typename GetEigenType<OutDtype>::type;
54 using TIn = Eigen::Tensor<InEigenType, Rank>;
55 using TOut = Eigen::Tensor<OutEigenType, Rank>;
56
57protected:
58 int broadcast();
59
60protected:
61 std::function<OutEigenType(InEigenType, InEigenType)> fcn;
62 Eigen::array<int, Rank> bcast_a;
63 Eigen::array<int, Rank> bcast_b;
64 TosaReference::TensorTemplate<TIn>* a;
65 TosaReference::TensorTemplate<TIn>* b;
Eric Kunzee5e26762020-10-13 16:11:07 -070066 TosaReference::TensorTemplate<TOut>* result;
Eric Kunzee5e26762020-10-13 16:11:07 -070067};
68
69// primary class
70template <int Rank, DType InDtype, DType OutDtype>
71class BinaryNode : public BinaryNodeBase<Rank, InDtype, OutDtype>
72{
73public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000074 BinaryNode(SubgraphTraverser* sgt_, const Op& op_, const uint64_t id_)
75 : BinaryNodeBase<Rank, InDtype, OutDtype>(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070076 {}
77 virtual ~BinaryNode()
78 {}
79
80 virtual int eval();
81
82 using InEigenType = typename GetEigenType<InDtype>::type;
83 using OutEigenType = typename GetEigenType<OutDtype>::type;
84 using TIn = Eigen::Tensor<InEigenType, Rank>;
85 using TOut = Eigen::Tensor<OutEigenType, Rank>;
86};
87
88// partial specialization for rank 0
89template <DType InDtype, DType OutDtype>
90class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype>
91{
92public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000093 BinaryNode(SubgraphTraverser* sgt_, const Op& op_, const uint64_t id_)
94 : BinaryNodeBase<0, InDtype, OutDtype>(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070095 {}
96 virtual ~BinaryNode()
97 {}
98
99 virtual int eval();
100};
101
Kevin Chengaee1fac2020-11-11 13:54:06 -0800102#define DEF_TEMPLATE_BINARY_OP_DEFAULT(Opname, OPNAME) \
Eric Kunzee5e26762020-10-13 16:11:07 -0700103 template <int Rank, DType Dtype> \
104 class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \
105 { \
106 public: \
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000107 Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
108 : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, id_) \
Eric Kunzee5e26762020-10-13 16:11:07 -0700109 { \
110 register_fcn(); \
111 } \
112 static constexpr DType InDtype = Dtype; \
113 static constexpr DType OutDtype = Dtype; \
114 using InEigenType = typename GetEigenType<InDtype>::type; \
115 using OutEigenType = typename GetEigenType<OutDtype>::type; \
116 virtual int register_fcn(); \
117 };
118
Kevin Chengaee1fac2020-11-11 13:54:06 -0800119DEF_TEMPLATE_BINARY_OP_DEFAULT(Add, ADD)
120DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseAnd, BITWISE_AND)
121DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseOr, BITWISE_OR)
122DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseXor, BITWISE_XOR)
Matthew Haddon459443c2021-08-23 16:43:13 +0100123DEF_TEMPLATE_BINARY_OP_DEFAULT(Intdiv, INTDIV)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800124DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalAnd, LOGICAL_AND)
125DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
126DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
127DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalOr, LOGICAL_OR)
128DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalXor, LOGICAL_XOR)
129DEF_TEMPLATE_BINARY_OP_DEFAULT(Maximum, MAXIMUM)
130DEF_TEMPLATE_BINARY_OP_DEFAULT(Minimum, MINIMUM)
131DEF_TEMPLATE_BINARY_OP_DEFAULT(Pow, POW)
132DEF_TEMPLATE_BINARY_OP_DEFAULT(Sub, SUB)
Eric Kunzee5e26762020-10-13 16:11:07 -0700133
Kevin Chengaee1fac2020-11-11 13:54:06 -0800134#undef DEF_TEMPLATE_BINARY_OP_DEFAULT
Eric Kunzee5e26762020-10-13 16:11:07 -0700135
Kevin Chengaee1fac2020-11-11 13:54:06 -0800136template <int Rank, DType Dtype>
137class OpArithmeticRightShift : public BinaryNode<Rank, Dtype, Dtype>
138{
139public:
Kevin Chengacb550f2021-06-29 15:32:19 -0700140 OpArithmeticRightShift(SubgraphTraverser* sgt_,
141 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700142 uint64_t id_)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000143 : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_ARITHMETIC_RIGHT_SHIFT, id_)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800144 {
145 INIT_ATTRIBUTE(ArithmeticRightShift);
146 register_fcn();
147 }
148 using InEigenType = typename GetEigenType<Dtype>::type;
149 using OutEigenType = typename GetEigenType<Dtype>::type;
150 virtual int register_fcn();
Jerry Gea6827492022-11-16 10:41:55 -0800151 virtual ~OpArithmeticRightShift();
Kevin Chengaee1fac2020-11-11 13:54:06 -0800152
153protected:
154 TosaArithmeticRightShiftAttribute* attribute;
155};
156
157template <int Rank, DType InDtype, DType OutDtype>
158class OpMul : public BinaryNode<Rank, InDtype, OutDtype>
159{
160public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000161 OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
162 : BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, id_)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800163 {
164 INIT_ATTRIBUTE(Mul);
165 register_fcn();
166 }
Jerry Gea6827492022-11-16 10:41:55 -0800167 virtual ~OpMul();
Kevin Chengaee1fac2020-11-11 13:54:06 -0800168 static constexpr int64_t QMin = GetQMin<OutDtype>::value;
169 static constexpr int64_t QMax = GetQMax<OutDtype>::value;
170 using InEigenType = typename GetEigenType<InDtype>::type;
171 using OutEigenType = typename GetEigenType<OutDtype>::type;
172 virtual int register_fcn();
173
174protected:
175 TosaMulAttribute* attribute;
176};
Eric Kunzee5e26762020-10-13 16:11:07 -0700177
Kevin Cheng571f7182021-05-24 17:20:01 -0700178template <int Rank, DType InDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700179class OpTable : public GraphNode
180{
181public:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000182 OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700183 virtual ~OpTable();
184
185 virtual int checkTensorAttributes();
186 virtual int eval();
187
Kevin Chengfe392ce2021-10-18 21:51:55 +0000188 static constexpr DType TableDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT16;
189 static constexpr DType OutDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT32;
190 static constexpr uint32_t TableNumEntries = (InDtype == DType_INT8) ? 256 : 513;
191 using InEigenType = typename GetEigenType<InDtype>::type;
192 using TableEigenType = typename GetEigenType<TableDtype>::type;
193 using OutEigenType = typename GetEigenType<OutDtype>::type;
194 using TIn = Eigen::Tensor<InEigenType, Rank>;
195 using TTable = Eigen::Tensor<TableEigenType, 1>;
196 using TOut = Eigen::Tensor<OutEigenType, Rank>;
197 static constexpr int32_t IntegerBits = 9;
198 static constexpr int32_t FractionBits = 7;
199 static constexpr int32_t NumTableEntries = (1 << IntegerBits);
200 static constexpr int32_t QInMin = GetQMin<InDtype>::value;
201 static constexpr int32_t QInMax = GetQMax<InDtype>::value;
202 static constexpr int32_t QOutMin = GetQMin<OutDtype>::value;
203 static constexpr int32_t QOutMax = GetQMax<OutDtype>::value;
Eric Kunzee5e26762020-10-13 16:11:07 -0700204
205protected:
206 TosaReference::TensorTemplate<TIn>* in;
Eric Kunzee5e26762020-10-13 16:11:07 -0700207 TosaReference::TensorTemplate<TOut>* out;
Kevin Chengfe392ce2021-10-18 21:51:55 +0000208 TosaTableAttribute* attribute;
209 std::array<TableEigenType, TableNumEntries> table;
Eric Kunzee5e26762020-10-13 16:11:07 -0700210};
211
212}; // namespace TosaReference
213
214#endif