blob: e3ef71e84737b26107ea5e5c1eccd1c162144c8e [file] [log] [blame]
Nattapat Chaimanowongb3485212019-03-04 12:35:39 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Nattapat Chaimanowongb3485212019-03-04 12:35:39 +00006#include "ParserFlatbuffersSerializeFixture.hpp"
Finn Williams85d36712021-01-26 22:30:06 +00007#include <armnnDeserializer/IDeserializer.hpp>
Nattapat Chaimanowongb3485212019-03-04 12:35:39 +00008
9#include <string>
10
Sadik Armagan1625efc2021-06-10 18:24:34 +010011TEST_SUITE("Deserializer_StridedSlice")
12{
Nattapat Chaimanowongb3485212019-03-04 12:35:39 +000013struct StridedSliceFixture : public ParserFlatbuffersSerializeFixture
14{
15 explicit StridedSliceFixture(const std::string& inputShape,
16 const std::string& begin,
17 const std::string& end,
18 const std::string& stride,
19 const std::string& beginMask,
20 const std::string& endMask,
21 const std::string& shrinkAxisMask,
22 const std::string& ellipsisMask,
23 const std::string& newAxisMask,
24 const std::string& dataLayout,
25 const std::string& outputShape,
26 const std::string& dataType)
27 {
28 m_JsonString = R"(
29 {
30 inputIds: [0],
31 outputIds: [2],
32 layers: [
33 {
34 layer_type: "InputLayer",
35 layer: {
36 base: {
37 layerBindingId: 0,
38 base: {
39 index: 0,
40 layerName: "InputLayer",
41 layerType: "Input",
42 inputSlots: [{
43 index: 0,
44 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
45 }],
46 outputSlots: [{
47 index: 0,
48 tensorInfo: {
49 dimensions: )" + inputShape + R"(,
50 dataType: )" + dataType + R"(
51 }
52 }]
53 }
54 }
55 }
56 },
57 {
58 layer_type: "StridedSliceLayer",
59 layer: {
60 base: {
61 index: 1,
62 layerName: "StridedSliceLayer",
63 layerType: "StridedSlice",
64 inputSlots: [{
65 index: 0,
66 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
67 }],
68 outputSlots: [{
69 index: 0,
70 tensorInfo: {
71 dimensions: )" + outputShape + R"(,
72 dataType: )" + dataType + R"(
73 }
74 }]
75 },
76 descriptor: {
77 begin: )" + begin + R"(,
78 end: )" + end + R"(,
79 stride: )" + stride + R"(,
80 beginMask: )" + beginMask + R"(,
81 endMask: )" + endMask + R"(,
82 shrinkAxisMask: )" + shrinkAxisMask + R"(,
83 ellipsisMask: )" + ellipsisMask + R"(,
84 newAxisMask: )" + newAxisMask + R"(,
85 dataLayout: )" + dataLayout + R"(,
86 }
87 }
88 },
89 {
90 layer_type: "OutputLayer",
91 layer: {
92 base:{
93 layerBindingId: 2,
94 base: {
95 index: 2,
96 layerName: "OutputLayer",
97 layerType: "Output",
98 inputSlots: [{
99 index: 0,
100 connection: {sourceLayerIndex:1, outputSlotIndex:0 },
101 }],
102 outputSlots: [{
103 index: 0,
104 tensorInfo: {
105 dimensions: )" + outputShape + R"(,
106 dataType: )" + dataType + R"(
107 },
108 }],
109 }
110 }
111 },
112 }
113 ]
114 }
115 )";
116 SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
117 }
118};
119
120struct SimpleStridedSliceFixture : StridedSliceFixture
121{
122 SimpleStridedSliceFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",
123 "[ 0, 0, 0, 0 ]",
124 "[ 3, 2, 3, 1 ]",
125 "[ 2, 2, 2, 1 ]",
126 "0",
127 "0",
128 "0",
129 "0",
130 "0",
131 "NCHW",
132 "[ 2, 1, 2, 1 ]",
133 "Float32") {}
134};
135
Sadik Armagan1625efc2021-06-10 18:24:34 +0100136TEST_CASE_FIXTURE(SimpleStridedSliceFixture, "SimpleStridedSliceFloat32")
Nattapat Chaimanowongb3485212019-03-04 12:35:39 +0000137{
138 RunTest<4, armnn::DataType::Float32>(0,
139 {
140 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
141 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
142 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
143 },
144 {
145 1.0f, 1.0f, 5.0f, 5.0f
146 });
147}
148
149struct StridedSliceMaskFixture : StridedSliceFixture
150{
151 StridedSliceMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",
152 "[ 1, 1, 1, 1 ]",
153 "[ 1, 1, 1, 1 ]",
154 "[ 1, 1, 1, 1 ]",
155 "15",
156 "15",
157 "0",
158 "0",
159 "0",
160 "NCHW",
161 "[ 3, 2, 3, 1 ]",
162 "Float32") {}
163};
164
Sadik Armagan1625efc2021-06-10 18:24:34 +0100165TEST_CASE_FIXTURE(StridedSliceMaskFixture, "StridedSliceMaskFloat32")
Nattapat Chaimanowongb3485212019-03-04 12:35:39 +0000166{
167 RunTest<4, armnn::DataType::Float32>(0,
168 {
169 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
170 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
171 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
172 },
173 {
174 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
175 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
176 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
177 });
178}
179
Sadik Armagan1625efc2021-06-10 18:24:34 +0100180}