blob: 723302c973a652e7786f0d7d788cd9365583be20 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_TEST_TENSOR_VISITORS_H__
25#define __ARM_COMPUTE_TEST_TENSOR_VISITORS_H__
26
27#include "Tensor.h"
28#include "TensorOperations.h"
29#include "arm_compute/core/Error.h"
30
31#include "boost_wrapper.h"
32
33#include <ostream>
Georgios Pinitasd4f8c272017-06-30 16:16:19 +010034#include <vector>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010035
36namespace arm_compute
37{
38namespace test
39{
40namespace validation
41{
42namespace tensor_visitors
43{
44// Absolute Difference visitor
45struct absolute_difference_visitor : public boost::static_visitor<>
46{
47public:
48 template <typename T1, typename T2, typename T3>
49 void operator()(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3> &out) const
50 {
51 tensor_operations::absolute_difference(in1, in2, out);
52 }
53};
54// Arithmetic Addition visitor
55struct arithmetic_addition_visitor : public boost::static_visitor<>
56{
57public:
58 explicit arithmetic_addition_visitor(ConvertPolicy convert_policy)
59 : _policy(convert_policy)
60 {
61 }
62
63 template <typename T1, typename T2, typename T3>
64 void operator()(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3> &out) const
65 {
66 tensor_operations::arithmetic_addition(in1, in2, out, _policy);
67 }
68
69private:
70 ConvertPolicy _policy;
71};
72// Arithmetic Subtraction visitor
73struct arithmetic_subtraction_visitor : public boost::static_visitor<>
74{
75public:
76 explicit arithmetic_subtraction_visitor(ConvertPolicy convert_policy)
77 : _policy(convert_policy)
78 {
79 }
80
81 template <typename T1, typename T2, typename T3>
82 void operator()(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3> &out) const
83 {
84 tensor_operations::arithmetic_subtraction(in1, in2, out, _policy);
85 }
86
87private:
88 ConvertPolicy _policy;
89};
90// Depth Convert visitor
91struct depth_convert_visitor : public boost::static_visitor<>
92{
93public:
94 explicit depth_convert_visitor(ConvertPolicy policy, uint32_t shift)
95 : _policy(policy), _shift(shift)
96 {
97 }
98
99 template <typename T1, typename T2>
100 void operator()(const Tensor<T1> &in, Tensor<T2> &out) const
101 {
102 tensor_operations::depth_convert(in, out, _policy, _shift);
103 }
104
105private:
106 ConvertPolicy _policy;
107 uint32_t _shift;
108};
109// GEMM visitor
110struct gemm_visitor : public boost::static_visitor<>
111{
112public:
113 explicit gemm_visitor(const TensorVariant &in1, const TensorVariant &in2, const TensorVariant &in3, float alpha, float beta)
114 : _in1(in1), _in2(in2), _in3(in3), _alpha(alpha), _beta(beta)
115 {
116 }
117
118 template <typename T>
119 void operator()(Tensor<T> &out) const
120 {
121 const Tensor<T> &in1 = boost::get<Tensor<T>>(_in1);
122 const Tensor<T> &in2 = boost::get<Tensor<T>>(_in2);
123 const Tensor<T> &in3 = boost::get<Tensor<T>>(_in3);
124 tensor_operations::gemm(in1, in2, in3, out, _alpha, _beta);
125 }
126
127private:
128 const TensorVariant &_in1, &_in2, &_in3;
129 float _alpha;
130 float _beta;
131};
132// Pixel-wise Multiplication visitor
133struct pixel_wise_multiplication_visitor : public boost::static_visitor<>
134{
135public:
136 explicit pixel_wise_multiplication_visitor(float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
137 : _scale(scale), _convert_policy(convert_policy), _rounding_policy(rounding_policy)
138 {
139 }
140
141 template <typename T1, typename T2, typename T3>
142 void operator()(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3> &out) const
143 {
144 tensor_operations::pixel_wise_multiplication(in1, in2, out, _scale, _convert_policy, _rounding_policy);
145 }
146
147private:
148 float _scale;
149 ConvertPolicy _convert_policy;
150 RoundingPolicy _rounding_policy;
151};
152// Fixed Point Pixel-wise Multiplication visitor
153struct fixed_point_pixel_wise_multiplication_visitor : public boost::static_visitor<>
154{
155public:
156 explicit fixed_point_pixel_wise_multiplication_visitor(const TensorVariant &in1, const TensorVariant &in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
157 : _in1(in1), _in2(in2), _scale(scale), _convert_policy(convert_policy), _rounding_policy(rounding_policy)
158 {
159 }
160
161 template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type>
162 void operator()(Tensor<T> &out) const
163 {
164 const Tensor<T> &in1 = boost::get<Tensor<T>>(_in1);
165 const Tensor<T> &in2 = boost::get<Tensor<T>>(_in2);
166 tensor_operations::fixed_point_pixel_wise_multiplication(in1, in2, out, _scale, _convert_policy, _rounding_policy);
167 }
168 template < typename T, typename std::enable_if < !std::is_integral<T>::value, int >::type = 0 >
169 void operator()(Tensor<T> &out) const
170 {
171 ARM_COMPUTE_ERROR("NOT SUPPORTED!");
172 }
173
174private:
175 const TensorVariant &_in1;
176 const TensorVariant &_in2;
177 float _scale;
178 ConvertPolicy _convert_policy;
179 RoundingPolicy _rounding_policy;
180};
181// Threshold operation
182void threshold_operation(const Tensor<uint8_t> &in, Tensor<uint8_t> &out, uint8_t threshold, uint8_t false_value, uint8_t true_value, ThresholdType type, uint8_t upper)
183{
184 tensor_operations::threshold(in, out, threshold, false_value, true_value, type, upper);
185}
186// Activation layer visitor
187struct activation_layer_visitor : public boost::static_visitor<>
188{
189public:
190 explicit activation_layer_visitor(const TensorVariant &in, ActivationLayerInfo act_info)
191 : _in(in), _act_info(act_info)
192 {
193 }
194
195 template <typename T>
196 void operator()(Tensor<T> &out) const
197 {
198 const auto &in = boost::get<Tensor<T>>(_in);
199 tensor_operations::activation_layer(in, out, _act_info);
200 }
201
202private:
203 const TensorVariant &_in;
204 const ActivationLayerInfo _act_info;
205};
206// Batch Normalization Layer visitor
207struct batch_normalization_layer_visitor : public boost::static_visitor<>
208{
209public:
210 explicit batch_normalization_layer_visitor(const TensorVariant &in, const TensorVariant &mean, const TensorVariant &var, const TensorVariant &beta, const TensorVariant &gamma, float epsilon,
211 int fixed_point_position = 0)
212 : _in(in), _mean(mean), _var(var), _beta(beta), _gamma(gamma), _epsilon(epsilon), _fixed_point_position(fixed_point_position)
213 {
214 }
215
216 template <typename T>
217 void operator()(Tensor<T> &out) const
218 {
219 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
220 const Tensor<T> &mean = boost::get<Tensor<T>>(_mean);
221 const Tensor<T> &var = boost::get<Tensor<T>>(_var);
222 const Tensor<T> &beta = boost::get<Tensor<T>>(_beta);
223 const Tensor<T> &gamma = boost::get<Tensor<T>>(_gamma);
224 tensor_operations::batch_normalization_layer(in, out, mean, var, beta, gamma, _epsilon, _fixed_point_position);
225 }
226
227private:
228 const TensorVariant &_in, &_mean, &_var, &_beta, &_gamma;
229 float _epsilon;
230 int _fixed_point_position;
231};
232// Convolution Layer visitor
233struct convolution_layer_visitor : public boost::static_visitor<>
234{
235public:
236 explicit convolution_layer_visitor(const TensorVariant &in, const TensorVariant &weights, const TensorVariant &bias, PadStrideInfo conv_info)
237 : _in(in), _weights(weights), _bias(bias), _conv_info(conv_info)
238 {
239 }
240
241 template <typename T>
242 void operator()(Tensor<T> &out) const
243 {
244 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
245 const Tensor<T> &weights = boost::get<Tensor<T>>(_weights);
246 const Tensor<T> &bias = boost::get<Tensor<T>>(_bias);
247 tensor_operations::convolution_layer(in, weights, bias, out, _conv_info);
248 }
249
250private:
251 const TensorVariant &_in;
252 const TensorVariant &_weights;
253 const TensorVariant &_bias;
254 PadStrideInfo _conv_info;
255};
256
257struct fully_connected_layer_visitor : public boost::static_visitor<>
258{
259public:
260 explicit fully_connected_layer_visitor(const TensorVariant &in, const TensorVariant &weights, const TensorVariant &bias)
261 : _in(in), _weights(weights), _bias(bias)
262 {
263 }
264 template <typename T>
265 void operator()(Tensor<T> &out) const
266 {
267 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
268 const Tensor<T> &weights = boost::get<Tensor<T>>(_weights);
269 const Tensor<T> &bias = boost::get<Tensor<T>>(_bias);
270 tensor_operations::fully_connected_layer(in, weights, bias, out);
271 }
272
273private:
274 const TensorVariant &_in;
275 const TensorVariant &_weights;
276 const TensorVariant &_bias;
277};
278
279// Normalization Layer visitor
280struct normalization_layer_visitor : public boost::static_visitor<>
281{
282public:
283 explicit normalization_layer_visitor(const TensorVariant &in, NormalizationLayerInfo norm_info)
284 : _in(in), _norm_info(norm_info)
285 {
286 }
287
288 template <typename T>
289 void operator()(Tensor<T> &out) const
290 {
291 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
292 tensor_operations::normalization_layer(in, out, _norm_info);
293 }
294
295private:
296 const TensorVariant &_in;
297 NormalizationLayerInfo _norm_info;
298};
299// Pooling layer
300struct pooling_layer_visitor : public boost::static_visitor<>
301{
302public:
303 explicit pooling_layer_visitor(const TensorVariant &in, PoolingLayerInfo pool_info, int fixed_point_position = 0)
304 : _in(in), _pool_info(pool_info), _fixed_point_position(fixed_point_position)
305 {
306 }
307
308 template <typename T>
309 void operator()(Tensor<T> &out) const
310 {
311 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
312 tensor_operations::pooling_layer(in, out, _pool_info, _fixed_point_position);
313 }
314
315private:
316 const TensorVariant &_in;
317 PoolingLayerInfo _pool_info;
318 int _fixed_point_position;
319};
Georgios Pinitas7b7858d2017-06-21 16:44:24 +0100320
321// ROI Pooling layer
322struct roi_pooling_layer_visitor : public boost::static_visitor<>
323{
324public:
325 explicit roi_pooling_layer_visitor(const TensorVariant &in, const std::vector<ROI> &rois, ROIPoolingLayerInfo pool_info)
326 : _in(in), _rois(rois), _pool_info(pool_info)
327 {
328 }
329
330 template <typename T>
331 void operator()(Tensor<T> &out) const
332 {
333 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
334 tensor_operations::roi_pooling_layer(in, out, _rois, _pool_info);
335 }
336
337private:
338 const TensorVariant &_in;
339 const std::vector<ROI> &_rois;
340 ROIPoolingLayerInfo _pool_info;
341};
342
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100343// Softmax Layer visitor
344struct softmax_layer_visitor : public boost::static_visitor<>
345{
346public:
347 explicit softmax_layer_visitor(const TensorVariant &in)
348 : _in(in)
349 {
350 }
351
352 template <typename T>
353 void operator()(Tensor<T> &out) const
354 {
355 const auto &in = boost::get<Tensor<T>>(_in);
356 tensor_operations::softmax_layer(in, out);
357 }
358
359private:
360 const TensorVariant &_in;
361};
362// Fixed Point operations visitor
363struct fixed_point_operation_visitor : public boost::static_visitor<>
364{
365public:
366 explicit fixed_point_operation_visitor(const TensorVariant &in, FixedPointOp op)
367 : _in(in), _op(op)
368 {
369 }
370
371 template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
372 void operator()(Tensor<T> &out) const
373 {
374 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
375 tensor_operations::fixed_point_operation(in, out, _op);
376 }
377 template < typename T, typename std::enable_if < !std::is_integral<T>::value, int >::type = 0 >
378 void operator()(Tensor<T> &out) const
379 {
380 ARM_COMPUTE_ERROR("NOT SUPPORTED!");
381 }
382
383private:
384 const TensorVariant &_in;
385 FixedPointOp _op;
386};
387// Print Tensor visitor
388struct print_visitor : public boost::static_visitor<>
389{
390public:
391 explicit print_visitor(std::ostream &out)
392 : _out(out)
393 {
394 }
395
396 template <typename T>
397 void operator()(const Tensor<T> &in) const
398 {
399 tensor_operations::print(in, _out);
400 }
401
402private:
403 std::ostream &_out;
404};
405} // namespace tensor_visitors
406} // namespace validation
407} // namespace test
408} // namespace arm_compute
409
410#endif /* __ARM_COMPUTE_TEST_TENSOR_VISITORS_H__ */