blob: f5304a557e8f50ae546fa371d0c4e6e0fa27ee11 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
James Ward8b390432022-08-12 20:48:56 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "data_nodes.h"
17
18using namespace TosaReference;
19using namespace Eigen;
20using namespace tosa;
21
Kevin Chengacb550f2021-06-29 15:32:19 -070022OpConst::OpConst(SubgraphTraverser* sgt_, uint64_t id_)
23 : GraphNode(sgt_, Op_CONST, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070024{
25 setRequiredOperands(0, 1);
26}
27
28OpConst::~OpConst()
29{}
30
31int OpConst::checkTensorAttributes()
32{
33 if (validateRequiredOperands())
34 return 1;
35
36 return 0;
37}
38
39int OpConst::eval()
40{
41 // Evaluation is trivial for constants
42 return GraphNode::eval();
43}
44
Eric Kunzee5e26762020-10-13 16:11:07 -070045template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070046OpIdentity<Rank, Dtype>::OpIdentity(SubgraphTraverser* sgt_,
47 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -070048 uint64_t id_)
49 : GraphNode(sgt_, Op_IDENTITY, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070050{
51 setRequiredOperands(1, 1);
52 setRequiredRank(0, 6);
53}
54
55template <int Rank, DType Dtype>
56OpIdentity<Rank, Dtype>::~OpIdentity()
57{}
58
59template <int Rank, DType Dtype>
60int OpIdentity<Rank, Dtype>::checkTensorAttributes()
61{
62
63 if (inputs.size() != outputs.size())
64 {
65 printNodeValidationError("Input and output tensor list lengths are not equal");
66 return 1;
67 }
68
69 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
70 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
71
72 if (in->matchRankTypeShape(*out))
73 {
74 printNodeValidationError("Input and output tensor rank, type, or shape do not match");
75 return 1;
76 }
77
78 return 0;
79}
80
81template <int Rank, DType Dtype>
82int OpIdentity<Rank, Dtype>::eval()
83{
84 out->getTensor() = in->getTensor();
85
86 return GraphNode::eval();
87}
88
Eric Kunzee5e26762020-10-13 16:11:07 -070089// template explicit instantiation
Kevin Cheng14d7f7a2021-05-12 10:44:49 -070090// note OpConst is not templated
Eric Kunzee5e26762020-10-13 16:11:07 -070091
James Ward8b390432022-08-12 20:48:56 +010092DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16);
James Ward24dbc422022-10-19 12:20:31 +010093DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010094DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -070095DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
96DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
97DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
98DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);