blob: c26507794c627ecb0a2218aea0cef13071f95534 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_ternary.h"
17
18using namespace TosaReference;
19using namespace Eigen;
20using namespace tosa;
21
22template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070023OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_,
24 TosaAttributeBase* attribute_,
25 TosaQuantInfoBase* qinfo_,
26 uint64_t id_)
27 : GraphNode(sgt_, Op_SELECT, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070028{
29 setRequiredOperands(3, 1);
30 setRequiredRank(0, 6);
31}
32
33template <int Rank, DType Dtype>
34OpSelectBase<Rank, Dtype>::~OpSelectBase()
35{}
36
37template <int Rank, DType Dtype>
38int OpSelectBase<Rank, Dtype>::checkTensorAttributes()
39{
40 if (validateRequiredOperands())
41 return 1;
42
43 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(inputs[2]) ||
44 validateRequiredRank(outputs[0]))
45 {
46 return 1;
47 }
48
49 // output and input must be the same types
Kevin Cheng1c3c8472021-11-08 11:19:10 -080050 if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */) ||
51 inputs[1]->matchRankTypeShape(*outputs[0], true /* broadcastOk */) ||
52 inputs[2]->matchRankTypeShape(*outputs[0], true /* broadcastOk */))
Eric Kunzee5e26762020-10-13 16:11:07 -070053 {
Kevin Cheng1c3c8472021-11-08 11:19:10 -080054 printNodeValidationError("Failure to match input and output rank/type/shape");
Eric Kunzee5e26762020-10-13 16:11:07 -070055 return 1;
56 }
57
58 cond = dynamic_cast<TosaReference::TensorTemplate<TCond>*>(inputs[0]);
59 then_val = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
60 else_val = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[2]);
61 out = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(outputs[0]);
62
63 return 0;
64}
65
66template <int Rank, DType Dtype>
67int OpSelectBase<Rank, Dtype>::eval()
68{
Kevin Chengacb550f2021-06-29 15:32:19 -070069 FATAL_ERROR("shouldn't be called");
Eric Kunzee5e26762020-10-13 16:11:07 -070070}
71
72template <int Rank, DType Dtype>
73int OpSelect<Rank, Dtype>::broadcast()
74{
Kevin Cheng1c3c8472021-11-08 11:19:10 -080075 const std::vector<int>& cond_shape = this->cond->getShape();
76 const std::vector<int>& then_shape = this->then_val->getShape();
77 const std::vector<int>& else_shape = this->else_val->getShape();
78 const std::vector<int>& output_shape = this->out->getShape();
Eric Kunzee5e26762020-10-13 16:11:07 -070079
80 for (int i = 0; i < Rank; i++)
81 {
Kevin Cheng1c3c8472021-11-08 11:19:10 -080082 this->bcast_cond[i] = (cond_shape[i] != output_shape[i] && cond_shape[i] == 1) ? output_shape[i] : 1;
83 this->bcast_then[i] = (then_shape[i] != output_shape[i] && then_shape[i] == 1) ? output_shape[i] : 1;
84 this->bcast_else[i] = (else_shape[i] != output_shape[i] && else_shape[i] == 1) ? output_shape[i] : 1;
Eric Kunzee5e26762020-10-13 16:11:07 -070085 }
86
87 return 0;
88}
89
90template <int Rank, DType Dtype>
91int OpSelect<Rank, Dtype>::eval()
92{
93 this->broadcast();
94 this->out->getTensor() = this->cond->getTensor()
95 .broadcast(this->bcast_cond)
96 .select(this->then_val->getTensor().broadcast(this->bcast_then),
97 this->else_val->getTensor().broadcast(this->bcast_else));
98
99 return GraphNode::eval();
100}
101
102template <DType Dtype>
103int OpSelect<0, Dtype>::eval()
104{
105 this->out->getTensor() = this->cond->getTensor().select(this->then_val->getTensor(), this->else_val->getTensor());
106
107 return GraphNode::eval();
108}
109
110// template explicit instantiation
111DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700112DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
113DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
114DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
115DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL);