blob: 9619ca2e98b5473a3b000a8657827d9632052175 [file] [log] [blame]
Cathal Corbett839b9322022-11-18 08:52:18 +00001//
Colm Donelan7bcae3c2024-01-22 10:07:14 +00002// Copyright © 2022-2024 Arm Ltd and Contributors. All rights reserved.
Cathal Corbett839b9322022-11-18 08:52:18 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "StridedSliceTestHelper.hpp"
7
Cathal Corbett839b9322022-11-18 08:52:18 +00008#include <doctest/doctest.h>
9
10namespace armnnDelegate
11{
12
Colm Donelan7bcae3c2024-01-22 10:07:14 +000013void StridedSlice4DTest()
Cathal Corbett839b9322022-11-18 08:52:18 +000014{
15 std::vector<int32_t> inputShape { 3, 2, 3, 1 };
16 std::vector<int32_t> outputShape { 1, 2, 3, 1 };
17 std::vector<int32_t> beginShape { 4 };
18 std::vector<int32_t> endShape { 4 };
19 std::vector<int32_t> strideShape { 4 };
20
21 std::vector<int32_t> beginData { 1, 0, 0, 0 };
22 std::vector<int32_t> endData { 2, 2, 3, 1 };
23 std::vector<int32_t> strideData { 1, 1, 1, 1 };
24 std::vector<float> inputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
25 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
26 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
27 std::vector<float> outputData { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f };
28
29 StridedSliceTestImpl<float>(
Cathal Corbett839b9322022-11-18 08:52:18 +000030 inputData,
31 outputData,
32 beginData,
33 endData,
34 strideData,
35 inputShape,
36 beginShape,
37 endShape,
38 strideShape,
39 outputShape
40 );
41}
42
Colm Donelan7bcae3c2024-01-22 10:07:14 +000043void StridedSlice4DReverseTest()
Cathal Corbett839b9322022-11-18 08:52:18 +000044{
45 std::vector<int32_t> inputShape { 3, 2, 3, 1 };
46 std::vector<int32_t> outputShape { 1, 2, 3, 1 };
47 std::vector<int32_t> beginShape { 4 };
48 std::vector<int32_t> endShape { 4 };
49 std::vector<int32_t> strideShape { 4 };
50
51 std::vector<int32_t> beginData { 1, -1, 0, 0 };
52 std::vector<int32_t> endData { 2, -3, 3, 1 };
53 std::vector<int32_t> strideData { 1, -1, 1, 1 };
54 std::vector<float> inputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
55 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
56 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
57 std::vector<float> outputData { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f };
58
59 StridedSliceTestImpl<float>(
Cathal Corbett839b9322022-11-18 08:52:18 +000060 inputData,
61 outputData,
62 beginData,
63 endData,
64 strideData,
65 inputShape,
66 beginShape,
67 endShape,
68 strideShape,
69 outputShape
70 );
71}
72
Colm Donelan7bcae3c2024-01-22 10:07:14 +000073void StridedSliceSimpleStrideTest()
Cathal Corbett839b9322022-11-18 08:52:18 +000074{
75 std::vector<int32_t> inputShape { 3, 2, 3, 1 };
76 std::vector<int32_t> outputShape { 2, 1, 2, 1 };
77 std::vector<int32_t> beginShape { 4 };
78 std::vector<int32_t> endShape { 4 };
79 std::vector<int32_t> strideShape { 4 };
80
81 std::vector<int32_t> beginData { 0, 0, 0, 0 };
82 std::vector<int32_t> endData { 3, 2, 3, 1 };
83 std::vector<int32_t> strideData { 2, 2, 2, 1 };
84 std::vector<float> inputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
85 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
86 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
87 std::vector<float> outputData { 1.0f, 1.0f,
88 5.0f, 5.0f };
89
90 StridedSliceTestImpl<float>(
Cathal Corbett839b9322022-11-18 08:52:18 +000091 inputData,
92 outputData,
93 beginData,
94 endData,
95 strideData,
96 inputShape,
97 beginShape,
98 endShape,
99 strideShape,
100 outputShape
101 );
102}
103
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000104void StridedSliceSimpleRangeMaskTest()
Cathal Corbett839b9322022-11-18 08:52:18 +0000105{
106 std::vector<int32_t> inputShape { 3, 2, 3, 1 };
107 std::vector<int32_t> outputShape { 3, 2, 3, 1 };
108 std::vector<int32_t> beginShape { 4 };
109 std::vector<int32_t> endShape { 4 };
110 std::vector<int32_t> strideShape { 4 };
111
112 std::vector<int32_t> beginData { 1, 1, 1, 1 };
113 std::vector<int32_t> endData { 1, 1, 1, 1 };
114 std::vector<int32_t> strideData { 1, 1, 1, 1 };
115
116 int beginMask = -1;
117 int endMask = -1;
118
119 std::vector<float> inputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
120 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
121 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
122 std::vector<float> outputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
123 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
124 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
125
126 StridedSliceTestImpl<float>(
Cathal Corbett839b9322022-11-18 08:52:18 +0000127 inputData,
128 outputData,
129 beginData,
130 endData,
131 strideData,
132 inputShape,
133 beginShape,
134 endShape,
135 strideShape,
136 outputShape,
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000137 {},
Cathal Corbett839b9322022-11-18 08:52:18 +0000138 beginMask,
139 endMask
140 );
141}
142
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000143TEST_SUITE("StridedSliceTests")
Cathal Corbett839b9322022-11-18 08:52:18 +0000144{
145
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000146TEST_CASE ("StridedSlice_4D_Test")
Cathal Corbett839b9322022-11-18 08:52:18 +0000147{
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000148 StridedSlice4DTest();
Cathal Corbett839b9322022-11-18 08:52:18 +0000149}
150
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000151TEST_CASE ("StridedSlice_4D_Reverse_Test")
Cathal Corbett839b9322022-11-18 08:52:18 +0000152{
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000153 StridedSlice4DReverseTest();
Cathal Corbett839b9322022-11-18 08:52:18 +0000154}
155
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000156TEST_CASE ("StridedSlice_SimpleStride_Test")
Cathal Corbett839b9322022-11-18 08:52:18 +0000157{
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000158 StridedSliceSimpleStrideTest();
Cathal Corbett839b9322022-11-18 08:52:18 +0000159}
160
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000161TEST_CASE ("StridedSlice_SimpleRange_Test")
Cathal Corbett839b9322022-11-18 08:52:18 +0000162{
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000163 StridedSliceSimpleRangeMaskTest();
Cathal Corbett839b9322022-11-18 08:52:18 +0000164}
165
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000166} // StridedSliceTests TestSuite
Cathal Corbett839b9322022-11-18 08:52:18 +0000167
168} // namespace armnnDelegate