blob: f8c60b1b6d0ef10a90e5f3fb6a369a8c001b9939 [file] [log] [blame]
Kevin May1bea6be2023-12-12 11:18:46 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5// Copyright © 2020 The TensorFlow Authors. All Rights Reserved.
6// SPDX-License-Identifier: Apache-2.0
7//
8
9#include "SplitOperator.hpp"
10
11// This function is paraphrased from:
12// tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc from function convertSplitOp
13TosaSerializationBasicBlock* ConvertSplitToTosaOperator(const Layer* layer,
14 const std::vector<const TensorInfo*>& inputs,
15 const std::vector<const TensorInfo*>& outputs,
16 const SplitterDescriptor* splitDescriptor)
17{
18 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE( inputs.size() == 1,
19 "ConvertSplitToTosaOperator: Split must have only one input" );
20
Kevin Mayf0d8ec12023-12-14 14:57:59 +000021 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE( outputs.size() >= 1,
22 "ConvertSplitToTosaOperator: Split must have at least one output" );
Kevin May1bea6be2023-12-12 11:18:46 +000023
24 if (!inputs[0]->GetShape().AreAllDimensionsSpecified())
25 {
26 throw armnn::Exception("ConvertSplitToTosaOperator: Dynamic input dimensions are unsupported.");
27 }
28
29 std::string inputName = std::string("input0_");
30 std::vector<std::string> outputNames;
31 std::string blockName = std::string("Op_SPLIT_block_") + GetUniqueTosaMappingID();
32
33 unsigned int numSplit = splitDescriptor->GetNumViews();
34 // If a layer is present then the block will be used for execution, so input and output names need to be determined
35 // using the previous and following layers so the graph is connected correctly. For validation this doesn't matter.
36 if(layer != nullptr)
37 {
38 // Get the layers connected to the input slots and determine unique tensor names.
39 Layer& connectedLayer = layer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
40 inputName = GenerateUniqueName(connectedLayer, 0);
41
42 for (unsigned int i=0; i < numSplit; ++i)
43 {
44 // Determine unique output(s) tensor name.
45 std::string outputName = GenerateUniqueOutputName(*layer, i);
46 outputNames.push_back(outputName);
47 }
48 }
49 else
50 {
51 for (unsigned int i=0; i < numSplit; ++i)
52 {
53 // Determine unique output(s) tensor name.
54 std::string outputName = "output" + std::to_string(i) + "_";
55 outputNames.push_back(outputName);
56 }
57 }
58
59 // Each slice op has a different beginning point.
60 // The size is the same for each slice op.
61 std::vector<int32_t> beginVals;
62 beginVals.reserve(inputs[0]->GetNumDimensions());
63 std::vector<int32_t> sizeVals;
64 sizeVals.reserve(inputs[0]->GetNumDimensions());
65 for (unsigned int j = 0; j < inputs[0]->GetNumDimensions(); ++j)
66 {
67 beginVals.emplace_back(0);
68 uint32_t dim = inputs[0]->GetShape()[j];
69 sizeVals.emplace_back(dim);
70 }
71
72 uint32_t axis = static_cast<uint32_t>(splitDescriptor->GetAxis());
73 sizeVals[axis] = sizeVals[axis] / static_cast<int32_t>(numSplit);
74
75 std::vector<TosaSerializationOperator*> ops;
76 for (unsigned int i=0; i < numSplit; ++i)
77 {
78 beginVals[axis] = static_cast<int>(i) * sizeVals[axis];
79 TosaSliceAttribute attribute(beginVals, sizeVals);
80 auto* op = new TosaSerializationOperator(Op_SLICE,
81 Attribute_SliceAttribute,
82 &attribute,
83 {inputName},
84 {outputNames[i]});
85
86 ops.push_back(op);
87 }
88
89 std::vector<TosaSerializationTensor*> tensors;
90 // Only add input tensors if connected layer is an input layer.
91 // As intermediate or constant tensors will be created separately.
92 // There also can't be duplicate tensor.
93 if(inputName.find("input0_") != std::string::npos)
94 {
95 std::vector<int32_t> inputShape = GetTosaTensorShape(inputs[0]->GetShape());
96 DType inputDType = ArmNNToDType(inputs[0]->GetDataType());
97
98 tensors.push_back(new TosaSerializationTensor(inputName, inputShape, inputDType, {}));
99 }
100
101 std::vector<int32_t> outputShape = GetTosaTensorShape(outputs[0]->GetShape());
102 DType outputDType = ArmNNToDType(outputs[0]->GetDataType());
103
104 for (unsigned int i=0; i < numSplit; ++i)
105 {
106 tensors.push_back(new TosaSerializationTensor(outputNames[i], outputShape, outputDType, {}));
107 }
108 // operatorInputNames/operatorOutputNames ends up being the same as
109 // blockInputNames/blockOutputNames for one-to-one ArmNN to TOSA mappings
110 return new TosaSerializationBasicBlock(blockName, // name
111 mainName, // region name
112 ops, // operators
113 tensors, // tensors
114 {inputName}, // inputs
115 outputNames); // outputs
116}