blob: 1b58a16e4c78a2bc13a6cce2f5a699563ab93582 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_TEST_TENSOR_VISITORS_H__
25#define __ARM_COMPUTE_TEST_TENSOR_VISITORS_H__
26
27#include "Tensor.h"
28#include "TensorOperations.h"
29#include "arm_compute/core/Error.h"
30
31#include "boost_wrapper.h"
32
33#include <ostream>
34
35namespace arm_compute
36{
37namespace test
38{
39namespace validation
40{
41namespace tensor_visitors
42{
43// Absolute Difference visitor
44struct absolute_difference_visitor : public boost::static_visitor<>
45{
46public:
47 template <typename T1, typename T2, typename T3>
48 void operator()(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3> &out) const
49 {
50 tensor_operations::absolute_difference(in1, in2, out);
51 }
52};
53// Arithmetic Addition visitor
54struct arithmetic_addition_visitor : public boost::static_visitor<>
55{
56public:
57 explicit arithmetic_addition_visitor(ConvertPolicy convert_policy)
58 : _policy(convert_policy)
59 {
60 }
61
62 template <typename T1, typename T2, typename T3>
63 void operator()(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3> &out) const
64 {
65 tensor_operations::arithmetic_addition(in1, in2, out, _policy);
66 }
67
68private:
69 ConvertPolicy _policy;
70};
71// Arithmetic Subtraction visitor
72struct arithmetic_subtraction_visitor : public boost::static_visitor<>
73{
74public:
75 explicit arithmetic_subtraction_visitor(ConvertPolicy convert_policy)
76 : _policy(convert_policy)
77 {
78 }
79
80 template <typename T1, typename T2, typename T3>
81 void operator()(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3> &out) const
82 {
83 tensor_operations::arithmetic_subtraction(in1, in2, out, _policy);
84 }
85
86private:
87 ConvertPolicy _policy;
88};
89// Depth Convert visitor
90struct depth_convert_visitor : public boost::static_visitor<>
91{
92public:
93 explicit depth_convert_visitor(ConvertPolicy policy, uint32_t shift)
94 : _policy(policy), _shift(shift)
95 {
96 }
97
98 template <typename T1, typename T2>
99 void operator()(const Tensor<T1> &in, Tensor<T2> &out) const
100 {
101 tensor_operations::depth_convert(in, out, _policy, _shift);
102 }
103
104private:
105 ConvertPolicy _policy;
106 uint32_t _shift;
107};
108// GEMM visitor
109struct gemm_visitor : public boost::static_visitor<>
110{
111public:
112 explicit gemm_visitor(const TensorVariant &in1, const TensorVariant &in2, const TensorVariant &in3, float alpha, float beta)
113 : _in1(in1), _in2(in2), _in3(in3), _alpha(alpha), _beta(beta)
114 {
115 }
116
117 template <typename T>
118 void operator()(Tensor<T> &out) const
119 {
120 const Tensor<T> &in1 = boost::get<Tensor<T>>(_in1);
121 const Tensor<T> &in2 = boost::get<Tensor<T>>(_in2);
122 const Tensor<T> &in3 = boost::get<Tensor<T>>(_in3);
123 tensor_operations::gemm(in1, in2, in3, out, _alpha, _beta);
124 }
125
126private:
127 const TensorVariant &_in1, &_in2, &_in3;
128 float _alpha;
129 float _beta;
130};
131// Pixel-wise Multiplication visitor
132struct pixel_wise_multiplication_visitor : public boost::static_visitor<>
133{
134public:
135 explicit pixel_wise_multiplication_visitor(float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
136 : _scale(scale), _convert_policy(convert_policy), _rounding_policy(rounding_policy)
137 {
138 }
139
140 template <typename T1, typename T2, typename T3>
141 void operator()(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3> &out) const
142 {
143 tensor_operations::pixel_wise_multiplication(in1, in2, out, _scale, _convert_policy, _rounding_policy);
144 }
145
146private:
147 float _scale;
148 ConvertPolicy _convert_policy;
149 RoundingPolicy _rounding_policy;
150};
151// Fixed Point Pixel-wise Multiplication visitor
152struct fixed_point_pixel_wise_multiplication_visitor : public boost::static_visitor<>
153{
154public:
155 explicit fixed_point_pixel_wise_multiplication_visitor(const TensorVariant &in1, const TensorVariant &in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
156 : _in1(in1), _in2(in2), _scale(scale), _convert_policy(convert_policy), _rounding_policy(rounding_policy)
157 {
158 }
159
160 template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type>
161 void operator()(Tensor<T> &out) const
162 {
163 const Tensor<T> &in1 = boost::get<Tensor<T>>(_in1);
164 const Tensor<T> &in2 = boost::get<Tensor<T>>(_in2);
165 tensor_operations::fixed_point_pixel_wise_multiplication(in1, in2, out, _scale, _convert_policy, _rounding_policy);
166 }
167 template < typename T, typename std::enable_if < !std::is_integral<T>::value, int >::type = 0 >
168 void operator()(Tensor<T> &out) const
169 {
170 ARM_COMPUTE_ERROR("NOT SUPPORTED!");
171 }
172
173private:
174 const TensorVariant &_in1;
175 const TensorVariant &_in2;
176 float _scale;
177 ConvertPolicy _convert_policy;
178 RoundingPolicy _rounding_policy;
179};
180// Threshold operation
181void threshold_operation(const Tensor<uint8_t> &in, Tensor<uint8_t> &out, uint8_t threshold, uint8_t false_value, uint8_t true_value, ThresholdType type, uint8_t upper)
182{
183 tensor_operations::threshold(in, out, threshold, false_value, true_value, type, upper);
184}
185// Activation layer visitor
186struct activation_layer_visitor : public boost::static_visitor<>
187{
188public:
189 explicit activation_layer_visitor(const TensorVariant &in, ActivationLayerInfo act_info)
190 : _in(in), _act_info(act_info)
191 {
192 }
193
194 template <typename T>
195 void operator()(Tensor<T> &out) const
196 {
197 const auto &in = boost::get<Tensor<T>>(_in);
198 tensor_operations::activation_layer(in, out, _act_info);
199 }
200
201private:
202 const TensorVariant &_in;
203 const ActivationLayerInfo _act_info;
204};
205// Batch Normalization Layer visitor
206struct batch_normalization_layer_visitor : public boost::static_visitor<>
207{
208public:
209 explicit batch_normalization_layer_visitor(const TensorVariant &in, const TensorVariant &mean, const TensorVariant &var, const TensorVariant &beta, const TensorVariant &gamma, float epsilon,
210 int fixed_point_position = 0)
211 : _in(in), _mean(mean), _var(var), _beta(beta), _gamma(gamma), _epsilon(epsilon), _fixed_point_position(fixed_point_position)
212 {
213 }
214
215 template <typename T>
216 void operator()(Tensor<T> &out) const
217 {
218 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
219 const Tensor<T> &mean = boost::get<Tensor<T>>(_mean);
220 const Tensor<T> &var = boost::get<Tensor<T>>(_var);
221 const Tensor<T> &beta = boost::get<Tensor<T>>(_beta);
222 const Tensor<T> &gamma = boost::get<Tensor<T>>(_gamma);
223 tensor_operations::batch_normalization_layer(in, out, mean, var, beta, gamma, _epsilon, _fixed_point_position);
224 }
225
226private:
227 const TensorVariant &_in, &_mean, &_var, &_beta, &_gamma;
228 float _epsilon;
229 int _fixed_point_position;
230};
231// Convolution Layer visitor
232struct convolution_layer_visitor : public boost::static_visitor<>
233{
234public:
235 explicit convolution_layer_visitor(const TensorVariant &in, const TensorVariant &weights, const TensorVariant &bias, PadStrideInfo conv_info)
236 : _in(in), _weights(weights), _bias(bias), _conv_info(conv_info)
237 {
238 }
239
240 template <typename T>
241 void operator()(Tensor<T> &out) const
242 {
243 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
244 const Tensor<T> &weights = boost::get<Tensor<T>>(_weights);
245 const Tensor<T> &bias = boost::get<Tensor<T>>(_bias);
246 tensor_operations::convolution_layer(in, weights, bias, out, _conv_info);
247 }
248
249private:
250 const TensorVariant &_in;
251 const TensorVariant &_weights;
252 const TensorVariant &_bias;
253 PadStrideInfo _conv_info;
254};
255
256struct fully_connected_layer_visitor : public boost::static_visitor<>
257{
258public:
259 explicit fully_connected_layer_visitor(const TensorVariant &in, const TensorVariant &weights, const TensorVariant &bias)
260 : _in(in), _weights(weights), _bias(bias)
261 {
262 }
263 template <typename T>
264 void operator()(Tensor<T> &out) const
265 {
266 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
267 const Tensor<T> &weights = boost::get<Tensor<T>>(_weights);
268 const Tensor<T> &bias = boost::get<Tensor<T>>(_bias);
269 tensor_operations::fully_connected_layer(in, weights, bias, out);
270 }
271
272private:
273 const TensorVariant &_in;
274 const TensorVariant &_weights;
275 const TensorVariant &_bias;
276};
277
278// Normalization Layer visitor
279struct normalization_layer_visitor : public boost::static_visitor<>
280{
281public:
282 explicit normalization_layer_visitor(const TensorVariant &in, NormalizationLayerInfo norm_info)
283 : _in(in), _norm_info(norm_info)
284 {
285 }
286
287 template <typename T>
288 void operator()(Tensor<T> &out) const
289 {
290 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
291 tensor_operations::normalization_layer(in, out, _norm_info);
292 }
293
294private:
295 const TensorVariant &_in;
296 NormalizationLayerInfo _norm_info;
297};
298// Pooling layer
299struct pooling_layer_visitor : public boost::static_visitor<>
300{
301public:
302 explicit pooling_layer_visitor(const TensorVariant &in, PoolingLayerInfo pool_info, int fixed_point_position = 0)
303 : _in(in), _pool_info(pool_info), _fixed_point_position(fixed_point_position)
304 {
305 }
306
307 template <typename T>
308 void operator()(Tensor<T> &out) const
309 {
310 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
311 tensor_operations::pooling_layer(in, out, _pool_info, _fixed_point_position);
312 }
313
314private:
315 const TensorVariant &_in;
316 PoolingLayerInfo _pool_info;
317 int _fixed_point_position;
318};
Georgios Pinitas7b7858d2017-06-21 16:44:24 +0100319
320// ROI Pooling layer
321struct roi_pooling_layer_visitor : public boost::static_visitor<>
322{
323public:
324 explicit roi_pooling_layer_visitor(const TensorVariant &in, const std::vector<ROI> &rois, ROIPoolingLayerInfo pool_info)
325 : _in(in), _rois(rois), _pool_info(pool_info)
326 {
327 }
328
329 template <typename T>
330 void operator()(Tensor<T> &out) const
331 {
332 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
333 tensor_operations::roi_pooling_layer(in, out, _rois, _pool_info);
334 }
335
336private:
337 const TensorVariant &_in;
338 const std::vector<ROI> &_rois;
339 ROIPoolingLayerInfo _pool_info;
340};
341
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100342// Softmax Layer visitor
343struct softmax_layer_visitor : public boost::static_visitor<>
344{
345public:
346 explicit softmax_layer_visitor(const TensorVariant &in)
347 : _in(in)
348 {
349 }
350
351 template <typename T>
352 void operator()(Tensor<T> &out) const
353 {
354 const auto &in = boost::get<Tensor<T>>(_in);
355 tensor_operations::softmax_layer(in, out);
356 }
357
358private:
359 const TensorVariant &_in;
360};
361// Fixed Point operations visitor
362struct fixed_point_operation_visitor : public boost::static_visitor<>
363{
364public:
365 explicit fixed_point_operation_visitor(const TensorVariant &in, FixedPointOp op)
366 : _in(in), _op(op)
367 {
368 }
369
370 template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
371 void operator()(Tensor<T> &out) const
372 {
373 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
374 tensor_operations::fixed_point_operation(in, out, _op);
375 }
376 template < typename T, typename std::enable_if < !std::is_integral<T>::value, int >::type = 0 >
377 void operator()(Tensor<T> &out) const
378 {
379 ARM_COMPUTE_ERROR("NOT SUPPORTED!");
380 }
381
382private:
383 const TensorVariant &_in;
384 FixedPointOp _op;
385};
386// Print Tensor visitor
387struct print_visitor : public boost::static_visitor<>
388{
389public:
390 explicit print_visitor(std::ostream &out)
391 : _out(out)
392 {
393 }
394
395 template <typename T>
396 void operator()(const Tensor<T> &in) const
397 {
398 tensor_operations::print(in, _out);
399 }
400
401private:
402 std::ostream &_out;
403};
404} // namespace tensor_visitors
405} // namespace validation
406} // namespace test
407} // namespace arm_compute
408
409#endif /* __ARM_COMPUTE_TEST_TENSOR_VISITORS_H__ */