blob: cd707edb3a2241ef908fab479203b053aa161fb2 [file] [log] [blame]
Cathal Corbettbd18eab2022-11-15 12:56:16 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Pooling2DOperator.hpp"
7
8TosaSerializationBasicBlock* ConvertPooling2DToTosaOperator(const std::vector<const TensorInfo*>& inputs,
9 const std::vector<const TensorInfo*>& outputs,
10 bool isMain,
11 const Pooling2dDescriptor* poolDescriptor)
12{
13 std::string poolType = (poolDescriptor->m_PoolType == PoolingAlgorithm::Max) ? "Op_MAX" : "Op_AVG";
14 Op opcode = (poolDescriptor->m_PoolType == PoolingAlgorithm::Max) ? Op_MAX_POOL2D : Op_AVG_POOL2D;
15
16 // A helper function with static global variables ensures uniqueness
17 // for dynamically generating input, output and block names
18 std::string input0Name = poolType + std::string("_POOL2D_input0_") + GetUniqueTosaMappingID();
19 std::string outputName = poolType + std::string("_POOL2D_output0_") + GetUniqueTosaMappingID();
20 std::string blockName = poolType + std::string("_POOL2D_block_") + GetUniqueTosaMappingID();
21
22 // If it's the first block, overwrite block name with main.
23 if (isMain)
24 {
25 blockName = std::string("main");
26 }
27
28 std::vector<int> pad = {static_cast<int>(poolDescriptor->m_PadTop),
29 static_cast<int>(poolDescriptor->m_PadBottom),
30 static_cast<int>(poolDescriptor->m_PadLeft),
31 static_cast<int>(poolDescriptor->m_PadRight)};
32 std::vector<int> kernel = {static_cast<int>(poolDescriptor->m_PoolHeight),
33 static_cast<int>(poolDescriptor->m_PoolWidth)};
34 std::vector<int> stride = {static_cast<int>(poolDescriptor->m_StrideY),
35 static_cast<int>(poolDescriptor->m_StrideX)};
36 TosaPoolAttribute attribute(pad, kernel, stride, 0, 0, ArmNNToDType(inputs[0]->GetDataType()));
37
38 TosaSerializationOperator* op = new TosaSerializationOperator(opcode,
39 Attribute_PoolAttribute,
40 &attribute,
41 {input0Name},
42 {outputName});
43
44 std::vector<int32_t> inputShape0 = GetTosaTensorShape(inputs[0]->GetShape());
45 DType inputDType0 = ArmNNToDType(inputs[0]->GetDataType());
46
47 std::vector<int32_t> outputShape0 = GetTosaTensorShape(outputs[0]->GetShape());
48 DType outputDType0 = ArmNNToDType(outputs[0]->GetDataType());
49
50 TosaSerializationTensor* inputTensor0 = new TosaSerializationTensor(input0Name, inputShape0, inputDType0, {});
51 TosaSerializationTensor* outputTensor0 = new TosaSerializationTensor(outputName, outputShape0, outputDType0, {});
52
53 // operatorInputNames/operatorOutputNames ends up being the same as
54 // blockInputNames/blockOutputNames for one-to-one ArmNN to Tosa mappings
55 return new TosaSerializationBasicBlock(blockName, // name
56 {op}, // operators
57 {inputTensor0, outputTensor0}, // tensors
58 {input0Name}, // inputs
59 {outputName}); // outputs
60}