blob: 6d533edea5628639d20efb52b0d5f5e5e33c10a0 [file] [log] [blame]
giuros01164a2722018-11-20 18:34:46 +00001/*
George Worta1e7e282019-01-15 11:00:29 +00002 * Copyright (c) 2018-2019 ARM Limited.
giuros01164a2722018-11-20 18:34:46 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "ElementwiseOperations.h"
25
26#include "arm_compute/core/Types.h"
27#include "tests/validation/Helpers.h"
28
29namespace arm_compute
30{
31namespace test
32{
33namespace validation
34{
35namespace reference
36{
37namespace
38{
39template <typename T>
40T arithm_op(ArithmeticOperation op, T src1, T src2, ConvertPolicy convert_policy)
41{
42 using intermediate_type = typename common_promoted_signed_type<T>::intermediate_type;
43
44 intermediate_type val;
45
46 if(op == ArithmeticOperation::ADD)
47 {
48 val = static_cast<intermediate_type>(src1) + static_cast<intermediate_type>(src2);
49 }
50 else if(op == ArithmeticOperation::SUB)
51 {
52 val = static_cast<intermediate_type>(src1) - static_cast<intermediate_type>(src2);
53 }
54 else if(op == ArithmeticOperation::MIN)
55 {
56 val = std::min(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2));
57 }
58 else if(op == ArithmeticOperation::MAX)
59 {
60 val = std::max(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2));
61 }
62 else if(op == ArithmeticOperation::SQUARED_DIFF)
63 {
64 intermediate_type tmp = (static_cast<intermediate_type>(src1) - static_cast<intermediate_type>(src2));
65 val = tmp * tmp;
66 }
67 else if(op == ArithmeticOperation::DIV)
68 {
69 val = (static_cast<intermediate_type>(src1) / static_cast<intermediate_type>(src2));
70 }
71 else
72 {
73 ARM_COMPUTE_ERROR("Not handled");
74 }
75
76 T result;
George Worta1e7e282019-01-15 11:00:29 +000077 if(op == ArithmeticOperation::ADD || op == ArithmeticOperation::SUB || op == ArithmeticOperation::DIV)
giuros01164a2722018-11-20 18:34:46 +000078 {
79 result = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T>(val) : static_cast<T>(val);
80 }
81 else
82 {
83 result = static_cast<T>(val);
84 }
85 return result;
86}
87
88template <size_t dim>
89struct BroadcastUnroll
90{
91 template <typename T>
92 static void unroll(ArithmeticOperation op, const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, SimpleTensor<T> &dst,
93 ConvertPolicy convert_policy, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
94 {
95 const bool src1_is_broadcast = (src1.shape()[dim - 1] != dst.shape()[dim - 1]);
96 const bool src2_is_broadcast = (src2.shape()[dim - 1] != dst.shape()[dim - 1]);
97
98 id_src1.set(dim - 1, 0);
99 id_src2.set(dim - 1, 0);
100 id_dst.set(dim - 1, 0);
101
102 for(size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1])
103 {
104 BroadcastUnroll < dim - 1 >::unroll(op, src1, src2, dst, convert_policy, id_src1, id_src2, id_dst);
105
106 id_src1[dim - 1] += !src1_is_broadcast;
107 id_src2[dim - 1] += !src2_is_broadcast;
108 }
109 }
110};
111
112template <>
113struct BroadcastUnroll<0>
114{
115 template <typename T>
116 static void unroll(ArithmeticOperation op, const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, SimpleTensor<T> &dst,
117 ConvertPolicy convert_policy, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
118 {
119 dst[coord2index(dst.shape(), id_dst)] = arithm_op(op, src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], convert_policy);
120 }
121};
122} // namespace
123
124template <typename T>
125SimpleTensor<T> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, SimpleTensor<T> &dst, ConvertPolicy convert_policy)
126{
127 Coordinates id_src1, id_src2, id_dst;
128
129 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2, dst, convert_policy, id_src1, id_src2, id_dst);
130
131 return dst;
132}
133
134template <>
135SimpleTensor<uint8_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, SimpleTensor<uint8_t> &dst, ConvertPolicy convert_policy)
136{
137 if(dst.data_type() == DataType::QASYMM8)
138 {
139 SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1);
140 SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2);
141 SimpleTensor<float> dst_tmp(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dst.data_type());
142
143 Coordinates id_src1, id_src2, id_dst;
144
145 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1_tmp, src2_tmp, dst_tmp, convert_policy, id_src1, id_src2, id_dst);
146
147 dst = convert_to_asymmetric(dst_tmp, dst.quantization_info());
148 return dst;
149 }
150 else
151 {
152 // DataType::U8
153 Coordinates id_src1, id_src2, id_dst;
154
155 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2, dst, convert_policy, id_src1, id_src2, id_dst);
156
157 return dst;
158 }
159}
160
giuros0192fd9432018-12-03 17:30:00 +0000161template SimpleTensor<int32_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<int32_t> &src1, const SimpleTensor<int32_t> &src2, SimpleTensor<int32_t> &dst,
162 ConvertPolicy convert_policy);
giuros01164a2722018-11-20 18:34:46 +0000163template SimpleTensor<int16_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, SimpleTensor<int16_t> &dst,
164 ConvertPolicy convert_policy);
165template SimpleTensor<int8_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, SimpleTensor<int8_t> &dst,
166 ConvertPolicy convert_policy);
167template SimpleTensor<half> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<half> &src1, const SimpleTensor<half> &src2, SimpleTensor<half> &dst, ConvertPolicy convert_policy);
168template SimpleTensor<float> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, SimpleTensor<float> &dst, ConvertPolicy convert_policy);
169
170template <typename T>
171SimpleTensor<T> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, DataType dst_data_type, ConvertPolicy convert_policy)
172{
173 ARM_COMPUTE_ERROR_ON_MSG(dst_data_type == DataType::QASYMM8, "For QASYMM8, the quantized output tensor should be passed directly.");
174
175 SimpleTensor<T> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dst_data_type);
176 arithmetic_operation<T>(op, src1, src2, dst, convert_policy);
177 return dst;
178}
179
giuros0192fd9432018-12-03 17:30:00 +0000180template SimpleTensor<int32_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<int32_t> &src1, const SimpleTensor<int32_t> &src2, DataType dst_data_type,
181 ConvertPolicy convert_policy);
giuros01164a2722018-11-20 18:34:46 +0000182template SimpleTensor<int16_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type,
183 ConvertPolicy convert_policy);
184template SimpleTensor<int8_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
185template SimpleTensor<half> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<half> &src1, const SimpleTensor<half> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
186template SimpleTensor<float> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
187
188} // namespace reference
189} // namespace validation
190} // namespace test
191} // namespace arm_compute