blob: fb7a6d699798b6c7590e5e06413d1cf05f02a5c7 [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Michalis Spyrouaea14c62019-01-03 11:10:25 +00002 * Copyright (c) 2017-2019 ARM Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "ReductionOperation.h"
25
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010026#include "tests/validation/Helpers.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010027
28#include <algorithm>
29#include <cmath>
30
31namespace arm_compute
32{
33namespace test
34{
35namespace validation
36{
37namespace reference
38{
39namespace
40{
Michalis Spyrou7930db42018-11-22 17:36:28 +000041template <typename T, typename OT>
42OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, int stride)
Georgios Pinitasd9769582017-08-03 10:19:40 +010043{
Michalis Spyrou7930db42018-11-22 17:36:28 +000044 using type = typename std::remove_cv<OT>::type;
Manuel Bottinib412fab2018-12-10 17:40:23 +000045 auto res = (op == ReductionOperation::PROD) ? type(1) : type(0);
Georgios Pinitasd9769582017-08-03 10:19:40 +010046
Michalis Spyrou8aaf93e2018-10-11 17:33:32 +010047 if(std::is_integral<type>::value)
Michalis Spyrou7e9391b2018-10-05 14:49:28 +010048 {
Manuel Bottinib412fab2018-12-10 17:40:23 +000049 auto int_res = static_cast<uint32_t>(res);
Michalis Spyrou8aaf93e2018-10-11 17:33:32 +010050 for(int i = 0; i < reduce_elements; ++i)
51 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +000052 auto elem = *(ptr + stride * i);
Michalis Spyrou7930db42018-11-22 17:36:28 +000053
54 switch(op)
55 {
56 case ReductionOperation::ARG_IDX_MIN:
Michalis Spyrouaea14c62019-01-03 11:10:25 +000057 if(*(ptr + stride * static_cast<uint32_t>(int_res)) > elem)
Michalis Spyrou7930db42018-11-22 17:36:28 +000058 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +000059 int_res = static_cast<uint32_t>(i);
Michalis Spyrou7930db42018-11-22 17:36:28 +000060 }
61 break;
62 case ReductionOperation::ARG_IDX_MAX:
Michalis Spyrouaea14c62019-01-03 11:10:25 +000063 if(*(ptr + stride * static_cast<uint32_t>(int_res)) < elem)
Michalis Spyrou7930db42018-11-22 17:36:28 +000064 {
Michalis Spyrouaea14c62019-01-03 11:10:25 +000065 int_res = static_cast<uint32_t>(i);
Michalis Spyrou7930db42018-11-22 17:36:28 +000066 }
67 break;
68 case ReductionOperation::SUM_SQUARE:
69 int_res += elem * elem;
70 break;
71 case ReductionOperation::MEAN_SUM:
72 case ReductionOperation::SUM:
73 int_res += elem;
74 break;
Manuel Bottinib412fab2018-12-10 17:40:23 +000075 case ReductionOperation::PROD:
76 int_res *= elem;
77 break;
Michalis Spyrou7930db42018-11-22 17:36:28 +000078 default:
79 ARM_COMPUTE_ERROR("Operation not supported");
80 }
Michalis Spyrou8aaf93e2018-10-11 17:33:32 +010081 }
82 if(op == ReductionOperation::MEAN_SUM && reduce_elements > 0)
83 {
84 int_res /= reduce_elements;
85 }
86 res = saturate_cast<type>(int_res);
Michalis Spyrou7e9391b2018-10-05 14:49:28 +010087 }
Michalis Spyrou8aaf93e2018-10-11 17:33:32 +010088 else
89 {
90 for(int i = 0; i < reduce_elements; ++i)
91 {
92 auto elem = *(ptr + stride * i);
Michalis Spyrou7930db42018-11-22 17:36:28 +000093 switch(op)
94 {
95 case ReductionOperation::ARG_IDX_MIN:
96 if(*(ptr + stride * static_cast<uint32_t>(res)) > elem)
97 {
98 res = static_cast<uint32_t>(i);
99 }
100 break;
101 case ReductionOperation::ARG_IDX_MAX:
102 if(*(ptr + stride * static_cast<uint32_t>(res)) < elem)
103 {
104 res = static_cast<uint32_t>(i);
105 }
106 break;
107 case ReductionOperation::SUM_SQUARE:
108 res += elem * elem;
109 break;
110 case ReductionOperation::MEAN_SUM:
111 case ReductionOperation::SUM:
112 res += elem;
113 break;
Manuel Bottinib412fab2018-12-10 17:40:23 +0000114 case ReductionOperation::PROD:
115 res *= elem;
116 break;
Michalis Spyrou7930db42018-11-22 17:36:28 +0000117 default:
118 ARM_COMPUTE_ERROR("Operation not supported");
119 }
Michalis Spyrou8aaf93e2018-10-11 17:33:32 +0100120 }
121 if(op == ReductionOperation::MEAN_SUM && reduce_elements > 0)
122 {
123 res /= reduce_elements;
124 }
125 }
Michalis Spyrou8aaf93e2018-10-11 17:33:32 +0100126 return res;
Georgios Pinitasd9769582017-08-03 10:19:40 +0100127}
128} // namespace
129
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000130template <typename T, typename OT>
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000131SimpleTensor<OT> compute_reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100132{
133 // Create reference
Michalis Spyrou7930db42018-11-22 17:36:28 +0000134 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
135 DataType output_data_type = is_arg_min_max ? DataType::U32 : src.data_type();
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000136 SimpleTensor<OT> dst{ dst_shape, output_data_type, 1, src.quantization_info() };
Michalis Spyrou8aaf93e2018-10-11 17:33:32 +0100137 const unsigned int src_width = src.shape().x();
138 const unsigned int src_height = src.shape().y();
139 const unsigned int src_depth = src.shape().z();
140 const unsigned int src_batch = src.shape()[3];
141 const int reduce_elems = src.shape()[axis];
Georgios Pinitasd9769582017-08-03 10:19:40 +0100142
Michalis Spyrou7e9391b2018-10-05 14:49:28 +0100143 switch(axis)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100144 {
Michalis Spyrou7e9391b2018-10-05 14:49:28 +0100145 case 0:
Georgios Pinitasd9769582017-08-03 10:19:40 +0100146 {
Michalis Spyrou8aaf93e2018-10-11 17:33:32 +0100147 const unsigned int upper_dims = src.shape().total_size_upper(1);
Michalis Spyrou7e9391b2018-10-05 14:49:28 +0100148 for(unsigned int du = 0; du < upper_dims; ++du)
149 {
Michalis Spyrou8aaf93e2018-10-11 17:33:32 +0100150 const T *src_row_ptr = src.data() + du * reduce_elems;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000151 dst[du] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, 1);
Michalis Spyrou7e9391b2018-10-05 14:49:28 +0100152 }
Georgios Pinitasd9769582017-08-03 10:19:40 +0100153 }
Michalis Spyrou7e9391b2018-10-05 14:49:28 +0100154 break;
155 case 1:
Georgios Pinitasd9769582017-08-03 10:19:40 +0100156 {
Michalis Spyrou7e9391b2018-10-05 14:49:28 +0100157 const unsigned int upper_dims = src.shape().total_size_upper(2);
158 for(unsigned int du = 0; du < upper_dims; ++du)
159 {
160 for(unsigned int x = 0; x < src_width; ++x)
161 {
Michalis Spyrou8aaf93e2018-10-11 17:33:32 +0100162 const int in_offset = du * src_height * src_width + x;
163 const int out_offset = du * src_width + x;
164 const T *src_row_ptr = src.data() + in_offset;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000165 dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width);
Michalis Spyrou7e9391b2018-10-05 14:49:28 +0100166 }
167 }
168 }
169 break;
170 case 2:
171 {
172 const unsigned int upper_dims = src.shape().total_size_upper(3);
173 for(unsigned int du = 0; du < upper_dims; ++du)
174 {
175 for(unsigned int x = 0; x < src_width; ++x)
176 {
177 for(unsigned int y = 0; y < src_height; ++y)
178 {
Michalis Spyrou8aaf93e2018-10-11 17:33:32 +0100179 const int in_offset = du * src_depth * src_height * src_width + y * src_width + x;
180 const int out_offset = du * src_width * src_height + y * src_width + x;
181 const T *src_row_ptr = src.data() + in_offset;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000182 dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_height * src_width);
Michalis Spyrou7e9391b2018-10-05 14:49:28 +0100183 }
184 }
185 }
186 }
187 break;
188 case 3:
189 {
190 const unsigned int upper_dims = src.shape().total_size_upper(4);
191 for(unsigned int du = 0; du < upper_dims; ++du)
192 {
193 for(unsigned int z = 0; z < src_depth; ++z)
194 {
195 for(unsigned int y = 0; y < src_height; ++y)
196 {
197 for(unsigned int x = 0; x < src_width; ++x)
198 {
Michalis Spyrou8aaf93e2018-10-11 17:33:32 +0100199 const int in_offset = du * src_batch * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x;
200 const int out_offset = du * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x;
201 const T *src_row_ptr = src.data() + in_offset;
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000202 dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth);
Michalis Spyrou7e9391b2018-10-05 14:49:28 +0100203 }
204 }
205 }
206 }
207 }
208 break;
209 default:
Georgios Pinitasd9769582017-08-03 10:19:40 +0100210 ARM_COMPUTE_ERROR("Unsupported reduction axis");
Georgios Pinitasd9769582017-08-03 10:19:40 +0100211 }
212
213 return dst;
214}
215
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000216template <typename T, typename OT>
217SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op)
218{
219 return compute_reduction_operation<T, OT>(src, dst_shape, axis, op);
220}
221
222template <>
223SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op)
224{
225 if(src.data_type() == DataType::QASYMM8 && op != ReductionOperation::MEAN_SUM)
226 {
227 SimpleTensor<float> src_f = convert_from_asymmetric(src);
228 SimpleTensor<float> dst_f = reference::reduction_operation<float, float>(src_f, dst_shape, axis, op);
229 return convert_to_asymmetric(dst_f, src.quantization_info());
230 }
231 else
232 {
233 return compute_reduction_operation<uint8_t, uint8_t>(src, dst_shape, axis, op);
234 }
235}
236
237template SimpleTensor<float> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
238template SimpleTensor<half> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
239
Michalis Spyrouaea14c62019-01-03 11:10:25 +0000240template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
241template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
242template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
Manuel Bottini1d4f3852019-01-14 15:14:43 +0000243
Georgios Pinitasd9769582017-08-03 10:19:40 +0100244} // namespace reference
245} // namespace validation
246} // namespace test
247} // namespace arm_compute