blob: d9895e5ed939a1cb30ef2db8ae3336756b550324 [file] [log] [blame]
John Richardsondd715f22017-09-18 16:10:48 +01001/*
Manuel Bottini79fa9a22019-02-22 17:54:22 +00002 * Copyright (c) 2017-2019 ARM Limited.
John Richardsondd715f22017-09-18 16:10:48 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
Manuel Bottini79fa9a22019-02-22 17:54:22 +000021 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
John Richardsondd715f22017-09-18 16:10:48 +010022 * SOFTWARE.
23 */
24#include "PixelWiseMultiplication.h"
25
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +010026#include "tests/validation/Helpers.h"
27
John Richardsondd715f22017-09-18 16:10:48 +010028namespace arm_compute
29{
30namespace test
31{
32namespace validation
33{
34namespace reference
35{
36template <class T>
37struct is_floating_point
38 : std::integral_constant < bool,
39 std::is_same<float, typename std::remove_cv<T>::type>::value || std::is_same<half_float::half, typename std::remove_cv<T>::type>::value
40 || std::is_same<double, typename std::remove_cv<T>::type>::value || std::is_same<long double, typename std::remove_cv<T>::type>::value >
41{
42};
43
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +000044namespace
45{
46/** Compute the result of `src1 * src2 * scale`. The result type always matches the type of @p src2.
47 *
Vidhya Sudhan Loganathan0fc25452018-06-18 14:40:56 +010048 * @param[in] src1 An input value. Data types supported: U8/S16/F16/F32.
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +000049 * @param[in] src2 An input value. Data types supported: same as @p src1.
50 * @param[in] scale Scale to apply after multiplication.
Vidhya Sudhan Loganathan0fc25452018-06-18 14:40:56 +010051 * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15.
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +000052 * @param[in] convert_policy Overflow policy. Supported overflow policies: Wrap, Saturate
53 * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even.
54 */
55template <typename T1, typename T2>
56T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
57{
58 using intermediate_type = typename common_promoted_signed_type<T1, T2, T2>::intermediate_type;
59
60 const double val = static_cast<intermediate_type>(src1) * static_cast<intermediate_type>(src2) * static_cast<double>(scale);
61
62 if(is_floating_point<T2>::value)
63 {
64 const auto result = static_cast<T2>(val);
65
66 return result;
67 }
68 else
69 {
70 double rounded_val = 0;
71 switch(rounding_policy)
72 {
73 case(RoundingPolicy::TO_ZERO):
74 rounded_val = support::cpp11::trunc(val);
75 break;
76 case(RoundingPolicy::TO_NEAREST_UP):
77 rounded_val = round_half_up(val);
78 break;
79 case(RoundingPolicy::TO_NEAREST_EVEN):
80 rounded_val = round_half_even(val);
81 break;
82 default:
83 ARM_COMPUTE_ERROR("Unsupported rounding policy");
84 }
85
86 const auto result = static_cast<T2>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T2>(rounded_val) : rounded_val);
87
88 return result;
89 }
90}
91
92template <size_t dim>
93struct BroadcastUnroll
94{
95 template <typename T1, typename T2>
96 static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst,
97 float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
98 Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
99 {
100 const bool src1_is_broadcast = (src1.shape()[dim - 1] != dst.shape()[dim - 1]);
101 const bool src2_is_broadcast = (src2.shape()[dim - 1] != dst.shape()[dim - 1]);
102
103 id_src1.set(dim - 1, 0);
104 id_src2.set(dim - 1, 0);
105 id_dst.set(dim - 1, 0);
106
107 for(size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1])
108 {
109 BroadcastUnroll < dim - 1 >::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
110
111 id_src1[dim - 1] += !src1_is_broadcast;
112 id_src2[dim - 1] += !src2_is_broadcast;
113 }
114 }
115};
116
117template <>
118struct BroadcastUnroll<0>
119{
120 template <typename T1, typename T2>
121 static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst,
122 float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
123 Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
124 {
125 dst[coord2index(dst.shape(), id_dst)] = mul(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy);
126 }
127};
128} // namespace
129
John Richardsondd715f22017-09-18 16:10:48 +0100130template <typename T1, typename T2>
Michalis Spyrou6260e192019-06-06 13:47:38 +0100131SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
132 const QuantizationInfo &qout)
John Richardsondd715f22017-09-18 16:10:48 +0100133{
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100134 ARM_COMPUTE_UNUSED(qout);
135
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +0000136 SimpleTensor<T2> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type());
John Richardsondd715f22017-09-18 16:10:48 +0100137
138 if(scale < 0)
139 {
140 ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
141 }
142
Michalis Spyroubcfd09a2019-05-01 13:03:59 +0100143 Coordinates id_src1{};
144 Coordinates id_src2{};
145 Coordinates id_dst{};
John Richardsondd715f22017-09-18 16:10:48 +0100146
Michele Di Giorgio6259e5f2018-01-17 17:29:33 +0000147 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
John Richardsondd715f22017-09-18 16:10:48 +0100148
149 return dst;
150}
151
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100152template <>
153SimpleTensor<uint8_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
Michalis Spyrou6260e192019-06-06 13:47:38 +0100154 const QuantizationInfo &qout)
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100155{
156 SimpleTensor<uint8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout);
157
158 if(src1.data_type() == DataType::QASYMM8 && src2.data_type() == DataType::QASYMM8)
159 {
160 SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1);
161 SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2);
162 SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout);
Michele Di Giorgio4aff98f2019-08-28 16:27:26 +0100163 dst = convert_to_asymmetric<uint8_t>(dst_tmp, qout);
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100164 }
165 else
166 {
167 if(scale < 0)
168 {
169 ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
170 }
171
Michalis Spyroubcfd09a2019-05-01 13:03:59 +0100172 Coordinates id_src1{};
173 Coordinates id_src2{};
174 Coordinates id_dst{};
Georgios Pinitasbf28a3c2018-09-18 14:34:48 +0100175 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
176 }
177 return dst;
178}
Michele Di Giorgiod8a468f2019-06-19 15:34:41 +0100179
180template <>
181SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
182 const QuantizationInfo &qout)
183{
184 SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout);
185
186 if(src1.data_type() == DataType::QSYMM16 && src2.data_type() == DataType::QSYMM16)
187 {
188 SimpleTensor<float> src1_tmp = convert_from_symmetric<int16_t>(src1);
189 SimpleTensor<float> src2_tmp = convert_from_symmetric<int16_t>(src2);
190 SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout);
191 dst = convert_to_symmetric<int16_t>(dst_tmp, qout);
192 }
193 else
194 {
195 if(scale < 0)
196 {
197 ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
198 }
199
200 Coordinates id_src1{};
201 Coordinates id_src2{};
202 Coordinates id_dst{};
203 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
204 }
205 return dst;
206}
John Richardsondd715f22017-09-18 16:10:48 +0100207// *INDENT-OFF*
208// clang-format off
Michalis Spyrou6260e192019-06-06 13:47:38 +0100209template SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout);
Michalis Spyrou6260e192019-06-06 13:47:38 +0100210template SimpleTensor<float> pixel_wise_multiplication(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout);
211template SimpleTensor<half_float::half> pixel_wise_multiplication(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout);
John Richardsondd715f22017-09-18 16:10:48 +0100212// clang-format on
213// *INDENT-ON*
214} // namespace reference
215} // namespace validation
216} // namespace test
217} // namespace arm_compute