blob: c33f61296a4867d4c7b2a0f19e5aa52fb6ddda6c [file] [log] [blame]
Cathal Corbettbd18eab2022-11-15 12:56:16 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Pooling2DOperator.hpp"
7
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +00008TosaSerializationBasicBlock* ConvertPooling2DToTosaOperator(const Layer* layer,
9 const std::vector<const TensorInfo*>& inputs,
Cathal Corbettbd18eab2022-11-15 12:56:16 +000010 const std::vector<const TensorInfo*>& outputs,
Cathal Corbettbd18eab2022-11-15 12:56:16 +000011 const Pooling2dDescriptor* poolDescriptor)
12{
13 std::string poolType = (poolDescriptor->m_PoolType == PoolingAlgorithm::Max) ? "Op_MAX" : "Op_AVG";
14 Op opcode = (poolDescriptor->m_PoolType == PoolingAlgorithm::Max) ? Op_MAX_POOL2D : Op_AVG_POOL2D;
15
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000016 std::string input0Name = std::string("input0_");
17 std::string outputName = std::string("output0_");
18 std::string blockName = std::string("Op_") + poolType + std::string("_POOL2D_block_") + GetUniqueTosaMappingID();
Cathal Corbettbd18eab2022-11-15 12:56:16 +000019
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000020 // If a layer is present then the block will be used for execution, so input and output names need to be determined
21 // using the previous and following layers so the graph is connected correctly. For validation this doesn't matter.
22 if(layer != nullptr)
Cathal Corbettbd18eab2022-11-15 12:56:16 +000023 {
Kevin May5b58e312022-12-15 10:15:21 +000024 // Get the layers connected to the input slots and determine unique tensor names.
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000025 Layer& connectedInputLayer = layer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
26 input0Name = GenerateUniqueName(connectedInputLayer, 0);
27
Kevin May5b58e312022-12-15 10:15:21 +000028 // Determine unique output tensor name.
Matthew Sloyanda6bf9e2022-12-14 10:16:27 +000029 outputName = GenerateUniqueOutputName(*layer, 0);
Cathal Corbettbd18eab2022-11-15 12:56:16 +000030 }
31
32 std::vector<int> pad = {static_cast<int>(poolDescriptor->m_PadTop),
33 static_cast<int>(poolDescriptor->m_PadBottom),
34 static_cast<int>(poolDescriptor->m_PadLeft),
35 static_cast<int>(poolDescriptor->m_PadRight)};
36 std::vector<int> kernel = {static_cast<int>(poolDescriptor->m_PoolHeight),
37 static_cast<int>(poolDescriptor->m_PoolWidth)};
38 std::vector<int> stride = {static_cast<int>(poolDescriptor->m_StrideY),
39 static_cast<int>(poolDescriptor->m_StrideX)};
40 TosaPoolAttribute attribute(pad, kernel, stride, 0, 0, ArmNNToDType(inputs[0]->GetDataType()));
41
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000042 auto* op = new TosaSerializationOperator(opcode,
43 Attribute_PoolAttribute,
44 &attribute,
45 {input0Name},
46 {outputName});
Cathal Corbettbd18eab2022-11-15 12:56:16 +000047
Matthew Sloyanda6bf9e2022-12-14 10:16:27 +000048 std::vector<TosaSerializationTensor*> tensors;
49
50 // Only add input tensors if connected layer is an input layer.
51 // As intermediate or constant tensors will be created separately.
52 // There also can't be duplicate tensor.
53 if(input0Name.find("input0_") != std::string::npos)
54 {
55 std::vector<int32_t> inputShape0 = GetTosaTensorShape(inputs[0]->GetShape());
56 DType inputDType0 = ArmNNToDType(inputs[0]->GetDataType());
57
58 tensors.push_back(new TosaSerializationTensor(input0Name, inputShape0, inputDType0, {}));
59 }
Cathal Corbettbd18eab2022-11-15 12:56:16 +000060
61 std::vector<int32_t> outputShape0 = GetTosaTensorShape(outputs[0]->GetShape());
62 DType outputDType0 = ArmNNToDType(outputs[0]->GetDataType());
63
Matthew Sloyanda6bf9e2022-12-14 10:16:27 +000064 tensors.push_back(new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
Cathal Corbettbd18eab2022-11-15 12:56:16 +000065
66 // operatorInputNames/operatorOutputNames ends up being the same as
Cathal Corbettb30e6552022-12-07 11:50:50 +000067 // blockInputNames/blockOutputNames for one-to-one ArmNN to TOSA mappings
Cathal Corbettbd18eab2022-11-15 12:56:16 +000068 return new TosaSerializationBasicBlock(blockName, // name
Narumol Prangnawaratad323af2023-09-29 17:00:38 +010069 mainName, // region name
Cathal Corbettbd18eab2022-11-15 12:56:16 +000070 {op}, // operators
Matthew Sloyanda6bf9e2022-12-14 10:16:27 +000071 tensors, // tensors
Cathal Corbettbd18eab2022-11-15 12:56:16 +000072 {input0Name}, // inputs
73 {outputName}); // outputs
74}