blob: 9927b75032668581c14e06fce3e8f3ed2885a8a6 [file] [log] [blame]
John Richardsondd715f22017-09-18 16:10:48 +01001/*
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +00002 * Copyright (c) 2017-2018 ARM Limited.
John Richardsondd715f22017-09-18 16:10:48 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_TEST_PIXEL_WISE_MULTIPLICATION_FIXTURE
25#define ARM_COMPUTE_TEST_PIXEL_WISE_MULTIPLICATION_FIXTURE
26
27#include "arm_compute/core/TensorShape.h"
28#include "arm_compute/core/Types.h"
29#include "tests/AssetsLibrary.h"
30#include "tests/Globals.h"
31#include "tests/IAccessor.h"
32#include "tests/framework/Asserts.h"
33#include "tests/framework/Fixture.h"
Georgios Pinitas5a7e7762017-12-01 16:27:29 +000034#include "tests/validation/reference/PixelWiseMultiplication.h"
John Richardsondd715f22017-09-18 16:10:48 +010035
36namespace arm_compute
37{
38namespace test
39{
40namespace validation
41{
42template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +010043class PixelWiseMultiplicationGenericValidationFixture : public framework::Fixture
John Richardsondd715f22017-09-18 16:10:48 +010044{
45public:
46 template <typename...>
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +000047 void setup(const TensorShape &shape0,
48 const TensorShape &shape1,
49 DataType dt_in1,
50 DataType dt_in2,
51 float scale,
52 ConvertPolicy convert_policy,
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +010053 RoundingPolicy rounding_policy,
54 QuantizationInfo qinfo0,
55 QuantizationInfo qinfo1,
56 QuantizationInfo qinfo_out)
John Richardsondd715f22017-09-18 16:10:48 +010057 {
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +010058 _target = compute_target(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out);
59 _reference = compute_reference(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out);
John Richardsondd715f22017-09-18 16:10:48 +010060 }
61
62protected:
63 template <typename U>
64 void fill(U &&tensor, unsigned int seed_offset)
65 {
66 library->fill_tensor_uniform(tensor, seed_offset);
67 }
68
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +000069 TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2,
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +010070 float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
71 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
John Richardsondd715f22017-09-18 16:10:48 +010072 {
73 // Create tensors
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +010074 TensorType src1 = create_tensor<TensorType>(shape0, dt_in1, 1, qinfo0);
75 TensorType src2 = create_tensor<TensorType>(shape1, dt_in2, 1, qinfo1);
76 TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), dt_in2, 1, qinfo_out);
John Richardsondd715f22017-09-18 16:10:48 +010077
78 // Create and configure function
79 FunctionType multiply;
80 multiply.configure(&src1, &src2, &dst, scale, convert_policy, rounding_policy);
81
82 ARM_COMPUTE_EXPECT(src1.info()->is_resizable(), framework::LogLevel::ERRORS);
83 ARM_COMPUTE_EXPECT(src2.info()->is_resizable(), framework::LogLevel::ERRORS);
84 ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
85
86 // Allocate tensors
87 src1.allocator()->allocate();
88 src2.allocator()->allocate();
89 dst.allocator()->allocate();
90
91 ARM_COMPUTE_EXPECT(!src1.info()->is_resizable(), framework::LogLevel::ERRORS);
92 ARM_COMPUTE_EXPECT(!src2.info()->is_resizable(), framework::LogLevel::ERRORS);
93 ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
94
95 // Fill tensors
96 fill(AccessorType(src1), 0);
97 fill(AccessorType(src2), 1);
98
99 // Compute function
100 multiply.run();
101
102 return dst;
103 }
104
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +0000105 SimpleTensor<T2> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2,
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100106 float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
107 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
John Richardsondd715f22017-09-18 16:10:48 +0100108 {
109 // Create reference
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100110 SimpleTensor<T1> src1{ shape0, dt_in1, 1, qinfo0 };
111 SimpleTensor<T2> src2{ shape1, dt_in2, 1, qinfo1 };
John Richardsondd715f22017-09-18 16:10:48 +0100112
113 // Fill reference
114 fill(src1, 0);
115 fill(src2, 1);
116
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100117 return reference::pixel_wise_multiplication<T1, T2>(src1, src2, scale, convert_policy, rounding_policy, qinfo_out);
John Richardsondd715f22017-09-18 16:10:48 +0100118 }
119
120 TensorType _target{};
121 SimpleTensor<T2> _reference{};
122};
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +0000123
124template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100125class PixelWiseMultiplicationValidationFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +0000126{
127public:
128 template <typename...>
129 void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
130 {
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100131 PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy,
132 QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +0000133 }
134};
135
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100136template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
137class PixelWiseMultiplicationBroadcastValidationFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
138{
139public:
140 template <typename...>
141 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
142 {
143 PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy,
144 QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
145 }
146};
147
148template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
149class PixelWiseMultiplicationValidationQuantizedFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
150{
151public:
152 template <typename...>
153 void setup(const TensorShape &shape, DataType dt, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
154 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
155 {
156 PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt, dt, scale, convert_policy, rounding_policy,
157 qinfo0, qinfo1, qinfo_out);
158 }
159};
John Richardsondd715f22017-09-18 16:10:48 +0100160} // namespace validation
161} // namespace test
162} // namespace arm_compute
163#endif /* ARM_COMPUTE_TEST_PIXEL_WISE_MULTIPLICATION_FIXTURE */