blob: fcc584dd46bdf6a4839d9694a85732c19c982b3f [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_TEST_TENSOR_VISITORS_H__
25#define __ARM_COMPUTE_TEST_TENSOR_VISITORS_H__
26
27#include "Tensor.h"
28#include "TensorOperations.h"
29#include "arm_compute/core/Error.h"
30
31#include "boost_wrapper.h"
32
Georgios Pinitasac4e8732017-07-05 17:02:25 +010033#include <algorithm>
34#include <memory>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010035#include <ostream>
Georgios Pinitasd4f8c272017-06-30 16:16:19 +010036#include <vector>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010037
38namespace arm_compute
39{
40namespace test
41{
42namespace validation
43{
44namespace tensor_visitors
45{
46// Absolute Difference visitor
47struct absolute_difference_visitor : public boost::static_visitor<>
48{
49public:
50 template <typename T1, typename T2, typename T3>
51 void operator()(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3> &out) const
52 {
53 tensor_operations::absolute_difference(in1, in2, out);
54 }
55};
56// Arithmetic Addition visitor
57struct arithmetic_addition_visitor : public boost::static_visitor<>
58{
59public:
60 explicit arithmetic_addition_visitor(ConvertPolicy convert_policy)
61 : _policy(convert_policy)
62 {
63 }
64
65 template <typename T1, typename T2, typename T3>
66 void operator()(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3> &out) const
67 {
68 tensor_operations::arithmetic_addition(in1, in2, out, _policy);
69 }
70
71private:
72 ConvertPolicy _policy;
73};
74// Arithmetic Subtraction visitor
75struct arithmetic_subtraction_visitor : public boost::static_visitor<>
76{
77public:
78 explicit arithmetic_subtraction_visitor(ConvertPolicy convert_policy)
79 : _policy(convert_policy)
80 {
81 }
82
83 template <typename T1, typename T2, typename T3>
84 void operator()(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3> &out) const
85 {
86 tensor_operations::arithmetic_subtraction(in1, in2, out, _policy);
87 }
88
89private:
90 ConvertPolicy _policy;
91};
92// Depth Convert visitor
93struct depth_convert_visitor : public boost::static_visitor<>
94{
95public:
96 explicit depth_convert_visitor(ConvertPolicy policy, uint32_t shift)
97 : _policy(policy), _shift(shift)
98 {
99 }
100
101 template <typename T1, typename T2>
102 void operator()(const Tensor<T1> &in, Tensor<T2> &out) const
103 {
104 tensor_operations::depth_convert(in, out, _policy, _shift);
105 }
106
107private:
108 ConvertPolicy _policy;
109 uint32_t _shift;
110};
111// GEMM visitor
112struct gemm_visitor : public boost::static_visitor<>
113{
114public:
115 explicit gemm_visitor(const TensorVariant &in1, const TensorVariant &in2, const TensorVariant &in3, float alpha, float beta)
116 : _in1(in1), _in2(in2), _in3(in3), _alpha(alpha), _beta(beta)
117 {
118 }
119
120 template <typename T>
121 void operator()(Tensor<T> &out) const
122 {
123 const Tensor<T> &in1 = boost::get<Tensor<T>>(_in1);
124 const Tensor<T> &in2 = boost::get<Tensor<T>>(_in2);
125 const Tensor<T> &in3 = boost::get<Tensor<T>>(_in3);
126 tensor_operations::gemm(in1, in2, in3, out, _alpha, _beta);
127 }
128
129private:
130 const TensorVariant &_in1, &_in2, &_in3;
131 float _alpha;
132 float _beta;
133};
134// Pixel-wise Multiplication visitor
135struct pixel_wise_multiplication_visitor : public boost::static_visitor<>
136{
137public:
138 explicit pixel_wise_multiplication_visitor(float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
139 : _scale(scale), _convert_policy(convert_policy), _rounding_policy(rounding_policy)
140 {
141 }
142
143 template <typename T1, typename T2, typename T3>
144 void operator()(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3> &out) const
145 {
146 tensor_operations::pixel_wise_multiplication(in1, in2, out, _scale, _convert_policy, _rounding_policy);
147 }
148
149private:
150 float _scale;
151 ConvertPolicy _convert_policy;
152 RoundingPolicy _rounding_policy;
153};
154// Fixed Point Pixel-wise Multiplication visitor
155struct fixed_point_pixel_wise_multiplication_visitor : public boost::static_visitor<>
156{
157public:
158 explicit fixed_point_pixel_wise_multiplication_visitor(const TensorVariant &in1, const TensorVariant &in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
159 : _in1(in1), _in2(in2), _scale(scale), _convert_policy(convert_policy), _rounding_policy(rounding_policy)
160 {
161 }
162
163 template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type>
164 void operator()(Tensor<T> &out) const
165 {
166 const Tensor<T> &in1 = boost::get<Tensor<T>>(_in1);
167 const Tensor<T> &in2 = boost::get<Tensor<T>>(_in2);
168 tensor_operations::fixed_point_pixel_wise_multiplication(in1, in2, out, _scale, _convert_policy, _rounding_policy);
169 }
170 template < typename T, typename std::enable_if < !std::is_integral<T>::value, int >::type = 0 >
171 void operator()(Tensor<T> &out) const
172 {
173 ARM_COMPUTE_ERROR("NOT SUPPORTED!");
174 }
175
176private:
177 const TensorVariant &_in1;
178 const TensorVariant &_in2;
179 float _scale;
180 ConvertPolicy _convert_policy;
181 RoundingPolicy _rounding_policy;
182};
183// Threshold operation
184void threshold_operation(const Tensor<uint8_t> &in, Tensor<uint8_t> &out, uint8_t threshold, uint8_t false_value, uint8_t true_value, ThresholdType type, uint8_t upper)
185{
186 tensor_operations::threshold(in, out, threshold, false_value, true_value, type, upper);
187}
188// Activation layer visitor
189struct activation_layer_visitor : public boost::static_visitor<>
190{
191public:
192 explicit activation_layer_visitor(const TensorVariant &in, ActivationLayerInfo act_info)
193 : _in(in), _act_info(act_info)
194 {
195 }
196
197 template <typename T>
198 void operator()(Tensor<T> &out) const
199 {
200 const auto &in = boost::get<Tensor<T>>(_in);
201 tensor_operations::activation_layer(in, out, _act_info);
202 }
203
204private:
205 const TensorVariant &_in;
206 const ActivationLayerInfo _act_info;
207};
208// Batch Normalization Layer visitor
209struct batch_normalization_layer_visitor : public boost::static_visitor<>
210{
211public:
212 explicit batch_normalization_layer_visitor(const TensorVariant &in, const TensorVariant &mean, const TensorVariant &var, const TensorVariant &beta, const TensorVariant &gamma, float epsilon,
213 int fixed_point_position = 0)
214 : _in(in), _mean(mean), _var(var), _beta(beta), _gamma(gamma), _epsilon(epsilon), _fixed_point_position(fixed_point_position)
215 {
216 }
217
218 template <typename T>
219 void operator()(Tensor<T> &out) const
220 {
221 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
222 const Tensor<T> &mean = boost::get<Tensor<T>>(_mean);
223 const Tensor<T> &var = boost::get<Tensor<T>>(_var);
224 const Tensor<T> &beta = boost::get<Tensor<T>>(_beta);
225 const Tensor<T> &gamma = boost::get<Tensor<T>>(_gamma);
226 tensor_operations::batch_normalization_layer(in, out, mean, var, beta, gamma, _epsilon, _fixed_point_position);
227 }
228
229private:
230 const TensorVariant &_in, &_mean, &_var, &_beta, &_gamma;
231 float _epsilon;
232 int _fixed_point_position;
233};
234// Convolution Layer visitor
235struct convolution_layer_visitor : public boost::static_visitor<>
236{
237public:
238 explicit convolution_layer_visitor(const TensorVariant &in, const TensorVariant &weights, const TensorVariant &bias, PadStrideInfo conv_info)
239 : _in(in), _weights(weights), _bias(bias), _conv_info(conv_info)
240 {
241 }
242
243 template <typename T>
244 void operator()(Tensor<T> &out) const
245 {
246 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
247 const Tensor<T> &weights = boost::get<Tensor<T>>(_weights);
248 const Tensor<T> &bias = boost::get<Tensor<T>>(_bias);
249 tensor_operations::convolution_layer(in, weights, bias, out, _conv_info);
250 }
251
252private:
253 const TensorVariant &_in;
254 const TensorVariant &_weights;
255 const TensorVariant &_bias;
256 PadStrideInfo _conv_info;
257};
Georgios Pinitasac4e8732017-07-05 17:02:25 +0100258// Depth Concatenate Layer visitor
259struct depth_concatenate_layer_visitor : public boost::static_visitor<>
260{
261public:
262 explicit depth_concatenate_layer_visitor(const std::vector<TensorVariant> &srcs)
263 : _srcs(srcs)
264 {
265 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100266
Georgios Pinitasac4e8732017-07-05 17:02:25 +0100267 template <typename T>
268 void operator()(Tensor<T> &out) const
269 {
270 std::vector<const Tensor<T> *> srcs;
271 srcs.resize(_srcs.size());
272 std::transform(_srcs.begin(), _srcs.end(), srcs.begin(), [](const TensorVariant & t)
273 {
274 return &(boost::get<Tensor<T>>(t));
275 });
276 tensor_operations::depth_concatenate_layer(srcs, out);
277 }
278
279private:
280 const std::vector<TensorVariant> &_srcs;
281};
282// Fully Connected Layer visitor
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100283struct fully_connected_layer_visitor : public boost::static_visitor<>
284{
285public:
286 explicit fully_connected_layer_visitor(const TensorVariant &in, const TensorVariant &weights, const TensorVariant &bias)
287 : _in(in), _weights(weights), _bias(bias)
288 {
289 }
290 template <typename T>
291 void operator()(Tensor<T> &out) const
292 {
293 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
294 const Tensor<T> &weights = boost::get<Tensor<T>>(_weights);
295 const Tensor<T> &bias = boost::get<Tensor<T>>(_bias);
296 tensor_operations::fully_connected_layer(in, weights, bias, out);
297 }
298
299private:
300 const TensorVariant &_in;
301 const TensorVariant &_weights;
302 const TensorVariant &_bias;
303};
304
305// Normalization Layer visitor
306struct normalization_layer_visitor : public boost::static_visitor<>
307{
308public:
309 explicit normalization_layer_visitor(const TensorVariant &in, NormalizationLayerInfo norm_info)
310 : _in(in), _norm_info(norm_info)
311 {
312 }
313
314 template <typename T>
315 void operator()(Tensor<T> &out) const
316 {
317 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
318 tensor_operations::normalization_layer(in, out, _norm_info);
319 }
320
321private:
322 const TensorVariant &_in;
323 NormalizationLayerInfo _norm_info;
324};
325// Pooling layer
326struct pooling_layer_visitor : public boost::static_visitor<>
327{
328public:
329 explicit pooling_layer_visitor(const TensorVariant &in, PoolingLayerInfo pool_info, int fixed_point_position = 0)
330 : _in(in), _pool_info(pool_info), _fixed_point_position(fixed_point_position)
331 {
332 }
333
334 template <typename T>
335 void operator()(Tensor<T> &out) const
336 {
337 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
338 tensor_operations::pooling_layer(in, out, _pool_info, _fixed_point_position);
339 }
340
341private:
342 const TensorVariant &_in;
343 PoolingLayerInfo _pool_info;
344 int _fixed_point_position;
345};
Georgios Pinitas7b7858d2017-06-21 16:44:24 +0100346
347// ROI Pooling layer
348struct roi_pooling_layer_visitor : public boost::static_visitor<>
349{
350public:
351 explicit roi_pooling_layer_visitor(const TensorVariant &in, const std::vector<ROI> &rois, ROIPoolingLayerInfo pool_info)
352 : _in(in), _rois(rois), _pool_info(pool_info)
353 {
354 }
355
356 template <typename T>
357 void operator()(Tensor<T> &out) const
358 {
359 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
360 tensor_operations::roi_pooling_layer(in, out, _rois, _pool_info);
361 }
362
363private:
364 const TensorVariant &_in;
365 const std::vector<ROI> &_rois;
366 ROIPoolingLayerInfo _pool_info;
367};
368
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100369// Softmax Layer visitor
370struct softmax_layer_visitor : public boost::static_visitor<>
371{
372public:
373 explicit softmax_layer_visitor(const TensorVariant &in)
374 : _in(in)
375 {
376 }
377
378 template <typename T>
379 void operator()(Tensor<T> &out) const
380 {
381 const auto &in = boost::get<Tensor<T>>(_in);
382 tensor_operations::softmax_layer(in, out);
383 }
384
385private:
386 const TensorVariant &_in;
387};
388// Fixed Point operations visitor
389struct fixed_point_operation_visitor : public boost::static_visitor<>
390{
391public:
392 explicit fixed_point_operation_visitor(const TensorVariant &in, FixedPointOp op)
393 : _in(in), _op(op)
394 {
395 }
396
397 template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
398 void operator()(Tensor<T> &out) const
399 {
400 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
401 tensor_operations::fixed_point_operation(in, out, _op);
402 }
403 template < typename T, typename std::enable_if < !std::is_integral<T>::value, int >::type = 0 >
404 void operator()(Tensor<T> &out) const
405 {
406 ARM_COMPUTE_ERROR("NOT SUPPORTED!");
407 }
408
409private:
410 const TensorVariant &_in;
411 FixedPointOp _op;
412};
413// Print Tensor visitor
414struct print_visitor : public boost::static_visitor<>
415{
416public:
417 explicit print_visitor(std::ostream &out)
418 : _out(out)
419 {
420 }
421
422 template <typename T>
423 void operator()(const Tensor<T> &in) const
424 {
425 tensor_operations::print(in, _out);
426 }
427
428private:
429 std::ostream &_out;
430};
431} // namespace tensor_visitors
432} // namespace validation
433} // namespace test
434} // namespace arm_compute
435
436#endif /* __ARM_COMPUTE_TEST_TENSOR_VISITORS_H__ */