blob: d4845f9dfc504209293d77784acf57b27021d3de [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "ewise_ternary.h"
17
18using namespace TosaReference;
19using namespace Eigen;
20using namespace tosa;
21
22template <int Rank, DType Dtype>
23OpSelectBase<Rank, Dtype>::OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
24 : GraphNode(Op_SELECT, id_)
25{
26 setRequiredOperands(3, 1);
27 setRequiredRank(0, 6);
28}
29
30template <int Rank, DType Dtype>
31OpSelectBase<Rank, Dtype>::~OpSelectBase()
32{}
33
34template <int Rank, DType Dtype>
35int OpSelectBase<Rank, Dtype>::checkTensorAttributes()
36{
37 if (validateRequiredOperands())
38 return 1;
39
40 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(inputs[2]) ||
41 validateRequiredRank(outputs[0]))
42 {
43 return 1;
44 }
45
46 // output and input must be the same types
47 if (inputs[0]->matchRank(*outputs[0]) || inputs[1]->matchRankType(*outputs[0]) ||
48 inputs[2]->matchRankType(*outputs[0]))
49 {
50 printNodeValidationError("Failure to match input and output rank and type");
51 return 1;
52 }
53
54 cond = dynamic_cast<TosaReference::TensorTemplate<TCond>*>(inputs[0]);
55 then_val = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
56 else_val = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[2]);
57 out = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(outputs[0]);
58
59 return 0;
60}
61
62template <int Rank, DType Dtype>
63int OpSelectBase<Rank, Dtype>::eval()
64{
65 FATAL_ERROR_NODE("shouldn't be called");
66}
67
68template <int Rank, DType Dtype>
69int OpSelect<Rank, Dtype>::broadcast()
70{
71 std::vector<int> cond_shape = this->cond->getShape();
72 std::vector<int> then_shape = this->then_val->getShape();
73 std::vector<int> else_shape = this->else_val->getShape();
74 std::vector<int> out_shape = this->out->getShape();
75
76 for (int i = 0; i < Rank; i++)
77 {
78 this->bcast_cond[i] = (cond_shape[i] == 1) ? std::max(then_shape[i], else_shape[i]) : 1;
79 this->bcast_then[i] = (then_shape[i] == 1) ? std::max(cond_shape[i], else_shape[i]) : 1;
80 this->bcast_else[i] = (else_shape[i] == 1) ? std::max(then_shape[i], cond_shape[i]) : 1;
81 ASSERT_MSG_NODE((this->bcast_cond[i] * cond_shape[i]) == out_shape[i], "SELECT broadcast invariant failed");
82 ASSERT_MSG_NODE((this->bcast_then[i] * then_shape[i]) == out_shape[i], "SELECT broadcast invariant failed");
83 ASSERT_MSG_NODE((this->bcast_else[i] * else_shape[i]) == out_shape[i], "SELECT broadcast invariant failed");
84 }
85
86 return 0;
87}
88
89template <int Rank, DType Dtype>
90int OpSelect<Rank, Dtype>::eval()
91{
92 this->broadcast();
93 this->out->getTensor() = this->cond->getTensor()
94 .broadcast(this->bcast_cond)
95 .select(this->then_val->getTensor().broadcast(this->bcast_then),
96 this->else_val->getTensor().broadcast(this->bcast_else));
97
98 return GraphNode::eval();
99}
100
101template <DType Dtype>
102int OpSelect<0, Dtype>::eval()
103{
104 this->out->getTensor() = this->cond->getTensor().select(this->then_val->getTensor(), this->else_val->getTensor());
105
106 return GraphNode::eval();
107}
108
109// template explicit instantiation
110DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT);
Eric Kunzee5e26762020-10-13 16:11:07 -0700111DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
112DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
113DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
114DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL);