blob: 56e3f3402c5a5925cfc3748d8cb4031860f711d5 [file] [log] [blame]
Cathal Corbettbd18eab2022-11-15 12:56:16 +00001//
Teresa Charlin8cfd0592024-04-23 16:22:47 +01002// Copyright © 2022-2024 Arm Ltd and Contributors. All rights reserved.
Cathal Corbettbd18eab2022-11-15 12:56:16 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "Pooling2DOperator.hpp"
7
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +00008TosaSerializationBasicBlock* ConvertPooling2DToTosaOperator(const Layer* layer,
9 const std::vector<const TensorInfo*>& inputs,
Cathal Corbettbd18eab2022-11-15 12:56:16 +000010 const std::vector<const TensorInfo*>& outputs,
Cathal Corbettbd18eab2022-11-15 12:56:16 +000011 const Pooling2dDescriptor* poolDescriptor)
12{
13 std::string poolType = (poolDescriptor->m_PoolType == PoolingAlgorithm::Max) ? "Op_MAX" : "Op_AVG";
14 Op opcode = (poolDescriptor->m_PoolType == PoolingAlgorithm::Max) ? Op_MAX_POOL2D : Op_AVG_POOL2D;
15
Teresa Charlin8cfd0592024-04-23 16:22:47 +010016 std::string input0Name = std::string("input_");
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000017 std::string outputName = std::string("output0_");
18 std::string blockName = std::string("Op_") + poolType + std::string("_POOL2D_block_") + GetUniqueTosaMappingID();
Cathal Corbettbd18eab2022-11-15 12:56:16 +000019
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000020 // If a layer is present then the block will be used for execution, so input and output names need to be determined
21 // using the previous and following layers so the graph is connected correctly. For validation this doesn't matter.
22 if(layer != nullptr)
Cathal Corbettbd18eab2022-11-15 12:56:16 +000023 {
Teresa Charlin8cfd0592024-04-23 16:22:47 +010024 input0Name = GenerateUniqueInputName(layer->GetInputSlot(0));
25 outputName = GenerateUniqueOutputName(*layer);
Cathal Corbettbd18eab2022-11-15 12:56:16 +000026 }
27
28 std::vector<int> pad = {static_cast<int>(poolDescriptor->m_PadTop),
29 static_cast<int>(poolDescriptor->m_PadBottom),
30 static_cast<int>(poolDescriptor->m_PadLeft),
31 static_cast<int>(poolDescriptor->m_PadRight)};
32 std::vector<int> kernel = {static_cast<int>(poolDescriptor->m_PoolHeight),
33 static_cast<int>(poolDescriptor->m_PoolWidth)};
34 std::vector<int> stride = {static_cast<int>(poolDescriptor->m_StrideY),
35 static_cast<int>(poolDescriptor->m_StrideX)};
36 TosaPoolAttribute attribute(pad, kernel, stride, 0, 0, ArmNNToDType(inputs[0]->GetDataType()));
37
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000038 auto* op = new TosaSerializationOperator(opcode,
39 Attribute_PoolAttribute,
40 &attribute,
41 {input0Name},
42 {outputName});
Cathal Corbettbd18eab2022-11-15 12:56:16 +000043
Matthew Sloyanda6bf9e2022-12-14 10:16:27 +000044 std::vector<TosaSerializationTensor*> tensors;
45
46 // Only add input tensors if connected layer is an input layer.
47 // As intermediate or constant tensors will be created separately.
48 // There also can't be duplicate tensor.
Teresa Charlin8cfd0592024-04-23 16:22:47 +010049 if(input0Name.find("input_") != std::string::npos)
Matthew Sloyanda6bf9e2022-12-14 10:16:27 +000050 {
51 std::vector<int32_t> inputShape0 = GetTosaTensorShape(inputs[0]->GetShape());
52 DType inputDType0 = ArmNNToDType(inputs[0]->GetDataType());
53
54 tensors.push_back(new TosaSerializationTensor(input0Name, inputShape0, inputDType0, {}));
55 }
Cathal Corbettbd18eab2022-11-15 12:56:16 +000056
57 std::vector<int32_t> outputShape0 = GetTosaTensorShape(outputs[0]->GetShape());
58 DType outputDType0 = ArmNNToDType(outputs[0]->GetDataType());
59
Matthew Sloyanda6bf9e2022-12-14 10:16:27 +000060 tensors.push_back(new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
Cathal Corbettbd18eab2022-11-15 12:56:16 +000061
62 // operatorInputNames/operatorOutputNames ends up being the same as
Cathal Corbettb30e6552022-12-07 11:50:50 +000063 // blockInputNames/blockOutputNames for one-to-one ArmNN to TOSA mappings
Cathal Corbettbd18eab2022-11-15 12:56:16 +000064 return new TosaSerializationBasicBlock(blockName, // name
Narumol Prangnawaratad323af2023-09-29 17:00:38 +010065 mainName, // region name
Cathal Corbettbd18eab2022-11-15 12:56:16 +000066 {op}, // operators
Matthew Sloyanda6bf9e2022-12-14 10:16:27 +000067 tensors, // tensors
Cathal Corbettbd18eab2022-11-15 12:56:16 +000068 {input0Name}, // inputs
69 {outputName}); // outputs
70}