blob: 2ffb0faa75f727ec2c35973f81faec62deb8a3be [file] [log] [blame]
giuros01164a2722018-11-20 18:34:46 +00001/*
George Worta1e7e282019-01-15 11:00:29 +00002 * Copyright (c) 2018-2019 ARM Limited.
giuros01164a2722018-11-20 18:34:46 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "ElementwiseOperations.h"
25
26#include "arm_compute/core/Types.h"
27#include "tests/validation/Helpers.h"
28
29namespace arm_compute
30{
31namespace test
32{
33namespace validation
34{
35namespace reference
36{
37namespace
38{
39template <typename T>
40T arithm_op(ArithmeticOperation op, T src1, T src2, ConvertPolicy convert_policy)
41{
42 using intermediate_type = typename common_promoted_signed_type<T>::intermediate_type;
43
44 intermediate_type val;
45
46 if(op == ArithmeticOperation::ADD)
47 {
48 val = static_cast<intermediate_type>(src1) + static_cast<intermediate_type>(src2);
49 }
50 else if(op == ArithmeticOperation::SUB)
51 {
52 val = static_cast<intermediate_type>(src1) - static_cast<intermediate_type>(src2);
53 }
54 else if(op == ArithmeticOperation::MIN)
55 {
56 val = std::min(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2));
57 }
58 else if(op == ArithmeticOperation::MAX)
59 {
60 val = std::max(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2));
61 }
62 else if(op == ArithmeticOperation::SQUARED_DIFF)
63 {
64 intermediate_type tmp = (static_cast<intermediate_type>(src1) - static_cast<intermediate_type>(src2));
65 val = tmp * tmp;
66 }
67 else if(op == ArithmeticOperation::DIV)
68 {
69 val = (static_cast<intermediate_type>(src1) / static_cast<intermediate_type>(src2));
70 }
71 else
72 {
73 ARM_COMPUTE_ERROR("Not handled");
74 }
75
76 T result;
George Worta1e7e282019-01-15 11:00:29 +000077 if(op == ArithmeticOperation::ADD || op == ArithmeticOperation::SUB || op == ArithmeticOperation::DIV)
giuros01164a2722018-11-20 18:34:46 +000078 {
79 result = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T>(val) : static_cast<T>(val);
80 }
81 else
82 {
83 result = static_cast<T>(val);
84 }
85 return result;
86}
87
88template <size_t dim>
89struct BroadcastUnroll
90{
91 template <typename T>
92 static void unroll(ArithmeticOperation op, const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, SimpleTensor<T> &dst,
93 ConvertPolicy convert_policy, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
94 {
95 const bool src1_is_broadcast = (src1.shape()[dim - 1] != dst.shape()[dim - 1]);
96 const bool src2_is_broadcast = (src2.shape()[dim - 1] != dst.shape()[dim - 1]);
97
98 id_src1.set(dim - 1, 0);
99 id_src2.set(dim - 1, 0);
100 id_dst.set(dim - 1, 0);
101
102 for(size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1])
103 {
104 BroadcastUnroll < dim - 1 >::unroll(op, src1, src2, dst, convert_policy, id_src1, id_src2, id_dst);
105
106 id_src1[dim - 1] += !src1_is_broadcast;
107 id_src2[dim - 1] += !src2_is_broadcast;
108 }
109 }
110};
111
112template <>
113struct BroadcastUnroll<0>
114{
115 template <typename T>
116 static void unroll(ArithmeticOperation op, const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, SimpleTensor<T> &dst,
117 ConvertPolicy convert_policy, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
118 {
119 dst[coord2index(dst.shape(), id_dst)] = arithm_op(op, src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], convert_policy);
120 }
121};
122} // namespace
123
124template <typename T>
125SimpleTensor<T> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, SimpleTensor<T> &dst, ConvertPolicy convert_policy)
126{
Michalis Spyroubcfd09a2019-05-01 13:03:59 +0100127 Coordinates id_src1{};
128 Coordinates id_src2{};
129 Coordinates id_dst{};
giuros01164a2722018-11-20 18:34:46 +0000130
131 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2, dst, convert_policy, id_src1, id_src2, id_dst);
132
133 return dst;
134}
135
136template <>
137SimpleTensor<uint8_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, SimpleTensor<uint8_t> &dst, ConvertPolicy convert_policy)
138{
139 if(dst.data_type() == DataType::QASYMM8)
140 {
141 SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1);
142 SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2);
143 SimpleTensor<float> dst_tmp(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dst.data_type());
144
Michalis Spyroubcfd09a2019-05-01 13:03:59 +0100145 Coordinates id_src1{};
146 Coordinates id_src2{};
147 Coordinates id_dst{};
giuros01164a2722018-11-20 18:34:46 +0000148
149 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1_tmp, src2_tmp, dst_tmp, convert_policy, id_src1, id_src2, id_dst);
150
151 dst = convert_to_asymmetric(dst_tmp, dst.quantization_info());
152 return dst;
153 }
154 else
155 {
156 // DataType::U8
Michalis Spyroubcfd09a2019-05-01 13:03:59 +0100157 Coordinates id_src1{};
158 Coordinates id_src2{};
159 Coordinates id_dst{};
giuros01164a2722018-11-20 18:34:46 +0000160
161 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2, dst, convert_policy, id_src1, id_src2, id_dst);
162
163 return dst;
164 }
165}
166
giuros0192fd9432018-12-03 17:30:00 +0000167template SimpleTensor<int32_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<int32_t> &src1, const SimpleTensor<int32_t> &src2, SimpleTensor<int32_t> &dst,
168 ConvertPolicy convert_policy);
giuros01164a2722018-11-20 18:34:46 +0000169template SimpleTensor<int16_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, SimpleTensor<int16_t> &dst,
170 ConvertPolicy convert_policy);
171template SimpleTensor<int8_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, SimpleTensor<int8_t> &dst,
172 ConvertPolicy convert_policy);
173template SimpleTensor<half> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<half> &src1, const SimpleTensor<half> &src2, SimpleTensor<half> &dst, ConvertPolicy convert_policy);
174template SimpleTensor<float> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, SimpleTensor<float> &dst, ConvertPolicy convert_policy);
175
176template <typename T>
177SimpleTensor<T> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, DataType dst_data_type, ConvertPolicy convert_policy)
178{
179 ARM_COMPUTE_ERROR_ON_MSG(dst_data_type == DataType::QASYMM8, "For QASYMM8, the quantized output tensor should be passed directly.");
180
181 SimpleTensor<T> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dst_data_type);
182 arithmetic_operation<T>(op, src1, src2, dst, convert_policy);
183 return dst;
184}
185
giuros0192fd9432018-12-03 17:30:00 +0000186template SimpleTensor<int32_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<int32_t> &src1, const SimpleTensor<int32_t> &src2, DataType dst_data_type,
187 ConvertPolicy convert_policy);
giuros01164a2722018-11-20 18:34:46 +0000188template SimpleTensor<int16_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type,
189 ConvertPolicy convert_policy);
190template SimpleTensor<int8_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
191template SimpleTensor<half> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<half> &src1, const SimpleTensor<half> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
192template SimpleTensor<float> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
193
194} // namespace reference
195} // namespace validation
196} // namespace test
197} // namespace arm_compute