blob: 9f70b1c2af7ef68a31d07bebe20981a1c5e561f2 [file] [log] [blame]
John Richardsondd715f22017-09-18 16:10:48 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
John Richardsondd715f22017-09-18 16:10:48 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
Manuel Bottini79fa9a22019-02-22 17:54:22 +000021 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
John Richardsondd715f22017-09-18 16:10:48 +010022 * SOFTWARE.
23 */
24#include "PixelWiseMultiplication.h"
25
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +010026#include "tests/validation/Helpers.h"
27
John Richardsondd715f22017-09-18 16:10:48 +010028namespace arm_compute
29{
30namespace test
31{
32namespace validation
33{
34namespace reference
35{
36template <class T>
37struct is_floating_point
38 : std::integral_constant < bool,
39 std::is_same<float, typename std::remove_cv<T>::type>::value || std::is_same<half_float::half, typename std::remove_cv<T>::type>::value
40 || std::is_same<double, typename std::remove_cv<T>::type>::value || std::is_same<long double, typename std::remove_cv<T>::type>::value >
41{
42};
43
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +000044namespace
45{
46/** Compute the result of `src1 * src2 * scale`. The result type always matches the type of @p src2.
47 *
Vidhya Sudhan Loganathan0fc25452018-06-18 14:40:56 +010048 * @param[in] src1 An input value. Data types supported: U8/S16/F16/F32.
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +000049 * @param[in] src2 An input value. Data types supported: same as @p src1.
50 * @param[in] scale Scale to apply after multiplication.
Vidhya Sudhan Loganathan0fc25452018-06-18 14:40:56 +010051 * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15.
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +000052 * @param[in] convert_policy Overflow policy. Supported overflow policies: Wrap, Saturate
53 * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even.
54 */
Michele Di Giorgio9428a182020-03-30 14:10:20 +010055template <typename T1, typename T2, typename T3>
56T3 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +000057{
Michele Di Giorgio9428a182020-03-30 14:10:20 +010058 using intermediate_type = typename common_promoted_signed_type<T1, T2, T3>::intermediate_type;
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +000059
60 const double val = static_cast<intermediate_type>(src1) * static_cast<intermediate_type>(src2) * static_cast<double>(scale);
61
Michele Di Giorgio9428a182020-03-30 14:10:20 +010062 if(is_floating_point<T3>::value)
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +000063 {
Michele Di Giorgio9428a182020-03-30 14:10:20 +010064 const auto result = static_cast<T3>(val);
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +000065
66 return result;
67 }
68 else
69 {
70 double rounded_val = 0;
71 switch(rounding_policy)
72 {
73 case(RoundingPolicy::TO_ZERO):
74 rounded_val = support::cpp11::trunc(val);
75 break;
76 case(RoundingPolicy::TO_NEAREST_UP):
77 rounded_val = round_half_up(val);
78 break;
79 case(RoundingPolicy::TO_NEAREST_EVEN):
80 rounded_val = round_half_even(val);
81 break;
82 default:
83 ARM_COMPUTE_ERROR("Unsupported rounding policy");
84 }
85
Michele Di Giorgio9428a182020-03-30 14:10:20 +010086 const auto result = static_cast<T3>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T3>(rounded_val) : rounded_val);
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +000087
88 return result;
89 }
90}
91
92template <size_t dim>
93struct BroadcastUnroll
94{
Michele Di Giorgio9428a182020-03-30 14:10:20 +010095 template <typename T1, typename T2, typename T3>
96 static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T3> &dst,
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +000097 float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
98 Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
99 {
100 const bool src1_is_broadcast = (src1.shape()[dim - 1] != dst.shape()[dim - 1]);
101 const bool src2_is_broadcast = (src2.shape()[dim - 1] != dst.shape()[dim - 1]);
102
103 id_src1.set(dim - 1, 0);
104 id_src2.set(dim - 1, 0);
105 id_dst.set(dim - 1, 0);
106
107 for(size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1])
108 {
109 BroadcastUnroll < dim - 1 >::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
110
111 id_src1[dim - 1] += !src1_is_broadcast;
112 id_src2[dim - 1] += !src2_is_broadcast;
113 }
114 }
115};
116
117template <>
118struct BroadcastUnroll<0>
119{
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100120 template <typename T1, typename T2, typename T3>
121 static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T3> &dst,
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +0000122 float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
123 Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
124 {
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100125 dst[coord2index(dst.shape(), id_dst)] = mul<T1, T2, T3>(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy);
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +0000126 }
127};
128} // namespace
129
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100130template <typename T1, typename T2, typename T3>
131SimpleTensor<T3> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
132 DataType dt_out, const QuantizationInfo &qout)
John Richardsondd715f22017-09-18 16:10:48 +0100133{
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100134 ARM_COMPUTE_UNUSED(qout);
135
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100136 SimpleTensor<T3> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out);
John Richardsondd715f22017-09-18 16:10:48 +0100137
138 if(scale < 0)
139 {
140 ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
141 }
142
Michalis Spyroubcfd09a2019-05-01 13:03:59 +0100143 Coordinates id_src1{};
144 Coordinates id_src2{};
145 Coordinates id_dst{};
John Richardsondd715f22017-09-18 16:10:48 +0100146
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +0000147 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
John Richardsondd715f22017-09-18 16:10:48 +0100148
149 return dst;
150}
151
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100152template <>
153SimpleTensor<uint8_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100154 DataType dt_out, const QuantizationInfo &qout)
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100155{
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100156 SimpleTensor<uint8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100157
158 if(src1.data_type() == DataType::QASYMM8 && src2.data_type() == DataType::QASYMM8)
159 {
160 SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1);
161 SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100162 SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
Michele Di Giorgio4aff98f2019-08-28 16:27:26 +0100163 dst = convert_to_asymmetric<uint8_t>(dst_tmp, qout);
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100164 }
165 else
166 {
167 if(scale < 0)
168 {
169 ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
170 }
171
Michalis Spyroubcfd09a2019-05-01 13:03:59 +0100172 Coordinates id_src1{};
173 Coordinates id_src2{};
174 Coordinates id_dst{};
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100175 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
176 }
177 return dst;
178}
Michele Di Giorgiod8a468f2019-06-19 15:34:41 +0100179
180template <>
Sheri Zhangfcf6f4e2020-06-25 20:01:00 +0100181SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
182 DataType dt_out, const QuantizationInfo &qout)
183{
184 SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
185
186 if(src1.data_type() == DataType::QASYMM8 && src2.data_type() == DataType::QASYMM8)
187 {
188 SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1);
189 SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2);
190 SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
191 dst = convert_to_symmetric<int16_t>(dst_tmp, qout);
192 }
193 else
194 {
195 if(scale < 0)
196 {
197 ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
198 }
199
200 Coordinates id_src1{};
201 Coordinates id_src2{};
202 Coordinates id_dst{};
203 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
204 }
205 return dst;
206}
207
208template <>
Pablo Tello52ea9c22019-12-10 11:28:53 +0000209SimpleTensor<int8_t> pixel_wise_multiplication(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100210 DataType dt_out, const QuantizationInfo &qout)
Pablo Tello52ea9c22019-12-10 11:28:53 +0000211{
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100212 SimpleTensor<int8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
Pablo Tello52ea9c22019-12-10 11:28:53 +0000213
214 if(src1.data_type() == DataType::QASYMM8_SIGNED && src2.data_type() == DataType::QASYMM8_SIGNED)
215 {
216 SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1);
217 SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100218 SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
Pablo Tello52ea9c22019-12-10 11:28:53 +0000219 dst = convert_to_asymmetric<int8_t>(dst_tmp, qout);
220 }
221 else
222 {
223 if(scale < 0)
224 {
225 ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
226 }
227
228 Coordinates id_src1{};
229 Coordinates id_src2{};
230 Coordinates id_dst{};
231 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
232 }
233 return dst;
234}
235
236template <>
Michele Di Giorgiod8a468f2019-06-19 15:34:41 +0100237SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100238 DataType dt_out, const QuantizationInfo &qout)
Michele Di Giorgiod8a468f2019-06-19 15:34:41 +0100239{
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100240 SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
Michele Di Giorgiod8a468f2019-06-19 15:34:41 +0100241
242 if(src1.data_type() == DataType::QSYMM16 && src2.data_type() == DataType::QSYMM16)
243 {
244 SimpleTensor<float> src1_tmp = convert_from_symmetric<int16_t>(src1);
245 SimpleTensor<float> src2_tmp = convert_from_symmetric<int16_t>(src2);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100246 SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
Michele Di Giorgiod8a468f2019-06-19 15:34:41 +0100247 dst = convert_to_symmetric<int16_t>(dst_tmp, qout);
248 }
249 else
250 {
251 if(scale < 0)
252 {
253 ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
254 }
255
256 Coordinates id_src1{};
257 Coordinates id_src2{};
258 Coordinates id_dst{};
259 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
260 }
261 return dst;
262}
John Richardsondd715f22017-09-18 16:10:48 +0100263// *INDENT-OFF*
264// clang-format off
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100265template SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
266template SimpleTensor<int32_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
267template SimpleTensor<float> pixel_wise_multiplication(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
268template SimpleTensor<half_float::half> pixel_wise_multiplication(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
John Richardsondd715f22017-09-18 16:10:48 +0100269// clang-format on
270// *INDENT-ON*
271} // namespace reference
272} // namespace validation
273} // namespace test
274} // namespace arm_compute