blob: 07801aba1c2aaf2861f777b7fd5eecaa9e8dee54 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "ReferenceCPP.h"
25
26#include "TensorFactory.h"
27#include "TensorOperations.h"
28#include "TensorVisitors.h"
29#include "TypePrinter.h"
30
31#include "arm_compute/core/Coordinates.h"
32#include "arm_compute/core/Error.h"
33#include "arm_compute/core/TensorInfo.h"
34#include "arm_compute/core/TensorShape.h"
35#include "arm_compute/runtime/Tensor.h"
36
37#include "boost_wrapper.h"
38
Georgios Pinitasac4e8732017-07-05 17:02:25 +010039#include <algorithm>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010040#include <functional>
Georgios Pinitasac4e8732017-07-05 17:02:25 +010041#include <memory>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010042#include <numeric>
43#include <vector>
44
45using namespace arm_compute::test::validation::tensor_visitors;
46
47namespace arm_compute
48{
49namespace test
50{
51namespace validation
52{
Giorgio Arena50f9fd72017-06-19 17:05:30 +010053// Sobel 3x3
54void ReferenceCPP::sobel_3x3(RawTensor &src, RawTensor &dst_x, RawTensor &dst_y, BorderMode border_mode, uint8_t constant_border_value)
55{
56 ARM_COMPUTE_ERROR_ON(src.data_type() != DataType::U8 || dst_x.data_type() != DataType::S16 || dst_y.data_type() != DataType::S16);
57 Tensor<uint8_t> s(src.shape(), src.data_type(), src.fixed_point_position(), reinterpret_cast<const uint8_t *>(src.data()));
58 Tensor<int16_t> dx(dst_x.shape(), dst_x.data_type(), dst_x.fixed_point_position(), reinterpret_cast<int16_t *>(dst_x.data()));
59 Tensor<int16_t> dy(dst_y.shape(), dst_y.data_type(), dst_y.fixed_point_position(), reinterpret_cast<int16_t *>(dst_y.data()));
60 tensor_operations::sobel_3x3(s, dx, dy, border_mode, constant_border_value);
61}
62
63// Sobel 5x5
64void ReferenceCPP::sobel_5x5(RawTensor &src, RawTensor &dst_x, RawTensor &dst_y, BorderMode border_mode, uint8_t constant_border_value)
65{
66 ARM_COMPUTE_ERROR_ON(src.data_type() != DataType::U8 || dst_x.data_type() != DataType::S16 || dst_y.data_type() != DataType::S16);
67 Tensor<uint8_t> s(src.shape(), src.data_type(), src.fixed_point_position(), reinterpret_cast<const uint8_t *>(src.data()));
68 Tensor<int16_t> dx(dst_x.shape(), dst_x.data_type(), dst_x.fixed_point_position(), reinterpret_cast<int16_t *>(dst_x.data()));
69 Tensor<int16_t> dy(dst_y.shape(), dst_y.data_type(), dst_y.fixed_point_position(), reinterpret_cast<int16_t *>(dst_y.data()));
70 tensor_operations::sobel_5x5(s, dx, dy, border_mode, constant_border_value);
71}
72
Giorgio Arena2ca209e2017-06-13 15:49:37 +010073// Minimum maximum location
Giorgio Arena935deee2017-06-14 13:40:36 +010074void ReferenceCPP::min_max_location(const RawTensor &src, int32_t &min, int32_t &max, IArray<Coordinates2D> &min_loc, IArray<Coordinates2D> &max_loc, uint32_t &min_count, uint32_t &max_count)
Giorgio Arena2ca209e2017-06-13 15:49:37 +010075{
76 const TensorVariant s = TensorFactory::get_tensor(src);
77 boost::apply_visitor(tensor_visitors::min_max_location_visitor(min, max, min_loc, max_loc, min_count, max_count), s);
78}
79
Anthony Barbier6ff3b192017-09-04 18:44:23 +010080// Absolute difference
81void ReferenceCPP::absolute_difference(const RawTensor &src1, const RawTensor &src2, RawTensor &dst)
82{
83 const TensorVariant s1 = TensorFactory::get_tensor(src1);
84 const TensorVariant s2 = TensorFactory::get_tensor(src2);
85 TensorVariant d = TensorFactory::get_tensor(dst);
86 boost::apply_visitor(absolute_difference_visitor(), s1, s2, d);
87}
Giorgio Arenaf7959862017-06-13 15:19:51 +010088
89// Mean and standard deviation
90void ReferenceCPP::mean_and_standard_deviation(const RawTensor &src, float &mean, float &std_dev)
91{
92 ARM_COMPUTE_ERROR_ON(src.data_type() != DataType::U8);
93 const Tensor<uint8_t> s(src.shape(), src.data_type(), src.fixed_point_position(), reinterpret_cast<const uint8_t *>(src.data()));
94 tensor_operations::mean_and_standard_deviation(s, mean, std_dev);
95}
96
Anthony Barbier6ff3b192017-09-04 18:44:23 +010097// Integral image
98void ReferenceCPP::integral_image(const RawTensor &src, RawTensor &dst)
99{
100 ARM_COMPUTE_ERROR_ON(src.data_type() != DataType::U8 || dst.data_type() != DataType::U32);
101 const Tensor<uint8_t> s(src.shape(), src.data_type(), src.fixed_point_position(), reinterpret_cast<const uint8_t *>(src.data()));
102 Tensor<uint32_t> d(dst.shape(), dst.data_type(), dst.fixed_point_position(), reinterpret_cast<uint32_t *>(dst.data()));
103 tensor_operations::integral_image(s, d);
104}
Giorgio Arenaf7959862017-06-13 15:19:51 +0100105
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100106// Accumulate
107void ReferenceCPP::accumulate(const RawTensor &src, RawTensor &dst)
108{
109 ARM_COMPUTE_ERROR_ON(src.data_type() != DataType::U8 || dst.data_type() != DataType::S16);
110 const Tensor<uint8_t> s(src.shape(), src.data_type(), src.fixed_point_position(), reinterpret_cast<const uint8_t *>(src.data()));
111 Tensor<int16_t> d(dst.shape(), dst.data_type(), dst.fixed_point_position(), reinterpret_cast<int16_t *>(dst.data()));
112 tensor_operations::accumulate(s, d);
113}
114
115// Accumulate squared
116void ReferenceCPP::accumulate_squared(const RawTensor &src, RawTensor &dst, uint32_t shift)
117{
118 ARM_COMPUTE_ERROR_ON(src.data_type() != DataType::U8 || dst.data_type() != DataType::S16);
119 const Tensor<uint8_t> s(src.shape(), src.data_type(), src.fixed_point_position(), reinterpret_cast<const uint8_t *>(src.data()));
120 Tensor<int16_t> d(dst.shape(), dst.data_type(), dst.fixed_point_position(), reinterpret_cast<int16_t *>(dst.data()));
121 tensor_operations::accumulate_squared(s, d, shift);
122}
123
124// Accumulate weighted
125void ReferenceCPP::accumulate_weighted(const RawTensor &src, RawTensor &dst, float alpha)
126{
127 ARM_COMPUTE_ERROR_ON(src.data_type() != DataType::U8 || dst.data_type() != DataType::U8);
128 const Tensor<uint8_t> s(src.shape(), src.data_type(), src.fixed_point_position(), reinterpret_cast<const uint8_t *>(src.data()));
129 Tensor<uint8_t> d(dst.shape(), dst.data_type(), dst.fixed_point_position(), reinterpret_cast<uint8_t *>(dst.data()));
130 tensor_operations::accumulate_weighted(s, d, alpha);
131}
132
133// Arithmetic addition
134void ReferenceCPP::arithmetic_addition(const RawTensor &src1, const RawTensor &src2, RawTensor &dst, ConvertPolicy convert_policy)
135{
136 const TensorVariant s1 = TensorFactory::get_tensor(src1);
137 const TensorVariant s2 = TensorFactory::get_tensor(src2);
138 TensorVariant d = TensorFactory::get_tensor(dst);
139 boost::apply_visitor(arithmetic_addition_visitor(convert_policy), s1, s2, d);
140}
141
142// Arithmetic subtraction
143void ReferenceCPP::arithmetic_subtraction(const RawTensor &src1, const RawTensor &src2, RawTensor &dst, ConvertPolicy convert_policy)
144{
145 const TensorVariant s1 = TensorFactory::get_tensor(src1);
146 const TensorVariant s2 = TensorFactory::get_tensor(src2);
147 TensorVariant d = TensorFactory::get_tensor(dst);
148 boost::apply_visitor(arithmetic_subtraction_visitor(convert_policy), s1, s2, d);
149}
150
151// Bitwise and
152void ReferenceCPP::bitwise_and(const RawTensor &src1, const RawTensor &src2, RawTensor &dst)
153{
154 ARM_COMPUTE_ERROR_ON(src1.data_type() != DataType::U8 || src2.data_type() != DataType::U8 || dst.data_type() != DataType::U8);
155 const Tensor<uint8_t> s1(src1.shape(), src1.data_type(), src1.fixed_point_position(), reinterpret_cast<const uint8_t *>(src1.data()));
156 const Tensor<uint8_t> s2(src2.shape(), src2.data_type(), src2.fixed_point_position(), reinterpret_cast<const uint8_t *>(src2.data()));
157 Tensor<uint8_t> d(dst.shape(), dst.data_type(), dst.fixed_point_position(), reinterpret_cast<uint8_t *>(dst.data()));
158 tensor_operations::bitwise_and(s1, s2, d);
159}
160
161// Bitwise or
162void ReferenceCPP::bitwise_or(const RawTensor &src1, const RawTensor &src2, RawTensor &dst)
163{
164 ARM_COMPUTE_ERROR_ON(src1.data_type() != DataType::U8 || src2.data_type() != DataType::U8 || dst.data_type() != DataType::U8);
165 const Tensor<uint8_t> s1(src1.shape(), src1.data_type(), src1.fixed_point_position(), reinterpret_cast<const uint8_t *>(src1.data()));
166 const Tensor<uint8_t> s2(src2.shape(), src2.data_type(), src2.fixed_point_position(), reinterpret_cast<const uint8_t *>(src2.data()));
167 Tensor<uint8_t> d(dst.shape(), dst.data_type(), dst.fixed_point_position(), reinterpret_cast<uint8_t *>(dst.data()));
168 tensor_operations::bitwise_or(s1, s2, d);
169}
170
171// Bitwise xor
172void ReferenceCPP::bitwise_xor(const RawTensor &src1, const RawTensor &src2, RawTensor &dst)
173{
174 ARM_COMPUTE_ERROR_ON(src1.data_type() != DataType::U8 || src2.data_type() != DataType::U8 || dst.data_type() != DataType::U8);
175 const Tensor<uint8_t> s1(src1.shape(), src1.data_type(), src1.fixed_point_position(), reinterpret_cast<const uint8_t *>(src1.data()));
176 const Tensor<uint8_t> s2(src2.shape(), src2.data_type(), src2.fixed_point_position(), reinterpret_cast<const uint8_t *>(src2.data()));
177 Tensor<uint8_t> d(dst.shape(), dst.data_type(), dst.fixed_point_position(), reinterpret_cast<uint8_t *>(dst.data()));
178 tensor_operations::bitwise_xor(s1, s2, d);
179}
180
181// Bitwise not
182void ReferenceCPP::bitwise_not(const RawTensor &src, RawTensor &dst)
183{
184 ARM_COMPUTE_ERROR_ON(src.data_type() != DataType::U8 || dst.data_type() != DataType::U8);
185 const Tensor<uint8_t> s(src.shape(), src.data_type(), src.fixed_point_position(), reinterpret_cast<const uint8_t *>(src.data()));
186 Tensor<uint8_t> d(dst.shape(), dst.data_type(), dst.fixed_point_position(), reinterpret_cast<uint8_t *>(dst.data()));
187 tensor_operations::bitwise_not(s, d);
188}
189
SiCong Libacaf9a2017-06-19 13:41:45 +0100190// Box3x3 filter
191void ReferenceCPP::box3x3(const RawTensor &src, RawTensor &dst, BorderMode border_mode, uint8_t constant_border_value)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100192{
193 ARM_COMPUTE_ERROR_ON(src.data_type() != DataType::U8 || dst.data_type() != DataType::U8);
194 const Tensor<uint8_t> s(src.shape(), src.data_type(), src.fixed_point_position(), reinterpret_cast<const uint8_t *>(src.data()));
195 Tensor<uint8_t> d(dst.shape(), dst.data_type(), dst.fixed_point_position(), reinterpret_cast<uint8_t *>(dst.data()));
SiCong Libacaf9a2017-06-19 13:41:45 +0100196 tensor_operations::box3x3(s, d, border_mode, constant_border_value);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100197}
198
199// Depth conversion
200void ReferenceCPP::depth_convert(const RawTensor &src, RawTensor &dst, ConvertPolicy policy, uint32_t shift)
201{
202 const TensorVariant s = TensorFactory::get_tensor(src);
203 TensorVariant d = TensorFactory::get_tensor(dst);
204 boost::apply_visitor(tensor_visitors::depth_convert_visitor(policy, shift), s, d);
205}
206
SiCong Li5a536642017-06-19 14:47:05 +0100207// Gaussian3x3 filter
208void ReferenceCPP::gaussian3x3(const RawTensor &src, RawTensor &dst, BorderMode border_mode, uint8_t constant_border_value)
209{
210 ARM_COMPUTE_ERROR_ON(src.data_type() != DataType::U8 || dst.data_type() != DataType::U8);
211 const Tensor<uint8_t> s(src.shape(), src.data_type(), src.fixed_point_position(), reinterpret_cast<const uint8_t *>(src.data()));
212 Tensor<uint8_t> d(dst.shape(), dst.data_type(), dst.fixed_point_position(), reinterpret_cast<uint8_t *>(dst.data()));
213 tensor_operations::gaussian3x3(s, d, border_mode, constant_border_value);
214}
215
SiCong Li3eb263e2017-06-19 15:31:43 +0100216// Gaussian5x5 filter
217void ReferenceCPP::gaussian5x5(const RawTensor &src, RawTensor &dst, BorderMode border_mode, uint8_t constant_border_value)
218{
219 ARM_COMPUTE_ERROR_ON(src.data_type() != DataType::U8 || dst.data_type() != DataType::U8);
220 const Tensor<uint8_t> s(src.shape(), src.data_type(), src.fixed_point_position(), reinterpret_cast<const uint8_t *>(src.data()));
221 Tensor<uint8_t> d(dst.shape(), dst.data_type(), dst.fixed_point_position(), reinterpret_cast<uint8_t *>(dst.data()));
222 tensor_operations::gaussian5x5(s, d, border_mode, constant_border_value);
223}
224
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100225// GEMM
226void ReferenceCPP::gemm(const RawTensor &src1, const RawTensor &src2, const RawTensor &src3,
227 RawTensor &dst, float alpha, float beta)
228{
229 const TensorVariant s1 = TensorFactory::get_tensor(src1);
230 const TensorVariant s2 = TensorFactory::get_tensor(src2);
231 const TensorVariant s3 = TensorFactory::get_tensor(src3);
232 TensorVariant d = TensorFactory::get_tensor(dst);
233
234 boost::apply_visitor(tensor_visitors::gemm_visitor(s1, s2, s3, alpha, beta), d);
235}
Isabella Gottardi3b77e9d2017-06-22 11:05:41 +0100236// Non linear filter
237void ReferenceCPP::non_linear_filter(const RawTensor &src, RawTensor &dst, NonLinearFilterFunction function, unsigned int mask_size,
238 MatrixPattern pattern, const uint8_t *mask, BorderMode border_mode, uint8_t constant_border_value)
239{
240 ARM_COMPUTE_ERROR_ON(src.data_type() != DataType::U8 || dst.data_type() != DataType::U8);
241 const Tensor<uint8_t> s(src.shape(), src.data_type(), src.fixed_point_position(), reinterpret_cast<const uint8_t *>(src.data()));
242 Tensor<uint8_t> d(dst.shape(), dst.data_type(), dst.fixed_point_position(), reinterpret_cast<uint8_t *>(dst.data()));
243 tensor_operations::non_linear_filter(s, d, function, mask_size, pattern, mask, border_mode, constant_border_value);
244}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100245
246// Pixel-wise multiplication
247void ReferenceCPP::pixel_wise_multiplication(const RawTensor &src1, const RawTensor &src2, RawTensor &dst, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
248{
249 const TensorVariant s1 = TensorFactory::get_tensor(src1);
250 const TensorVariant s2 = TensorFactory::get_tensor(src2);
251 TensorVariant d = TensorFactory::get_tensor(dst);
252 boost::apply_visitor(pixel_wise_multiplication_visitor(scale, convert_policy, rounding_policy), s1, s2, d);
253}
254
255// Fixed-point Pixel-wise multiplication
256void ReferenceCPP::fixed_point_pixel_wise_multiplication(const RawTensor &src1, const RawTensor &src2, RawTensor &dst, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
257{
258 const TensorVariant s1 = TensorFactory::get_tensor(src1);
259 const TensorVariant s2 = TensorFactory::get_tensor(src2);
260 TensorVariant d = TensorFactory::get_tensor(dst);
261 boost::apply_visitor(tensor_visitors::fixed_point_pixel_wise_multiplication_visitor(s1, s2, scale, convert_policy, rounding_policy), d);
262}
263
Isabella Gottardib797fa22017-06-23 15:02:11 +0100264// Table lookup
265template <typename T>
266void ReferenceCPP::table_lookup(const RawTensor &src, RawTensor &dst, std::map<T, T> &lut)
267{
268 const TensorVariant s = TensorFactory::get_tensor(src);
269 TensorVariant d = TensorFactory::get_tensor(dst);
270 boost::apply_visitor(tensor_visitors::table_lookup<T>(s, lut), d);
271}
272#ifndef DOXYGEN_SKIP_THIS
273template void arm_compute::test::validation::ReferenceCPP::table_lookup<uint8_t>(const RawTensor &src, RawTensor &dst, std::map<uint8_t, uint8_t> &lut);
274template void arm_compute::test::validation::ReferenceCPP::table_lookup<int16_t>(const RawTensor &src, RawTensor &dst, std::map<int16_t, int16_t> &lut);
275#endif /* DOXYGEN_SKIP_THIS */
276
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100277// Threshold
278void ReferenceCPP::threshold(const RawTensor &src, RawTensor &dst, uint8_t threshold, uint8_t false_value, uint8_t true_value, ThresholdType type, uint8_t upper)
279{
280 ARM_COMPUTE_ERROR_ON(src.data_type() != DataType::U8 || dst.data_type() != DataType::U8);
281 const Tensor<uint8_t> s(src.shape(), src.data_type(), src.fixed_point_position(), reinterpret_cast<const uint8_t *>(src.data()));
282 Tensor<uint8_t> d(dst.shape(), dst.data_type(), dst.fixed_point_position(), reinterpret_cast<uint8_t *>(dst.data()));
Isabella Gottardib797fa22017-06-23 15:02:11 +0100283 tensor_operations::threshold(s, d, threshold, false_value, true_value, type, upper);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100284}
285
286// Activation layer
287void ReferenceCPP::activation_layer(const RawTensor &input, RawTensor &output, ActivationLayerInfo act_info)
288{
289 const TensorVariant s = TensorFactory::get_tensor(input);
290 TensorVariant d = TensorFactory::get_tensor(output);
291 boost::apply_visitor(tensor_visitors::activation_layer_visitor(s, act_info), d);
292}
293
294// Batch Normalization Layer
295void ReferenceCPP::batch_normalization_layer(const RawTensor &src, RawTensor &dst, const RawTensor &mean, const RawTensor &var, const RawTensor &beta, const RawTensor &gamma, float epsilon,
296 int fixed_point_position)
297{
298 const TensorVariant s = TensorFactory::get_tensor(src);
299 TensorVariant d = TensorFactory::get_tensor(dst);
300 const TensorVariant m = TensorFactory::get_tensor(mean);
301 const TensorVariant v = TensorFactory::get_tensor(var);
302 const TensorVariant b = TensorFactory::get_tensor(beta);
303 const TensorVariant g = TensorFactory::get_tensor(gamma);
304 boost::apply_visitor(tensor_visitors::batch_normalization_layer_visitor(s, m, v, b, g, epsilon, fixed_point_position), d);
305}
306
307// Convolution Layer
308void ReferenceCPP::convolution_layer(const RawTensor &src, const RawTensor &weights, const RawTensor &bias, RawTensor &dst, const PadStrideInfo &conv_info)
309{
310 const TensorVariant s = TensorFactory::get_tensor(src);
311 const TensorVariant w = TensorFactory::get_tensor(weights);
312 const TensorVariant b = TensorFactory::get_tensor(bias);
313 TensorVariant d = TensorFactory::get_tensor(dst);
314 boost::apply_visitor(tensor_visitors::convolution_layer_visitor(s, w, b, conv_info), d);
315}
316
Georgios Pinitasac4e8732017-07-05 17:02:25 +0100317// Depth concatenate layer
318void ReferenceCPP::depth_concatenate_layer(const std::vector<std::unique_ptr<RawTensor>> &srcs, RawTensor &dst)
319{
320 std::vector<TensorVariant> ss;
321 ss.resize(srcs.size());
322 std::transform(srcs.begin(), srcs.end(), ss.begin(), [](std::unique_ptr<RawTensor> const & t)
323 {
324 return TensorFactory::get_tensor(*t);
325 });
326 TensorVariant d = TensorFactory::get_tensor(dst);
327 boost::apply_visitor(tensor_visitors::depth_concatenate_layer_visitor(ss), d);
328}
329
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100330// Fully connected layer
331void ReferenceCPP::fully_connected_layer(const RawTensor &src, const RawTensor &weights, const RawTensor &bias, RawTensor &dst)
332{
333 const TensorVariant s = TensorFactory::get_tensor(src);
334 const TensorVariant w = TensorFactory::get_tensor(weights);
335 const TensorVariant b = TensorFactory::get_tensor(bias);
336 TensorVariant d = TensorFactory::get_tensor(dst);
337 boost::apply_visitor(tensor_visitors::fully_connected_layer_visitor(s, w, b), d);
338}
339
340// Normalization Layer
341void ReferenceCPP::normalization_layer(const RawTensor &src, RawTensor &dst, NormalizationLayerInfo norm_info)
342{
343 const TensorVariant s = TensorFactory::get_tensor(src);
344 TensorVariant d = TensorFactory::get_tensor(dst);
345 boost::apply_visitor(tensor_visitors::normalization_layer_visitor(s, norm_info), d);
346}
347
348// Pooling Layer
349void ReferenceCPP::pooling_layer(const RawTensor &src, RawTensor &dst, PoolingLayerInfo pool_info, int fixed_point_position)
350{
351 const TensorVariant s = TensorFactory::get_tensor(src);
352 TensorVariant d = TensorFactory::get_tensor(dst);
353 boost::apply_visitor(tensor_visitors::pooling_layer_visitor(s, pool_info, fixed_point_position), d);
354}
355
Georgios Pinitas7b7858d2017-06-21 16:44:24 +0100356// ROI Pooling Layer
357void ReferenceCPP::roi_pooling_layer(const RawTensor &src, RawTensor &dst, const std::vector<ROI> &rois, const ROIPoolingLayerInfo &pool_info)
358{
359 const TensorVariant s = TensorFactory::get_tensor(src);
360 TensorVariant d = TensorFactory::get_tensor(dst);
361 boost::apply_visitor(tensor_visitors::roi_pooling_layer_visitor(s, rois, pool_info), d);
362}
363
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100364// Softmax Layer
365void ReferenceCPP::softmax_layer(const RawTensor &src, RawTensor &dst)
366{
367 const TensorVariant s = TensorFactory::get_tensor(src);
368 TensorVariant d = TensorFactory::get_tensor(dst);
369 boost::apply_visitor(tensor_visitors::softmax_layer_visitor(s), d);
370}
371
372// Fixed point operation
373void ReferenceCPP::fixed_point_operation(const RawTensor &src, RawTensor &dst, FixedPointOp op)
374{
375 const TensorVariant s = TensorFactory::get_tensor(src);
376 TensorVariant d = TensorFactory::get_tensor(dst);
377 boost::apply_visitor(tensor_visitors::fixed_point_operation_visitor(s, op), d);
378}
379
380} // namespace validation
381} // namespace test
382} // namespace arm_compute