John Richardson | dd715f2 | 2017-09-18 16:10:48 +0100 | [diff] [blame] | 1 | /* |
Michele Di Giorgio | d9eaf61 | 2020-07-08 11:12:57 +0100 | [diff] [blame] | 2 | * Copyright (c) 2017-2020 Arm Limited. |
John Richardson | dd715f2 | 2017-09-18 16:10:48 +0100 | [diff] [blame] | 3 | * |
| 4 | * SPDX-License-Identifier: MIT |
| 5 | * |
| 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | * of this software and associated documentation files (the "Software"), to |
| 8 | * deal in the Software without restriction, including without limitation the |
| 9 | * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 10 | * sell copies of the Software, and to permit persons to whom the Software is |
| 11 | * furnished to do so, subject to the following conditions: |
| 12 | * |
| 13 | * The above copyright notice and this permission notice shall be included in all |
| 14 | * copies or substantial portions of the Software. |
| 15 | * |
| 16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
Manuel Bottini | 79fa9a2 | 2019-02-22 17:54:22 +0000 | [diff] [blame] | 21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
John Richardson | dd715f2 | 2017-09-18 16:10:48 +0100 | [diff] [blame] | 22 | * SOFTWARE. |
| 23 | */ |
| 24 | #include "PixelWiseMultiplication.h" |
| 25 | |
Georgios Pinitas | bf28a3c | 2018-09-18 14:34:48 +0100 | [diff] [blame] | 26 | #include "tests/validation/Helpers.h" |
| 27 | |
John Richardson | dd715f2 | 2017-09-18 16:10:48 +0100 | [diff] [blame] | 28 | namespace arm_compute |
| 29 | { |
| 30 | namespace test |
| 31 | { |
| 32 | namespace validation |
| 33 | { |
| 34 | namespace reference |
| 35 | { |
| 36 | template <class T> |
| 37 | struct is_floating_point |
| 38 | : std::integral_constant < bool, |
| 39 | std::is_same<float, typename std::remove_cv<T>::type>::value || std::is_same<half_float::half, typename std::remove_cv<T>::type>::value |
| 40 | || std::is_same<double, typename std::remove_cv<T>::type>::value || std::is_same<long double, typename std::remove_cv<T>::type>::value > |
| 41 | { |
| 42 | }; |
| 43 | |
Michele Di Giorgio | 6259e5f | 2018-01-17 17:29:33 +0000 | [diff] [blame] | 44 | namespace |
| 45 | { |
SiCong Li | bb88f89 | 2020-08-28 11:18:47 +0100 | [diff] [blame] | 46 | constexpr float scale1_constant = 1.f; |
| 47 | |
Michele Di Giorgio | 6259e5f | 2018-01-17 17:29:33 +0000 | [diff] [blame] | 48 | /** Compute the result of `src1 * src2 * scale`. The result type always matches the type of @p src2. |
| 49 | * |
Vidhya Sudhan Loganathan | 0fc2545 | 2018-06-18 14:40:56 +0100 | [diff] [blame] | 50 | * @param[in] src1 An input value. Data types supported: U8/S16/F16/F32. |
Michele Di Giorgio | 6259e5f | 2018-01-17 17:29:33 +0000 | [diff] [blame] | 51 | * @param[in] src2 An input value. Data types supported: same as @p src1. |
| 52 | * @param[in] scale Scale to apply after multiplication. |
Vidhya Sudhan Loganathan | 0fc2545 | 2018-06-18 14:40:56 +0100 | [diff] [blame] | 53 | * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15. |
Michele Di Giorgio | 6259e5f | 2018-01-17 17:29:33 +0000 | [diff] [blame] | 54 | * @param[in] convert_policy Overflow policy. Supported overflow policies: Wrap, Saturate |
| 55 | * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even. |
| 56 | */ |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 57 | template <typename T1, typename T2, typename T3> |
| 58 | T3 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) |
Michele Di Giorgio | 6259e5f | 2018-01-17 17:29:33 +0000 | [diff] [blame] | 59 | { |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 60 | using intermediate_type = typename common_promoted_signed_type<T1, T2, T3>::intermediate_type; |
Michele Di Giorgio | 6259e5f | 2018-01-17 17:29:33 +0000 | [diff] [blame] | 61 | |
| 62 | const double val = static_cast<intermediate_type>(src1) * static_cast<intermediate_type>(src2) * static_cast<double>(scale); |
| 63 | |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 64 | if(is_floating_point<T3>::value) |
Michele Di Giorgio | 6259e5f | 2018-01-17 17:29:33 +0000 | [diff] [blame] | 65 | { |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 66 | const auto result = static_cast<T3>(val); |
Michele Di Giorgio | 6259e5f | 2018-01-17 17:29:33 +0000 | [diff] [blame] | 67 | |
| 68 | return result; |
| 69 | } |
| 70 | else |
| 71 | { |
| 72 | double rounded_val = 0; |
| 73 | switch(rounding_policy) |
| 74 | { |
| 75 | case(RoundingPolicy::TO_ZERO): |
| 76 | rounded_val = support::cpp11::trunc(val); |
| 77 | break; |
| 78 | case(RoundingPolicy::TO_NEAREST_UP): |
| 79 | rounded_val = round_half_up(val); |
| 80 | break; |
| 81 | case(RoundingPolicy::TO_NEAREST_EVEN): |
| 82 | rounded_val = round_half_even(val); |
| 83 | break; |
| 84 | default: |
| 85 | ARM_COMPUTE_ERROR("Unsupported rounding policy"); |
| 86 | } |
| 87 | |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 88 | const auto result = static_cast<T3>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T3>(rounded_val) : rounded_val); |
Michele Di Giorgio | 6259e5f | 2018-01-17 17:29:33 +0000 | [diff] [blame] | 89 | |
| 90 | return result; |
| 91 | } |
| 92 | } |
| 93 | |
SiCong Li | bb88f89 | 2020-08-28 11:18:47 +0100 | [diff] [blame] | 94 | template <> |
| 95 | int32_t mul(const int32_t src1, const int32_t src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) |
| 96 | { |
| 97 | const int64_t intermediate_val = static_cast<int64_t>(src1) * static_cast<int64_t>(src2); |
| 98 | |
| 99 | if(std::abs(scale - scale1_constant) < 0.00001f) |
| 100 | { |
| 101 | // Use bit-accurate integer arithmetic for scale == 1 |
| 102 | // Apply conversion |
| 103 | if(convert_policy == ConvertPolicy::SATURATE) |
| 104 | { |
| 105 | return saturate_cast<int32_t>(intermediate_val); |
| 106 | } |
| 107 | else |
| 108 | { |
| 109 | // Correct wrapping behaviour for int32_t |
| 110 | const auto i32_hi = static_cast<int64_t>(std::numeric_limits<int32_t>::max()); |
| 111 | const auto i32_lo = static_cast<int64_t>(std::numeric_limits<int32_t>::lowest()); |
| 112 | const auto i32_wi = static_cast<int64_t>(1) << 32; |
| 113 | int64_t wrapped_rounded_val = intermediate_val - i32_wi * static_cast<int64_t>(support::cpp11::trunc(static_cast<double>(intermediate_val) / i32_wi)); |
| 114 | if(wrapped_rounded_val <= i32_hi) |
| 115 | { |
| 116 | return static_cast<int32_t>(wrapped_rounded_val); |
| 117 | } |
| 118 | else |
| 119 | { |
| 120 | // Values beyond i32_hi wrap around to negatives |
| 121 | return static_cast<int32_t>((wrapped_rounded_val - i32_hi) + i32_lo - 1); |
| 122 | } |
| 123 | } |
| 124 | } |
| 125 | else |
| 126 | { |
| 127 | // Use double arithmetic for scale != 1; may not be bit-accurate |
| 128 | // Apply scaling |
| 129 | // scale == 1 / 2^scale_exponent |
| 130 | int scale_exponent = 0; |
| 131 | std::frexp(scale, &scale_exponent); |
| 132 | // Store the positive exponent. We know that we compute 1/2^n |
| 133 | // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5 |
| 134 | scale_exponent = std::abs(scale_exponent - 1); |
| 135 | const double scale_inv = static_cast<int64_t>(1) << scale_exponent; |
| 136 | const double val = intermediate_val / scale_inv; |
| 137 | // Apply rounding |
| 138 | double rounded_val = 0; |
| 139 | switch(rounding_policy) |
| 140 | { |
| 141 | case(RoundingPolicy::TO_ZERO): |
| 142 | rounded_val = support::cpp11::trunc(val); |
| 143 | break; |
| 144 | case(RoundingPolicy::TO_NEAREST_UP): |
| 145 | rounded_val = round_half_up(val); |
| 146 | break; |
| 147 | case(RoundingPolicy::TO_NEAREST_EVEN): |
| 148 | rounded_val = round_half_even(val); |
| 149 | break; |
| 150 | default: |
| 151 | ARM_COMPUTE_ERROR("Unsupported rounding policy"); |
| 152 | } |
| 153 | // Apply conversion |
| 154 | if(convert_policy == ConvertPolicy::SATURATE) |
| 155 | { |
| 156 | return saturate_cast<int32_t>(rounded_val); |
| 157 | } |
| 158 | else |
| 159 | { |
| 160 | // Correct wrapping behaviour for int32_t |
| 161 | const auto i32_hi = static_cast<double>(std::numeric_limits<int32_t>::max()); |
| 162 | const auto i32_lo = static_cast<double>(std::numeric_limits<int32_t>::lowest()); |
| 163 | const auto i32_wi = static_cast<double>(static_cast<int64_t>(1) << 32); |
| 164 | double wrapped_rounded_val = rounded_val - i32_wi * std::floor(rounded_val / i32_wi); |
| 165 | if(wrapped_rounded_val <= i32_hi) |
| 166 | { |
| 167 | return static_cast<int32_t>(wrapped_rounded_val); |
| 168 | } |
| 169 | else |
| 170 | { |
| 171 | // Values beyond i32_hi wrap around to negatives |
| 172 | return static_cast<int32_t>((wrapped_rounded_val - i32_hi) + i32_lo - 1); |
| 173 | } |
| 174 | } |
| 175 | } |
| 176 | } |
| 177 | |
Michele Di Giorgio | 6259e5f | 2018-01-17 17:29:33 +0000 | [diff] [blame] | 178 | template <size_t dim> |
| 179 | struct BroadcastUnroll |
| 180 | { |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 181 | template <typename T1, typename T2, typename T3> |
| 182 | static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T3> &dst, |
Michele Di Giorgio | 6259e5f | 2018-01-17 17:29:33 +0000 | [diff] [blame] | 183 | float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, |
| 184 | Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst) |
| 185 | { |
| 186 | const bool src1_is_broadcast = (src1.shape()[dim - 1] != dst.shape()[dim - 1]); |
| 187 | const bool src2_is_broadcast = (src2.shape()[dim - 1] != dst.shape()[dim - 1]); |
| 188 | |
| 189 | id_src1.set(dim - 1, 0); |
| 190 | id_src2.set(dim - 1, 0); |
| 191 | id_dst.set(dim - 1, 0); |
| 192 | |
| 193 | for(size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1]) |
| 194 | { |
| 195 | BroadcastUnroll < dim - 1 >::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst); |
| 196 | |
| 197 | id_src1[dim - 1] += !src1_is_broadcast; |
| 198 | id_src2[dim - 1] += !src2_is_broadcast; |
| 199 | } |
| 200 | } |
| 201 | }; |
| 202 | |
| 203 | template <> |
| 204 | struct BroadcastUnroll<0> |
| 205 | { |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 206 | template <typename T1, typename T2, typename T3> |
| 207 | static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T3> &dst, |
Michele Di Giorgio | 6259e5f | 2018-01-17 17:29:33 +0000 | [diff] [blame] | 208 | float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, |
| 209 | Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst) |
| 210 | { |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 211 | dst[coord2index(dst.shape(), id_dst)] = mul<T1, T2, T3>(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy); |
Michele Di Giorgio | 6259e5f | 2018-01-17 17:29:33 +0000 | [diff] [blame] | 212 | } |
| 213 | }; |
| 214 | } // namespace |
| 215 | |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 216 | template <typename T1, typename T2, typename T3> |
| 217 | SimpleTensor<T3> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, |
| 218 | DataType dt_out, const QuantizationInfo &qout) |
John Richardson | dd715f2 | 2017-09-18 16:10:48 +0100 | [diff] [blame] | 219 | { |
Georgios Pinitas | bf28a3c | 2018-09-18 14:34:48 +0100 | [diff] [blame] | 220 | ARM_COMPUTE_UNUSED(qout); |
| 221 | |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 222 | SimpleTensor<T3> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out); |
John Richardson | dd715f2 | 2017-09-18 16:10:48 +0100 | [diff] [blame] | 223 | |
| 224 | if(scale < 0) |
| 225 | { |
| 226 | ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative"); |
| 227 | } |
| 228 | |
Michalis Spyrou | bcfd09a | 2019-05-01 13:03:59 +0100 | [diff] [blame] | 229 | Coordinates id_src1{}; |
| 230 | Coordinates id_src2{}; |
| 231 | Coordinates id_dst{}; |
John Richardson | dd715f2 | 2017-09-18 16:10:48 +0100 | [diff] [blame] | 232 | |
Michele Di Giorgio | 6259e5f | 2018-01-17 17:29:33 +0000 | [diff] [blame] | 233 | BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst); |
John Richardson | dd715f2 | 2017-09-18 16:10:48 +0100 | [diff] [blame] | 234 | |
| 235 | return dst; |
| 236 | } |
| 237 | |
Georgios Pinitas | bf28a3c | 2018-09-18 14:34:48 +0100 | [diff] [blame] | 238 | template <> |
| 239 | SimpleTensor<uint8_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 240 | DataType dt_out, const QuantizationInfo &qout) |
Georgios Pinitas | bf28a3c | 2018-09-18 14:34:48 +0100 | [diff] [blame] | 241 | { |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 242 | SimpleTensor<uint8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout); |
Georgios Pinitas | bf28a3c | 2018-09-18 14:34:48 +0100 | [diff] [blame] | 243 | |
| 244 | if(src1.data_type() == DataType::QASYMM8 && src2.data_type() == DataType::QASYMM8) |
| 245 | { |
| 246 | SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1); |
| 247 | SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2); |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 248 | SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout); |
Michele Di Giorgio | 4aff98f | 2019-08-28 16:27:26 +0100 | [diff] [blame] | 249 | dst = convert_to_asymmetric<uint8_t>(dst_tmp, qout); |
Georgios Pinitas | bf28a3c | 2018-09-18 14:34:48 +0100 | [diff] [blame] | 250 | } |
| 251 | else |
| 252 | { |
| 253 | if(scale < 0) |
| 254 | { |
| 255 | ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative"); |
| 256 | } |
| 257 | |
Michalis Spyrou | bcfd09a | 2019-05-01 13:03:59 +0100 | [diff] [blame] | 258 | Coordinates id_src1{}; |
| 259 | Coordinates id_src2{}; |
| 260 | Coordinates id_dst{}; |
Georgios Pinitas | bf28a3c | 2018-09-18 14:34:48 +0100 | [diff] [blame] | 261 | BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst); |
| 262 | } |
| 263 | return dst; |
| 264 | } |
Michele Di Giorgio | d8a468f | 2019-06-19 15:34:41 +0100 | [diff] [blame] | 265 | |
| 266 | template <> |
Sheri Zhang | fcf6f4e | 2020-06-25 20:01:00 +0100 | [diff] [blame] | 267 | SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, |
| 268 | DataType dt_out, const QuantizationInfo &qout) |
| 269 | { |
| 270 | SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout); |
| 271 | |
| 272 | if(src1.data_type() == DataType::QASYMM8 && src2.data_type() == DataType::QASYMM8) |
| 273 | { |
| 274 | SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1); |
| 275 | SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2); |
| 276 | SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout); |
| 277 | dst = convert_to_symmetric<int16_t>(dst_tmp, qout); |
| 278 | } |
| 279 | else |
| 280 | { |
| 281 | if(scale < 0) |
| 282 | { |
| 283 | ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative"); |
| 284 | } |
| 285 | |
| 286 | Coordinates id_src1{}; |
| 287 | Coordinates id_src2{}; |
| 288 | Coordinates id_dst{}; |
| 289 | BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst); |
| 290 | } |
| 291 | return dst; |
| 292 | } |
| 293 | |
| 294 | template <> |
Pablo Tello | 52ea9c2 | 2019-12-10 11:28:53 +0000 | [diff] [blame] | 295 | SimpleTensor<int8_t> pixel_wise_multiplication(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 296 | DataType dt_out, const QuantizationInfo &qout) |
Pablo Tello | 52ea9c2 | 2019-12-10 11:28:53 +0000 | [diff] [blame] | 297 | { |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 298 | SimpleTensor<int8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout); |
Pablo Tello | 52ea9c2 | 2019-12-10 11:28:53 +0000 | [diff] [blame] | 299 | |
| 300 | if(src1.data_type() == DataType::QASYMM8_SIGNED && src2.data_type() == DataType::QASYMM8_SIGNED) |
| 301 | { |
| 302 | SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1); |
| 303 | SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2); |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 304 | SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout); |
Pablo Tello | 52ea9c2 | 2019-12-10 11:28:53 +0000 | [diff] [blame] | 305 | dst = convert_to_asymmetric<int8_t>(dst_tmp, qout); |
| 306 | } |
| 307 | else |
| 308 | { |
| 309 | if(scale < 0) |
| 310 | { |
| 311 | ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative"); |
| 312 | } |
| 313 | |
| 314 | Coordinates id_src1{}; |
| 315 | Coordinates id_src2{}; |
| 316 | Coordinates id_dst{}; |
| 317 | BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst); |
| 318 | } |
| 319 | return dst; |
| 320 | } |
| 321 | |
| 322 | template <> |
Michele Di Giorgio | d8a468f | 2019-06-19 15:34:41 +0100 | [diff] [blame] | 323 | SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 324 | DataType dt_out, const QuantizationInfo &qout) |
Michele Di Giorgio | d8a468f | 2019-06-19 15:34:41 +0100 | [diff] [blame] | 325 | { |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 326 | SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout); |
Michele Di Giorgio | d8a468f | 2019-06-19 15:34:41 +0100 | [diff] [blame] | 327 | |
| 328 | if(src1.data_type() == DataType::QSYMM16 && src2.data_type() == DataType::QSYMM16) |
| 329 | { |
| 330 | SimpleTensor<float> src1_tmp = convert_from_symmetric<int16_t>(src1); |
| 331 | SimpleTensor<float> src2_tmp = convert_from_symmetric<int16_t>(src2); |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 332 | SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout); |
Michele Di Giorgio | d8a468f | 2019-06-19 15:34:41 +0100 | [diff] [blame] | 333 | dst = convert_to_symmetric<int16_t>(dst_tmp, qout); |
| 334 | } |
| 335 | else |
| 336 | { |
| 337 | if(scale < 0) |
| 338 | { |
| 339 | ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative"); |
| 340 | } |
| 341 | |
| 342 | Coordinates id_src1{}; |
| 343 | Coordinates id_src2{}; |
| 344 | Coordinates id_dst{}; |
| 345 | BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst); |
| 346 | } |
| 347 | return dst; |
| 348 | } |
John Richardson | dd715f2 | 2017-09-18 16:10:48 +0100 | [diff] [blame] | 349 | // *INDENT-OFF* |
| 350 | // clang-format off |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 351 | template SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout); |
| 352 | template SimpleTensor<int32_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout); |
SiCong Li | bb88f89 | 2020-08-28 11:18:47 +0100 | [diff] [blame] | 353 | template SimpleTensor<int32_t> pixel_wise_multiplication(const SimpleTensor<int32_t> &src1, const SimpleTensor<int32_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout); |
Michele Di Giorgio | 9428a18 | 2020-03-30 14:10:20 +0100 | [diff] [blame] | 354 | template SimpleTensor<float> pixel_wise_multiplication(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout); |
| 355 | template SimpleTensor<half_float::half> pixel_wise_multiplication(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout); |
John Richardson | dd715f2 | 2017-09-18 16:10:48 +0100 | [diff] [blame] | 356 | // clang-format on |
| 357 | // *INDENT-ON* |
| 358 | } // namespace reference |
| 359 | } // namespace validation |
| 360 | } // namespace test |
| 361 | } // namespace arm_compute |