blob: 30b552ae3cce200b7a88604cd01e1322205f3bf3 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_TEST_TENSOR_VISITORS_H__
25#define __ARM_COMPUTE_TEST_TENSOR_VISITORS_H__
26
27#include "Tensor.h"
28#include "TensorOperations.h"
29#include "arm_compute/core/Error.h"
Michalis Spyroubbd9fb92017-06-22 12:57:51 +010030#include "arm_compute/core/Helpers.h"
Isabella Gottardib797fa22017-06-23 15:02:11 +010031#include "arm_compute/runtime/Lut.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010032
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010033#include "tests/validation_old/boost_wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010034
Georgios Pinitasac4e8732017-07-05 17:02:25 +010035#include <algorithm>
Isabella Gottardib797fa22017-06-23 15:02:11 +010036#include <map>
Georgios Pinitasac4e8732017-07-05 17:02:25 +010037#include <memory>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038#include <ostream>
Georgios Pinitasd4f8c272017-06-30 16:16:19 +010039#include <vector>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010040
41namespace arm_compute
42{
43namespace test
44{
45namespace validation
46{
47namespace tensor_visitors
48{
49// Absolute Difference visitor
50struct absolute_difference_visitor : public boost::static_visitor<>
51{
52public:
53 template <typename T1, typename T2, typename T3>
54 void operator()(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3> &out) const
55 {
56 tensor_operations::absolute_difference(in1, in2, out);
57 }
58};
Anthony Barbier6ff3b192017-09-04 18:44:23 +010059// Pixel-wise Multiplication visitor
60struct pixel_wise_multiplication_visitor : public boost::static_visitor<>
61{
62public:
63 explicit pixel_wise_multiplication_visitor(float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
64 : _scale(scale), _convert_policy(convert_policy), _rounding_policy(rounding_policy)
65 {
66 }
67
68 template <typename T1, typename T2, typename T3>
69 void operator()(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3> &out) const
70 {
71 tensor_operations::pixel_wise_multiplication(in1, in2, out, _scale, _convert_policy, _rounding_policy);
72 }
73
74private:
75 float _scale;
76 ConvertPolicy _convert_policy;
77 RoundingPolicy _rounding_policy;
78};
79// Fixed Point Pixel-wise Multiplication visitor
80struct fixed_point_pixel_wise_multiplication_visitor : public boost::static_visitor<>
81{
82public:
83 explicit fixed_point_pixel_wise_multiplication_visitor(const TensorVariant &in1, const TensorVariant &in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
84 : _in1(in1), _in2(in2), _scale(scale), _convert_policy(convert_policy), _rounding_policy(rounding_policy)
85 {
86 }
87
88 template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type>
89 void operator()(Tensor<T> &out) const
90 {
91 const Tensor<T> &in1 = boost::get<Tensor<T>>(_in1);
92 const Tensor<T> &in2 = boost::get<Tensor<T>>(_in2);
93 tensor_operations::fixed_point_pixel_wise_multiplication(in1, in2, out, _scale, _convert_policy, _rounding_policy);
94 }
95 template < typename T, typename std::enable_if < !std::is_integral<T>::value, int >::type = 0 >
96 void operator()(Tensor<T> &out) const
97 {
98 ARM_COMPUTE_ERROR("NOT SUPPORTED!");
99 }
100
101private:
102 const TensorVariant &_in1;
103 const TensorVariant &_in2;
104 float _scale;
105 ConvertPolicy _convert_policy;
106 RoundingPolicy _rounding_policy;
107};
Isabella Gottardib797fa22017-06-23 15:02:11 +0100108
Georgios Pinitas7b7858d2017-06-21 16:44:24 +0100109// ROI Pooling layer
110struct roi_pooling_layer_visitor : public boost::static_visitor<>
111{
112public:
113 explicit roi_pooling_layer_visitor(const TensorVariant &in, const std::vector<ROI> &rois, ROIPoolingLayerInfo pool_info)
114 : _in(in), _rois(rois), _pool_info(pool_info)
115 {
116 }
117
118 template <typename T>
119 void operator()(Tensor<T> &out) const
120 {
121 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
122 tensor_operations::roi_pooling_layer(in, out, _rois, _pool_info);
123 }
124
125private:
126 const TensorVariant &_in;
127 const std::vector<ROI> &_rois;
128 ROIPoolingLayerInfo _pool_info;
129};
130
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100131// Fixed Point operations visitor
132struct fixed_point_operation_visitor : public boost::static_visitor<>
133{
134public:
135 explicit fixed_point_operation_visitor(const TensorVariant &in, FixedPointOp op)
136 : _in(in), _op(op)
137 {
138 }
139
140 template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
141 void operator()(Tensor<T> &out) const
142 {
143 const Tensor<T> &in = boost::get<Tensor<T>>(_in);
144 tensor_operations::fixed_point_operation(in, out, _op);
145 }
146 template < typename T, typename std::enable_if < !std::is_integral<T>::value, int >::type = 0 >
147 void operator()(Tensor<T> &out) const
148 {
149 ARM_COMPUTE_ERROR("NOT SUPPORTED!");
150 }
151
152private:
153 const TensorVariant &_in;
154 FixedPointOp _op;
155};
156// Print Tensor visitor
157struct print_visitor : public boost::static_visitor<>
158{
159public:
160 explicit print_visitor(std::ostream &out)
161 : _out(out)
162 {
163 }
164
165 template <typename T>
166 void operator()(const Tensor<T> &in) const
167 {
168 tensor_operations::print(in, _out);
169 }
170
171private:
172 std::ostream &_out;
173};
174} // namespace tensor_visitors
175} // namespace validation
176} // namespace test
177} // namespace arm_compute
178
179#endif /* __ARM_COMPUTE_TEST_TENSOR_VISITORS_H__ */