blob: 6f57c4a61e4010dc71a372ea31eb74a394e11f6e [file] [log] [blame]
Cathal Corbettbd18eab2022-11-15 12:56:16 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "TosaTestUtils.hpp"
7
8using namespace armnn;
9using namespace tosa;
10
11void VerifyAvgPool2DIgnoreValue(TosaSerializationBasicBlock* basicBlock,
12 std::vector<std::vector<int32_t>> inputShape,
13 std::vector<std::vector<int32_t>> outputShape,
14 std::vector<std::vector<int32_t>> intermediateShape,
15 const BaseDescriptor& descriptor,
16 DType dataType = DType_FP32)
17{
18 uint32_t numInputs = static_cast<uint32_t>(inputShape.size());
19 uint32_t numOutputs = static_cast<uint32_t>(outputShape.size());
Cathal Corbettbd18eab2022-11-15 12:56:16 +000020
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000021 std::string blockStr = TosaOpToString(Op_AVG_POOL2D) + "_block_";
Cathal Corbettbd18eab2022-11-15 12:56:16 +000022 CHECK(basicBlock->GetName().find(blockStr) != std::string::npos);
23 CHECK(basicBlock->GetInputs().size() == numInputs);
24 CHECK(basicBlock->GetOutputs().size() == numOutputs);
25 CHECK(basicBlock->GetOperators().size() == 2);
26 CHECK(basicBlock->GetTensors().size() == 3);
27
28 //
29 // Verify padding operator first.
30 //
31
32 TosaSerializationOperator* padOp = basicBlock->GetOperators().at(0);
33 uint32_t padOpOutputs = 1;
34 CHECK(padOp->GetInputTensorNames().size() == numInputs);
35 CHECK(padOp->GetOutputTensorNames().size() == padOpOutputs);
36
37 for (uint32_t i = 0; i < numInputs; i++)
38 {
39 std::basic_string<char> blockInputName = basicBlock->GetInputs()[i];
40 std::basic_string<char> operatorInputName = padOp->GetInputTensorNames()[i];
41
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000042 std::string opStr = "input" + std::to_string(i) + "_";
Cathal Corbettbd18eab2022-11-15 12:56:16 +000043
44 CHECK(blockInputName == operatorInputName);
45 CHECK(basicBlock->GetTensorByName(blockInputName));
46 CHECK(blockInputName.find(opStr) != std::string::npos);
47
48 TosaSerializationTensor* tensor = basicBlock->GetTensorByName(operatorInputName);
49 CHECK(tensor->GetDtype() == dataType);
50 CHECK(tensor->GetData().size() == 0);
51 CHECK(tensor->GetShape() == inputShape[static_cast<unsigned long int>(i)]);
52 }
53
54 for (uint32_t i = 0; i < padOpOutputs; i++)
55 {
56 std::basic_string<char> operatorOutputName = padOp->GetOutputTensorNames()[i];
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000057 std::string opStr = "intermediate" + std::to_string(i) + "_";
Cathal Corbettbd18eab2022-11-15 12:56:16 +000058
59 CHECK(basicBlock->GetTensorByName(operatorOutputName));
60 CHECK(operatorOutputName.find(opStr) != std::string::npos);
61
62 TosaSerializationTensor* tensor = basicBlock->GetTensorByName(operatorOutputName);
63 CHECK(tensor->GetDtype() == dataType);
64 CHECK(tensor->GetData().size() == 0);
65 CHECK(tensor->GetShape() == intermediateShape[static_cast<unsigned long int>(i)]);
66 }
67
68 CHECK(padOp->GetAttributeType() == Attribute_PadAttribute);
69 CHECK(padOp->GetOp() == Op_PAD);
70
Cathal Corbettb30e6552022-12-07 11:50:50 +000071 VerifyTosaAttribute(descriptor,
72 padOp->GetAttribute(),
73 inputShape[0],
74 outputShape[0],
75 LayerType::Pooling2d);
Cathal Corbettbd18eab2022-11-15 12:56:16 +000076
77 //
78 // Verify average pool operator second.
79 //
80
81 TosaSerializationOperator* poolOp = basicBlock->GetOperators().at(1);
82 uint32_t poolOpInputs = 1;
83 CHECK(poolOp->GetInputTensorNames().size() == poolOpInputs);
84 CHECK(poolOp->GetOutputTensorNames().size() == numOutputs);
85
86 for (uint32_t i = 0; i < poolOpInputs; i++)
87 {
88 std::basic_string<char> operatorInputName = poolOp->GetInputTensorNames()[i];
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000089 std::string opStr = "intermediate" + std::to_string(i) + "_";
Cathal Corbettbd18eab2022-11-15 12:56:16 +000090
91 CHECK(basicBlock->GetTensorByName(operatorInputName));
92 CHECK(operatorInputName.find(opStr) != std::string::npos);
93
94 TosaSerializationTensor* tensor = basicBlock->GetTensorByName(operatorInputName);
95 CHECK(tensor->GetDtype() == dataType);
96 CHECK(tensor->GetData().size() == 0);
97 CHECK(tensor->GetShape() == intermediateShape[static_cast<unsigned long int>(i)]);
98 }
99
100 for (uint32_t i = 0; i < numOutputs; i++)
101 {
102 std::basic_string<char> blockOutputName = basicBlock->GetOutputs()[i];
103 std::basic_string<char> operatorOutputName = poolOp->GetOutputTensorNames()[i];
104
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000105 std::string opStr = "output" + std::to_string(i) + "_";
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000106
107 CHECK(blockOutputName == operatorOutputName);
108 CHECK(basicBlock->GetTensorByName(blockOutputName));
109 CHECK(blockOutputName.find(opStr) != std::string::npos);
110
111 TosaSerializationTensor* tensor = basicBlock->GetTensorByName(operatorOutputName);
112 CHECK(tensor->GetDtype() == dataType);
113 CHECK(tensor->GetData().size() == 0);
114 CHECK(tensor->GetShape() == outputShape[static_cast<unsigned long int>(i)]);
115 }
116
117 CHECK(poolOp->GetAttributeType() == Attribute_PoolAttribute);
118 CHECK(poolOp->GetOp() == Op_AVG_POOL2D);
119
Cathal Corbettb30e6552022-12-07 11:50:50 +0000120 VerifyTosaAttribute(descriptor,
121 poolOp->GetAttribute(),
122 inputShape[0],
123 outputShape[0],
124 LayerType::Pooling2d,
125 1);
Cathal Corbettbd18eab2022-11-15 12:56:16 +0000126
127}