blob: 66342223acc80b910dd80c6fb5ae0bdecb036ebc [file] [log] [blame]
Nattapat Chaimanowongb3485212019-03-04 12:35:39 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "ParserFlatbuffersSerializeFixture.hpp"
Finn Williams85d36712021-01-26 22:30:06 +00008#include <armnnDeserializer/IDeserializer.hpp>
Nattapat Chaimanowongb3485212019-03-04 12:35:39 +00009
10#include <string>
11
12BOOST_AUTO_TEST_SUITE(Deserializer)
13
14struct StridedSliceFixture : public ParserFlatbuffersSerializeFixture
15{
16 explicit StridedSliceFixture(const std::string& inputShape,
17 const std::string& begin,
18 const std::string& end,
19 const std::string& stride,
20 const std::string& beginMask,
21 const std::string& endMask,
22 const std::string& shrinkAxisMask,
23 const std::string& ellipsisMask,
24 const std::string& newAxisMask,
25 const std::string& dataLayout,
26 const std::string& outputShape,
27 const std::string& dataType)
28 {
29 m_JsonString = R"(
30 {
31 inputIds: [0],
32 outputIds: [2],
33 layers: [
34 {
35 layer_type: "InputLayer",
36 layer: {
37 base: {
38 layerBindingId: 0,
39 base: {
40 index: 0,
41 layerName: "InputLayer",
42 layerType: "Input",
43 inputSlots: [{
44 index: 0,
45 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
46 }],
47 outputSlots: [{
48 index: 0,
49 tensorInfo: {
50 dimensions: )" + inputShape + R"(,
51 dataType: )" + dataType + R"(
52 }
53 }]
54 }
55 }
56 }
57 },
58 {
59 layer_type: "StridedSliceLayer",
60 layer: {
61 base: {
62 index: 1,
63 layerName: "StridedSliceLayer",
64 layerType: "StridedSlice",
65 inputSlots: [{
66 index: 0,
67 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
68 }],
69 outputSlots: [{
70 index: 0,
71 tensorInfo: {
72 dimensions: )" + outputShape + R"(,
73 dataType: )" + dataType + R"(
74 }
75 }]
76 },
77 descriptor: {
78 begin: )" + begin + R"(,
79 end: )" + end + R"(,
80 stride: )" + stride + R"(,
81 beginMask: )" + beginMask + R"(,
82 endMask: )" + endMask + R"(,
83 shrinkAxisMask: )" + shrinkAxisMask + R"(,
84 ellipsisMask: )" + ellipsisMask + R"(,
85 newAxisMask: )" + newAxisMask + R"(,
86 dataLayout: )" + dataLayout + R"(,
87 }
88 }
89 },
90 {
91 layer_type: "OutputLayer",
92 layer: {
93 base:{
94 layerBindingId: 2,
95 base: {
96 index: 2,
97 layerName: "OutputLayer",
98 layerType: "Output",
99 inputSlots: [{
100 index: 0,
101 connection: {sourceLayerIndex:1, outputSlotIndex:0 },
102 }],
103 outputSlots: [{
104 index: 0,
105 tensorInfo: {
106 dimensions: )" + outputShape + R"(,
107 dataType: )" + dataType + R"(
108 },
109 }],
110 }
111 }
112 },
113 }
114 ]
115 }
116 )";
117 SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
118 }
119};
120
121struct SimpleStridedSliceFixture : StridedSliceFixture
122{
123 SimpleStridedSliceFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",
124 "[ 0, 0, 0, 0 ]",
125 "[ 3, 2, 3, 1 ]",
126 "[ 2, 2, 2, 1 ]",
127 "0",
128 "0",
129 "0",
130 "0",
131 "0",
132 "NCHW",
133 "[ 2, 1, 2, 1 ]",
134 "Float32") {}
135};
136
137BOOST_FIXTURE_TEST_CASE(SimpleStridedSliceFloat32, SimpleStridedSliceFixture)
138{
139 RunTest<4, armnn::DataType::Float32>(0,
140 {
141 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
142 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
143 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
144 },
145 {
146 1.0f, 1.0f, 5.0f, 5.0f
147 });
148}
149
150struct StridedSliceMaskFixture : StridedSliceFixture
151{
152 StridedSliceMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",
153 "[ 1, 1, 1, 1 ]",
154 "[ 1, 1, 1, 1 ]",
155 "[ 1, 1, 1, 1 ]",
156 "15",
157 "15",
158 "0",
159 "0",
160 "0",
161 "NCHW",
162 "[ 3, 2, 3, 1 ]",
163 "Float32") {}
164};
165
166BOOST_FIXTURE_TEST_CASE(StridedSliceMaskFloat32, StridedSliceMaskFixture)
167{
168 RunTest<4, armnn::DataType::Float32>(0,
169 {
170 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
171 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
172 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
173 },
174 {
175 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
176 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
177 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
178 });
179}
180
181BOOST_AUTO_TEST_SUITE_END()