blob: 4d53ae4566a6379deb935e38e2ae9f3508f2c18b [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
James Ward8b390432022-08-12 20:48:56 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_ternary.h"
17
18using namespace TosaReference;
19using namespace Eigen;
20using namespace tosa;
21
22template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070023OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_,
24 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -070025 uint64_t id_)
26 : GraphNode(sgt_, Op_SELECT, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070027{
28 setRequiredOperands(3, 1);
29 setRequiredRank(0, 6);
30}
31
32template <int Rank, DType Dtype>
33OpSelectBase<Rank, Dtype>::~OpSelectBase()
34{}
35
36template <int Rank, DType Dtype>
37int OpSelectBase<Rank, Dtype>::checkTensorAttributes()
38{
Jerry Gea793f462023-04-11 00:05:02 +000039 // Check Tosa Level
40 auto tosa_level = g_func_config.tosa_level;
41 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 if (validateRequiredOperands())
44 return 1;
45
46 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(inputs[2]) ||
47 validateRequiredRank(outputs[0]))
48 {
49 return 1;
50 }
51
52 // output and input must be the same types
Kevin Cheng1c3c8472021-11-08 11:19:10 -080053 if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */) ||
54 inputs[1]->matchRankTypeShape(*outputs[0], true /* broadcastOk */) ||
55 inputs[2]->matchRankTypeShape(*outputs[0], true /* broadcastOk */))
Eric Kunzee5e26762020-10-13 16:11:07 -070056 {
Kevin Cheng1c3c8472021-11-08 11:19:10 -080057 printNodeValidationError("Failure to match input and output rank/type/shape");
Eric Kunzee5e26762020-10-13 16:11:07 -070058 return 1;
59 }
60
61 cond = dynamic_cast<TosaReference::TensorTemplate<TCond>*>(inputs[0]);
62 then_val = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
63 else_val = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[2]);
64 out = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(outputs[0]);
65
66 return 0;
67}
68
69template <int Rank, DType Dtype>
70int OpSelectBase<Rank, Dtype>::eval()
71{
Kevin Chengacb550f2021-06-29 15:32:19 -070072 FATAL_ERROR("shouldn't be called");
Eric Kunzee5e26762020-10-13 16:11:07 -070073}
74
75template <int Rank, DType Dtype>
76int OpSelect<Rank, Dtype>::broadcast()
77{
Kevin Cheng1c3c8472021-11-08 11:19:10 -080078 const std::vector<int>& cond_shape = this->cond->getShape();
79 const std::vector<int>& then_shape = this->then_val->getShape();
80 const std::vector<int>& else_shape = this->else_val->getShape();
81 const std::vector<int>& output_shape = this->out->getShape();
Eric Kunzee5e26762020-10-13 16:11:07 -070082
83 for (int i = 0; i < Rank; i++)
84 {
Kevin Cheng1c3c8472021-11-08 11:19:10 -080085 this->bcast_cond[i] = (cond_shape[i] != output_shape[i] && cond_shape[i] == 1) ? output_shape[i] : 1;
86 this->bcast_then[i] = (then_shape[i] != output_shape[i] && then_shape[i] == 1) ? output_shape[i] : 1;
87 this->bcast_else[i] = (else_shape[i] != output_shape[i] && else_shape[i] == 1) ? output_shape[i] : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -070088 }
89
90 return 0;
91}
92
93template <int Rank, DType Dtype>
94int OpSelect<Rank, Dtype>::eval()
95{
96 this->broadcast();
97 this->out->getTensor() = this->cond->getTensor()
98 .broadcast(this->bcast_cond)
99 .select(this->then_val->getTensor().broadcast(this->bcast_then),
100 this->else_val->getTensor().broadcast(this->bcast_else));
101
102 return GraphNode::eval();
103}
104
105template <DType Dtype>
106int OpSelect<0, Dtype>::eval()
107{
108 this->out->getTensor() = this->cond->getTensor().select(this->then_val->getTensor(), this->else_val->getTensor());
109
110 return GraphNode::eval();
111}
112
113// template explicit instantiation
Jared Smolens98c281f2022-12-20 15:09:25 -0800114DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, FP16);
115DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, BF16);
116DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, FP32);
117DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT8);
118DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT16);
119DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT32);
120DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, BOOL);
121
James Ward8b390432022-08-12 20:48:56 +0100122DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100123DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100124DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700125DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
126DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
127DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
128DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL);