blob: 62fdc5cd6089a4c8079bad98fa6a685fe7cdf72b [file] [log] [blame]
Tamas Nyirid998a1c2021-11-05 14:55:33 +00001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ParserFlatbuffersSerializeFixture.hpp"
7#include <armnnDeserializer/IDeserializer.hpp>
8
9#include <string>
10
11TEST_SUITE("Deserializer_Pooling3d")
12{
13struct Pooling3dFixture : public ParserFlatbuffersSerializeFixture
14{
15 explicit Pooling3dFixture(const std::string &inputShape,
16 const std::string &outputShape,
17 const std::string &dataType,
18 const std::string &dataLayout,
19 const std::string &poolingAlgorithm)
20 {
21 m_JsonString = R"(
22 {
23 inputIds: [0],
24 outputIds: [2],
25 layers: [
26 {
27 layer_type: "InputLayer",
28 layer: {
29 base: {
30 layerBindingId: 0,
31 base: {
32 index: 0,
33 layerName: "InputLayer",
34 layerType: "Input",
35 inputSlots: [{
36 index: 0,
37 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
38 }],
39 outputSlots: [ {
40 index: 0,
41 tensorInfo: {
42 dimensions: )" + inputShape + R"(,
43 dataType: )" + dataType + R"(
44 }}]
45 }
46 }}},
47 {
48 layer_type: "Pooling3dLayer",
49 layer: {
50 base: {
51 index: 1,
52 layerName: "Pooling3dLayer",
53 layerType: "Pooling3d",
54 inputSlots: [{
55 index: 0,
56 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
57 }],
58 outputSlots: [ {
59 index: 0,
60 tensorInfo: {
61 dimensions: )" + outputShape + R"(,
62 dataType: )" + dataType + R"(
63
64 }}]},
65 descriptor: {
66 poolType: )" + poolingAlgorithm + R"(,
67 outputShapeRounding: "Floor",
68 paddingMethod: Exclude,
69 dataLayout: )" + dataLayout + R"(,
70 padLeft: 0,
71 padRight: 0,
72 padTop: 0,
73 padBottom: 0,
74 padFront: 0,
75 padBack: 0,
76 poolWidth: 2,
77 poolHeight: 2,
78 poolDepth: 2,
79 strideX: 2,
80 strideY: 2,
81 strideZ: 2
82 }
83 }},
84 {
85 layer_type: "OutputLayer",
86 layer: {
87 base:{
88 layerBindingId: 0,
89 base: {
90 index: 2,
91 layerName: "OutputLayer",
92 layerType: "Output",
93 inputSlots: [{
94 index: 0,
95 connection: {sourceLayerIndex:1, outputSlotIndex:0 },
96 }],
97 outputSlots: [ {
98 index: 0,
99 tensorInfo: {
100 dimensions: )" + outputShape + R"(,
101 dataType: )" + dataType + R"(
102 },
103 }],
104 }}},
105 }]
106 }
107 )";
108 SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
109 }
110};
111
112struct SimpleAvgPooling3dFixture : Pooling3dFixture
113{
114 SimpleAvgPooling3dFixture() : Pooling3dFixture("[ 1, 2, 2, 2, 1 ]",
115 "[ 1, 1, 1, 1, 1 ]",
116 "Float32", "NDHWC", "Average") {}
117};
118
119struct SimpleAvgPooling3dFixture2 : Pooling3dFixture
120{
121 SimpleAvgPooling3dFixture2() : Pooling3dFixture("[ 1, 2, 2, 2, 1 ]",
122 "[ 1, 1, 1, 1, 1 ]",
123 "QuantisedAsymm8", "NDHWC", "Average") {}
124};
125
126struct SimpleMaxPooling3dFixture : Pooling3dFixture
127{
128 SimpleMaxPooling3dFixture() : Pooling3dFixture("[ 1, 1, 2, 2, 2 ]",
129 "[ 1, 1, 1, 1, 1 ]",
130 "Float32", "NCDHW", "Max") {}
131};
132
133struct SimpleMaxPooling3dFixture2 : Pooling3dFixture
134{
135 SimpleMaxPooling3dFixture2() : Pooling3dFixture("[ 1, 1, 2, 2, 2 ]",
136 "[ 1, 1, 1, 1, 1 ]",
137 "QuantisedAsymm8", "NCDHW", "Max") {}
138};
139
140struct SimpleL2Pooling3dFixture : Pooling3dFixture
141{
142 SimpleL2Pooling3dFixture() : Pooling3dFixture("[ 1, 2, 2, 2, 1 ]",
143 "[ 1, 1, 1, 1, 1 ]",
144 "Float32", "NDHWC", "L2") {}
145};
146
147TEST_CASE_FIXTURE(SimpleAvgPooling3dFixture, "Pooling3dFloat32Avg")
148{
149 RunTest<5, armnn::DataType::Float32>(0, { 2, 3, 5, 2, 3, 2, 3, 4 }, { 3 });
150}
151
152TEST_CASE_FIXTURE(SimpleAvgPooling3dFixture2, "Pooling3dQuantisedAsymm8Avg")
153{
154 RunTest<5, armnn::DataType::QAsymmU8>(0,{ 20, 40, 60, 80, 50, 60, 70, 30 },{ 50 });
155}
156
157TEST_CASE_FIXTURE(SimpleMaxPooling3dFixture, "Pooling3dFloat32Max")
158{
159 RunTest<5, armnn::DataType::Float32>(0, { 2, 5, 5, 2, 1, 3, 4, 0 }, { 5 });
160}
161
162TEST_CASE_FIXTURE(SimpleMaxPooling3dFixture2, "Pooling3dQuantisedAsymm8Max")
163{
164 RunTest<5, armnn::DataType::QAsymmU8>(0,{ 20, 40, 60, 80, 10, 40, 0, 70 },{ 80 });
165}
166
167TEST_CASE_FIXTURE(SimpleL2Pooling3dFixture, "Pooling3dFloat32L2")
168{
169 RunTest<5, armnn::DataType::Float32>(0, { 2, 3, 5, 2, 4, 1, 1, 3 }, { 2.93683503112f });
170}
171
172}
173