blob: 98c01e2cb8b23724d2094fba99f90e771df9fd05 [file] [log] [blame]
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <Layer.hpp>
9
10#include <tosa_serialization_handler.h>
11#include "TosaOperatorUtils.hpp"
12
13using namespace armnn;
14using namespace tosa;
15
16TosaSerializationBasicBlock* ConvertAdditionToTosaOperator(const std::vector<const TensorInfo*>& inputs,
17 const std::vector<const TensorInfo*>& outputs)
18{
19 // A helper function with static global variables ensures uniqueness
20 // for dynamically generating input, output and block names
21 std::string input0Name = std::string("Op_ADD_input0_") + GetUniqueTosaMappingID();
22 std::string input1Name = std::string("Op_ADD_input1_") + GetUniqueTosaMappingID();
23 std::string outputName = std::string("Op_ADD_output0_") + GetUniqueTosaMappingID();
24 std::string blockName = std::string("Op_ADD_block_") + GetUniqueTosaMappingID();
25
26 TosaSerializationOperator* op = new TosaSerializationOperator(Op_ADD,
27 Attribute_NONE,
28 nullptr,
29 {input0Name, input1Name},
30 {outputName});
31
32 std::vector<int32_t> inputShape0 = GetTosaTensorShape(inputs[0]->GetShape());
33 DType inputDType0 = ArmNNToDType(inputs[0]->GetDataType());
34
35 std::vector<int32_t> inputShape1 = GetTosaTensorShape(inputs[1]->GetShape());
36 DType inputDType1 = ArmNNToDType(inputs[1]->GetDataType());
37
38 std::vector<int32_t> outputShape0 = GetTosaTensorShape(outputs[0]->GetShape());
39 DType outputDType0 = ArmNNToDType(outputs[0]->GetDataType());
40
41 TosaSerializationTensor* inputTensor0 = new TosaSerializationTensor(input0Name, inputShape0, inputDType0, {});
42 TosaSerializationTensor* inputTensor1 = new TosaSerializationTensor(input1Name, inputShape1, inputDType1, {});
43 TosaSerializationTensor* outputTensor0 = new TosaSerializationTensor(outputName, outputShape0, outputDType0, {});
44
45 // operatorInputNames/operatorOutputNames ends up being the same as
46 // blockInputNames/blockOutputNames for one-to-one ArmNN to Tosa mappings
47 return new TosaSerializationBasicBlock(blockName, // name
48 {op}, // operators
49 {inputTensor0, inputTensor1, outputTensor0}, // tensors
50 {input0Name, input1Name}, // inputs
51 {outputName}); // outputs
52}