blob: d0a6b81de928644ebcb0530eb1c8d0738ce1a5fa [file] [log] [blame]
Teresa Charlin03027232022-05-09 17:27:08 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6//#include "ReshapeTestImpl.hpp"
7#include "ElementwiseUnaryTestImpl.hpp"
8
9
10template<armnn::DataType ArmnnType, typename T>
11LayerTestResult<T, 2> Sqrt2dTest(
12 armnn::IWorkloadFactory& workloadFactory,
13 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
14 const armnn::ITensorHandleFactory& tensorHandleFactory)
15{
16 const unsigned int inputShape[] = { 2, 2 };
17
18 std::vector<float> inputValues
19 {
20 1.f, 4.f,
21 16.f, 25.f
22 };
23
24 std::vector<float> expectedOutputValues
25 {
26 1.f, 2.f,
27 4.f, 5.f
28 };
29
30 return ElementwiseUnaryTestHelper<2, ArmnnType>(
31 workloadFactory,
32 memoryManager,
33 armnn::UnaryOperation::Sqrt,
34 inputShape,
35 inputValues,
36 inputShape,
37 expectedOutputValues,
38 tensorHandleFactory);
39}
40
41template<armnn::DataType ArmnnType, typename T>
42LayerTestResult<T, 3> Sqrt3dTest(
43 armnn::IWorkloadFactory& workloadFactory,
44 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
45 const armnn::ITensorHandleFactory& tensorHandleFactory)
46{
47 const unsigned int inputShape[] = { 3, 1, 2 };
48
49 std::vector<float> inputValues
50 {
51 1.f, 4.f, 16.f,
52 25.f, 64.f, 100.f
53 };
54
55 std::vector<float> expectedOutputValues
56 {
57 1.f, 2.f, 4.f,
58 5.f, 8.f, 10.f
59 };
60
61 return ElementwiseUnaryTestHelper<3, ArmnnType>(
62 workloadFactory,
63 memoryManager,
64 armnn::UnaryOperation::Sqrt,
65 inputShape,
66 inputValues,
67 inputShape,
68 expectedOutputValues,
69 tensorHandleFactory);
70}
71
72template<armnn::DataType ArmnnType, typename T>
73LayerTestResult<T, 2> SqrtZeroTest(
74 armnn::IWorkloadFactory& workloadFactory,
75 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
76 const armnn::ITensorHandleFactory& tensorHandleFactory)
77{
78 const unsigned int inputShape[] = { 1, 2 };
79
80 std::vector<float> inputValues
81 {
82 0.f, -0.f
83 };
84
85 std::vector<float> expectedOutputValues
86 {
87 0, 0
88 };
89
90 return ElementwiseUnaryTestHelper<2, ArmnnType>(
91 workloadFactory,
92 memoryManager,
93 armnn::UnaryOperation::Sqrt,
94 inputShape,
95 inputValues,
96 inputShape,
97 expectedOutputValues,
98 tensorHandleFactory);
99}
100
101template<armnn::DataType ArmnnType, typename T>
102LayerTestResult<T, 2> SqrtNegativeTest(
103 armnn::IWorkloadFactory& workloadFactory,
104 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
105 const armnn::ITensorHandleFactory& tensorHandleFactory)
106{
107 const unsigned int inputShape[] = { 1, 2 };
108
109 std::vector<float> inputValues
110 {
111 -25.f, -16.f
112 };
113
114 std::vector<float> expectedOutputValues
115 {
116 -NAN, -NAN
117 };
118
119 return ElementwiseUnaryTestHelper<2, ArmnnType>(
120 workloadFactory,
121 memoryManager,
122 armnn::UnaryOperation::Sqrt,
123 inputShape,
124 inputValues,
125 inputShape,
126 expectedOutputValues,
127 tensorHandleFactory);
128}
129
130//
131// Explicit template specializations
132//
133
134template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
135Sqrt2dTest<armnn::DataType::Float32>(
136 armnn::IWorkloadFactory& workloadFactory,
137 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
138 const armnn::ITensorHandleFactory& tensorHandleFactory);
139
140template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 2>
141Sqrt2dTest<armnn::DataType::Float16>(
142 armnn::IWorkloadFactory& workloadFactory,
143 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
144 const armnn::ITensorHandleFactory& tensorHandleFactory);
145
146template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 2>
147Sqrt2dTest<armnn::DataType::QAsymmS8>(
148 armnn::IWorkloadFactory& workloadFactory,
149 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
150 const armnn::ITensorHandleFactory& tensorHandleFactory);
151
152template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 2>
153Sqrt2dTest<armnn::DataType::QAsymmU8>(
154 armnn::IWorkloadFactory& workloadFactory,
155 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
156 const armnn::ITensorHandleFactory& tensorHandleFactory);
157
158template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 2>
159Sqrt2dTest<armnn::DataType::QSymmS16>(
160 armnn::IWorkloadFactory& workloadFactory,
161 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
162 const armnn::ITensorHandleFactory& tensorHandleFactory);
163
164template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 3>
165Sqrt3dTest<armnn::DataType::Float32>(
166 armnn::IWorkloadFactory& workloadFactory,
167 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
168 const armnn::ITensorHandleFactory& tensorHandleFactory);
169
170template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 3>
171Sqrt3dTest<armnn::DataType::Float16>(
172 armnn::IWorkloadFactory& workloadFactory,
173 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
174 const armnn::ITensorHandleFactory& tensorHandleFactory);
175
176template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 3>
177Sqrt3dTest<armnn::DataType::QAsymmS8>(
178 armnn::IWorkloadFactory& workloadFactory,
179 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
180 const armnn::ITensorHandleFactory& tensorHandleFactory);
181
182template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 3>
183Sqrt3dTest<armnn::DataType::QAsymmU8>(
184 armnn::IWorkloadFactory& workloadFactory,
185 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
186 const armnn::ITensorHandleFactory& tensorHandleFactory);
187
188template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 3>
189Sqrt3dTest<armnn::DataType::QSymmS16>(
190 armnn::IWorkloadFactory& workloadFactory,
191 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
192 const armnn::ITensorHandleFactory& tensorHandleFactory);
193
194template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
195SqrtZeroTest<armnn::DataType::Float32>(
196 armnn::IWorkloadFactory& workloadFactory,
197 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
198 const armnn::ITensorHandleFactory& tensorHandleFactory);
199
200template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
201SqrtNegativeTest<armnn::DataType::Float32>(
202 armnn::IWorkloadFactory& workloadFactory,
203 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
204 const armnn::ITensorHandleFactory& tensorHandleFactory);