blob: ca423835dc136b96dd7579ab0e82a4a57eb73535 [file] [log] [blame]
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ReshapeTestImpl.hpp"
josh minor4a3c6102020-01-06 16:40:46 -06007#include "ElementwiseUnaryTestImpl.hpp"
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01008
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01009
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010010template<armnn::DataType ArmnnType, typename T>
11LayerTestResult<T, 2> Rsqrt2dTest(
12 armnn::IWorkloadFactory& workloadFactory,
13 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
14{
josh minor4a3c6102020-01-06 16:40:46 -060015 const unsigned int inputShape[] = { 2, 2 };
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010016
17 std::vector<float> inputValues
18 {
19 1.f, 4.f,
20 16.f, 25.f
21 };
22
23 std::vector<float> expectedOutputValues
24 {
25 1.f, 0.5f,
26 0.25f, 0.2f
27 };
28
josh minor4a3c6102020-01-06 16:40:46 -060029 return ElementwiseUnaryTestHelper<2, ArmnnType>(
30 workloadFactory,
31 memoryManager,
32 armnn::UnaryOperation::Rsqrt,
33 inputShape,
34 inputValues,
35 inputShape,
36 expectedOutputValues);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010037}
38
39template<armnn::DataType ArmnnType, typename T>
40LayerTestResult<T, 3> Rsqrt3dTest(
41 armnn::IWorkloadFactory& workloadFactory,
42 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
43{
josh minor4a3c6102020-01-06 16:40:46 -060044 const unsigned int inputShape[] = { 3, 1, 2 };
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010045
46 std::vector<float> inputValues
47 {
48 1.f, 4.f, 16.f,
49 25.f, 64.f, 100.f
50 };
51
52 std::vector<float> expectedOutputValues
53 {
54 1.f, 0.5f, 0.25f,
55 0.2f, 0.125f, 0.1f
56 };
57
josh minor4a3c6102020-01-06 16:40:46 -060058 return ElementwiseUnaryTestHelper<3, ArmnnType>(
59 workloadFactory,
60 memoryManager,
61 armnn::UnaryOperation::Rsqrt,
62 inputShape,
63 inputValues,
64 inputShape,
65 expectedOutputValues);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010066}
67
68template<armnn::DataType ArmnnType, typename T>
69LayerTestResult<T, 2> RsqrtZeroTest(
70 armnn::IWorkloadFactory& workloadFactory,
71 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
72{
josh minor4a3c6102020-01-06 16:40:46 -060073 const unsigned int inputShape[] = { 1, 2 };
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010074
75 std::vector<float> inputValues
76 {
77 0.f, -0.f
78 };
79
80 std::vector<float> expectedOutputValues
81 {
82 INFINITY, -INFINITY
83 };
84
josh minor4a3c6102020-01-06 16:40:46 -060085 return ElementwiseUnaryTestHelper<2, ArmnnType>(
86 workloadFactory,
87 memoryManager,
88 armnn::UnaryOperation::Rsqrt,
89 inputShape,
90 inputValues,
91 inputShape,
92 expectedOutputValues);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010093}
94
95template<armnn::DataType ArmnnType, typename T>
96LayerTestResult<T, 2> RsqrtNegativeTest(
97 armnn::IWorkloadFactory& workloadFactory,
98 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
99{
josh minor4a3c6102020-01-06 16:40:46 -0600100 const unsigned int inputShape[] = { 1, 2 };
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100101
102 std::vector<float> inputValues
103 {
104 -25.f, -16.f
105 };
106
107 std::vector<float> expectedOutputValues
108 {
109 -NAN, -NAN
110 };
111
josh minor4a3c6102020-01-06 16:40:46 -0600112 return ElementwiseUnaryTestHelper<2, ArmnnType>(
113 workloadFactory,
114 memoryManager,
115 armnn::UnaryOperation::Rsqrt,
116 inputShape,
117 inputValues,
118 inputShape,
119 expectedOutputValues);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100120}
121
122//
123// Explicit template specializations
124//
125
126template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
127Rsqrt2dTest<armnn::DataType::Float32>(
128 armnn::IWorkloadFactory& workloadFactory,
129 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
130
Matthew Jackson9bff1442019-09-12 09:08:23 +0100131template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 2>
132Rsqrt2dTest<armnn::DataType::Float16>(
133 armnn::IWorkloadFactory& workloadFactory,
134 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
135
Derek Lambertif90c56d2020-01-10 17:14:08 +0000136template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 2>
137Rsqrt2dTest<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100138 armnn::IWorkloadFactory& workloadFactory,
139 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
140
Derek Lambertif90c56d2020-01-10 17:14:08 +0000141template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 2>
142Rsqrt2dTest<armnn::DataType::QSymmS16>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100143 armnn::IWorkloadFactory& workloadFactory,
144 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
145
146template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 3>
147Rsqrt3dTest<armnn::DataType::Float32>(
148 armnn::IWorkloadFactory& workloadFactory,
149 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
150
Matthew Jackson9bff1442019-09-12 09:08:23 +0100151template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 3>
152Rsqrt3dTest<armnn::DataType::Float16>(
153 armnn::IWorkloadFactory& workloadFactory,
154 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
155
Derek Lambertif90c56d2020-01-10 17:14:08 +0000156template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 3>
157Rsqrt3dTest<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100158 armnn::IWorkloadFactory& workloadFactory,
159 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
160
Derek Lambertif90c56d2020-01-10 17:14:08 +0000161template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 3>
162Rsqrt3dTest<armnn::DataType::QSymmS16>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100163 armnn::IWorkloadFactory& workloadFactory,
164 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
165
166template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
167RsqrtZeroTest<armnn::DataType::Float32>(
168 armnn::IWorkloadFactory& workloadFactory,
169 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
170
171template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
172RsqrtNegativeTest<armnn::DataType::Float32>(
173 armnn::IWorkloadFactory& workloadFactory,
174 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);