blob: 5861cb2ee1ba807a790bda66f227219c651a4f1f [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Tai Lya4d748b2023-03-28 22:06:56 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_ternary.h"
17
18using namespace TosaReference;
19using namespace Eigen;
20using namespace tosa;
21
Tai Lya4d748b2023-03-28 22:06:56 +000022template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +000023OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -070024 : GraphNode(sgt_, Op_SELECT, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070025{
26 setRequiredOperands(3, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -070027}
28
Tai Lya4d748b2023-03-28 22:06:56 +000029template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070030OpSelectBase<Rank, Dtype>::~OpSelectBase()
31{}
32
Tai Lya4d748b2023-03-28 22:06:56 +000033template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070034int OpSelectBase<Rank, Dtype>::checkTensorAttributes()
35{
Jerry Gea793f462023-04-11 00:05:02 +000036 // Check Tosa Level
37 auto tosa_level = g_func_config.tosa_level;
38 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
39
Eric Kunzee5e26762020-10-13 16:11:07 -070040 if (validateRequiredOperands())
41 return 1;
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 // output and input must be the same types
Kevin Cheng1c3c8472021-11-08 11:19:10 -080044 if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */) ||
45 inputs[1]->matchRankTypeShape(*outputs[0], true /* broadcastOk */) ||
46 inputs[2]->matchRankTypeShape(*outputs[0], true /* broadcastOk */))
Eric Kunzee5e26762020-10-13 16:11:07 -070047 {
Kevin Cheng1c3c8472021-11-08 11:19:10 -080048 printNodeValidationError("Failure to match input and output rank/type/shape");
Eric Kunzee5e26762020-10-13 16:11:07 -070049 return 1;
50 }
51
52 cond = dynamic_cast<TosaReference::TensorTemplate<TCond>*>(inputs[0]);
53 then_val = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
54 else_val = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[2]);
55 out = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(outputs[0]);
56
57 return 0;
58}
59
Tai Lya4d748b2023-03-28 22:06:56 +000060template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070061int OpSelectBase<Rank, Dtype>::eval()
62{
Kevin Chengacb550f2021-06-29 15:32:19 -070063 FATAL_ERROR("shouldn't be called");
Eric Kunzee5e26762020-10-13 16:11:07 -070064}
65
Tai Lya4d748b2023-03-28 22:06:56 +000066template <int Rank, TOSA_REF_TYPE Dtype>
Jerry Ge135c9552023-05-23 20:59:32 +000067int OpSelect<Rank, Dtype>::broadcast(std::vector<int>& calculated_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -070068{
Kevin Cheng1c3c8472021-11-08 11:19:10 -080069 const std::vector<int>& cond_shape = this->cond->getShape();
70 const std::vector<int>& then_shape = this->then_val->getShape();
71 const std::vector<int>& else_shape = this->else_val->getShape();
72 const std::vector<int>& output_shape = this->out->getShape();
Eric Kunzee5e26762020-10-13 16:11:07 -070073
Jerry Ge135c9552023-05-23 20:59:32 +000074 // calculates the multipliers for Eigen
Eric Kunzee5e26762020-10-13 16:11:07 -070075 for (int i = 0; i < Rank; i++)
76 {
Kevin Cheng1c3c8472021-11-08 11:19:10 -080077 this->bcast_cond[i] = (cond_shape[i] != output_shape[i] && cond_shape[i] == 1) ? output_shape[i] : 1;
78 this->bcast_then[i] = (then_shape[i] != output_shape[i] && then_shape[i] == 1) ? output_shape[i] : 1;
79 this->bcast_else[i] = (else_shape[i] != output_shape[i] && else_shape[i] == 1) ? output_shape[i] : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -070080 }
81
Jerry Ge135c9552023-05-23 20:59:32 +000082 // calculates the broadcasted output shape
83 calculated_shape = cond_shape;
Jerry Ge9c9c8da2023-07-19 23:08:16 +000084 for (size_t i = 0; i < calculated_shape.size(); i++)
85 {
86 if (calculated_shape[i] == 1)
87 {
Jerry Ge135c9552023-05-23 20:59:32 +000088 calculated_shape[i] = then_shape[i];
Jerry Ge9c9c8da2023-07-19 23:08:16 +000089 }
90 else
91 {
92 ERROR_IF(then_shape[i] != 1 && then_shape[i] != calculated_shape[i],
93 "Broadcast_shape failure, input shapes are not compatible");
Jerry Ge135c9552023-05-23 20:59:32 +000094 }
95
Jerry Ge9c9c8da2023-07-19 23:08:16 +000096 if (calculated_shape[i] == 1)
97 {
Jerry Ge135c9552023-05-23 20:59:32 +000098 calculated_shape[i] = else_shape[i];
Jerry Ge9c9c8da2023-07-19 23:08:16 +000099 }
100 else
101 {
102 ERROR_IF(else_shape[i] != 1 && else_shape[i] != calculated_shape[i],
103 "Broadcast_shape failure, input shapes are not compatible");
Jerry Ge135c9552023-05-23 20:59:32 +0000104 }
105 }
106
Eric Kunzee5e26762020-10-13 16:11:07 -0700107 return 0;
108}
109
Tai Lya4d748b2023-03-28 22:06:56 +0000110template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700111int OpSelect<Rank, Dtype>::eval()
112{
Jerry Ge135c9552023-05-23 20:59:32 +0000113 std::vector<int> calculated_shape;
114 this->broadcast(calculated_shape);
115
116 auto result_shape = this->out->getShape();
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000117 ERROR_IF(calculated_shape != result_shape,
118 "Broadcast_shape failure, calculated_shape and result_shape don't match");
Jerry Ge135c9552023-05-23 20:59:32 +0000119
Eric Kunzee5e26762020-10-13 16:11:07 -0700120 this->out->getTensor() = this->cond->getTensor()
121 .broadcast(this->bcast_cond)
122 .select(this->then_val->getTensor().broadcast(this->bcast_then),
123 this->else_val->getTensor().broadcast(this->bcast_else));
124
125 return GraphNode::eval();
126}
127
Tai Lya4d748b2023-03-28 22:06:56 +0000128template <TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700129int OpSelect<0, Dtype>::eval()
130{
131 this->out->getTensor() = this->cond->getTensor().select(this->then_val->getTensor(), this->else_val->getTensor());
132
133 return GraphNode::eval();
134}
135
136// template explicit instantiation
Jared Smolens98c281f2022-12-20 15:09:25 -0800137DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, FP16);
138DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, BF16);
139DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, FP32);
140DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT8);
141DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT16);
142DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT32);
143DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000144DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, FP64);
Jared Smolens98c281f2022-12-20 15:09:25 -0800145
James Ward8b390432022-08-12 20:48:56 +0100146DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100147DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100148DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700149DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
150DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
151DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
152DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL);
Tai Lya4d748b2023-03-28 22:06:56 +0000153DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP64);