| // Copyright (c) 2023-2024, ARM Limited. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "shape.h" |
| |
| using namespace TosaReference; |
| using namespace Eigen; |
| using namespace tosa; |
| |
| OpConstShape::OpConstShape(SubgraphTraverser* sgt_, uint64_t id_) |
| : GraphNode(sgt_, Op_CONST, id_) |
| { |
| setRequiredOperands(0, 1); |
| } |
| |
| OpConstShape::~OpConstShape() |
| {} |
| |
| int OpConstShape::checkTensorAttributes() |
| { |
| if (validateRequiredOperands()) |
| return 1; |
| |
| return 0; |
| } |
| |
| int OpConstShape::eval() |
| { |
| for (auto ct : getOutputs()) |
| { |
| if (!ct->getIsValid()) |
| { |
| std::string err = "Constant Shape tensor " + ct->getName() + " not correctly initialized"; |
| printNodeValidationError(err.c_str()); |
| return 1; |
| } |
| } |
| |
| // Evaluation is trivial for constants |
| return GraphNode::eval(); |
| } |
| |
| OpConcatShape::OpConcatShape(SubgraphTraverser* sgt_, uint64_t id_) |
| : GraphNode(sgt_, Op_CONCAT_SHAPE, id_) |
| { |
| setRequiredOperands(-1, 1); |
| setRequiredRank(1, 1); |
| } |
| |
| OpConcatShape::~OpConcatShape() |
| {} |
| |
| int OpConcatShape::checkTensorAttributes() |
| { |
| if (validateRequiredOperands()) |
| return 1; |
| |
| if (inputs.empty()) |
| { |
| printNodeValidationError("ConcatShape operator must have at least one input tensor"); |
| return 1; |
| } |
| |
| int32_t num_inputs = inputs.size(); |
| int32_t elements_count = 0; |
| for (int32_t i = 0; i < num_inputs; i++) |
| { |
| if (validateRequiredRank(inputs[i])) |
| return 1; |
| ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i])); |
| elements_count += inputs[i]->getShape()[0]; |
| } |
| |
| ERROR_IF(elements_count != outputs[0]->getShape()[0], |
| "OpConcatShape: sum of input elements not equal to output number of elements"); |
| |
| num_dims = outputs[0]->getShape()[0]; |
| out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); |
| |
| return 0; |
| } |
| |
| int OpConcatShape::eval() |
| { |
| ETensor1<EigenType> out_tensor(num_dims); |
| int32_t out_idx = 0; |
| for (size_t i = 0; i < ins.size(); i++) |
| { |
| // all tosa.shape values are 1-d tensors |
| // interate in_idx in range of [0, rank of 1-d tensor] |
| for (int32_t in_idx = 0; in_idx < inputs[i]->getShape()[0]; in_idx++) |
| { |
| out_tensor(out_idx) = ins[i]->getTensor()(in_idx); |
| out_idx++; |
| } |
| } |
| out->getTensor() = out_tensor; |
| return GraphNode::eval(); |
| } |
| |
| ShapeBinaryNodeBase::ShapeBinaryNodeBase(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_) |
| : GraphNode(sgt_, op_, id_) |
| { |
| setRequiredOperands(2, 1); |
| setRequiredRank(1, 1); |
| |
| fcn = [](EigenType a, EigenType b) -> EigenType { return EigenType(); }; |
| } |
| |
| ShapeBinaryNodeBase::~ShapeBinaryNodeBase() |
| {} |
| |
| int ShapeBinaryNodeBase::checkTensorAttributes() |
| { |
| if (validateRequiredOperands()) |
| return 1; |
| if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) |
| return 1; |
| |
| num_dims = outputs[0]->getShape()[0]; |
| |
| if (inputs[0]->getShape()[0] != num_dims) |
| { |
| std::string err = "Binary shape operators " + std::string(EnumNamesOp()[nodeType]) + |
| " lhs input and output rank/shape must match"; |
| printNodeValidationError(err.c_str()); |
| return 1; |
| } |
| |
| if (inputs[1]->getShape()[0] != num_dims) |
| { |
| std::string err = "Binary shape operators " + std::string(EnumNamesOp()[nodeType]) + |
| " rhs input and output rank/shape must match"; |
| printNodeValidationError(err.c_str()); |
| return 1; |
| } |
| |
| a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); |
| b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]); |
| result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); |
| |
| ASSERT_MEM(a && b && result); |
| |
| return 0; |
| } |
| |
| int ShapeBinaryNodeBase::eval() |
| { |
| auto ia = a->getTensor(); |
| auto ib = b->getTensor(); |
| ETensor1<EigenType> out_tens(num_dims); |
| for (int32_t i = 0; i < num_dims; i++) |
| { |
| EigenType lhs = ia(i); |
| EigenType rhs = ib(i); |
| out_tens(i) = (lhs < 0 || rhs < 0) ? static_cast<EigenType>(-1) : fcn(lhs, rhs); |
| } |
| |
| result->getTensor() = out_tens; |
| return GraphNode::eval(); |
| } |
| |
| int OpAddShape::register_fcn() |
| { |
| fcn = [](EigenType a, EigenType b) -> EigenType { return a + b; }; |
| return 0; |
| } |
| |
| int OpSubShape::register_fcn() |
| { |
| fcn = [](EigenType a, EigenType b) -> EigenType { return a - b; }; |
| return 0; |
| } |
| |
| int OpMulShape::register_fcn() |
| { |
| fcn = [](EigenType a, EigenType b) -> EigenType { return a * b; }; |
| return 0; |
| } |
| |
| int OpDivShape::register_fcn() |
| { |
| fcn = [](EigenType a, EigenType b) -> EigenType { |
| return (b == static_cast<EigenType>(0)) ? static_cast<EigenType>(-1) : (a / b); |
| }; |
| return 0; |
| } |