blob: e212c7bd9b50eb86c07195e41cb8046b8c617825 [file] [log] [blame]
Moritz Pflanzer572ade72017-07-21 17:36:33 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_TEST_ACTIVATION_LAYER_FIXTURE
25#define ARM_COMPUTE_TEST_ACTIVATION_LAYER_FIXTURE
26
27#include "arm_compute/core/TensorShape.h"
28#include "arm_compute/core/Types.h"
Moritz Pflanzer572ade72017-07-21 17:36:33 +010029#include "tests/AssetsLibrary.h"
30#include "tests/Globals.h"
31#include "tests/IAccessor.h"
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010032#include "tests/framework/Asserts.h"
33#include "tests/framework/Fixture.h"
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010034#include "tests/validation/Helpers.h"
Georgios Pinitas5a7e7762017-12-01 16:27:29 +000035#include "tests/validation/reference/ActivationLayer.h"
Moritz Pflanzer572ade72017-07-21 17:36:33 +010036
37#include <random>
38
39namespace arm_compute
40{
41namespace test
42{
43namespace validation
44{
45template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
Michel Iwaniec66cc12f2017-12-07 17:26:40 +000046class ActivationValidationGenericFixture : public framework::Fixture
Moritz Pflanzer572ade72017-07-21 17:36:33 +010047{
48public:
49 template <typename...>
Michel Iwaniec66cc12f2017-12-07 17:26:40 +000050 void setup(TensorShape shape, bool in_place, ActivationLayerInfo::ActivationFunction function, float alpha_beta, DataType data_type, int fractional_bits, QuantizationInfo quantization_info)
Moritz Pflanzer572ade72017-07-21 17:36:33 +010051 {
Michel Iwaniec66cc12f2017-12-07 17:26:40 +000052 _fractional_bits = fractional_bits;
53 _quantization_info = quantization_info;
54 _data_type = data_type;
55 _function = function;
Moritz Pflanzer572ade72017-07-21 17:36:33 +010056
57 ActivationLayerInfo info(function, alpha_beta, alpha_beta);
58
Michel Iwaniec66cc12f2017-12-07 17:26:40 +000059 _target = compute_target(shape, in_place, info, data_type, fractional_bits, quantization_info);
60 _reference = compute_reference(shape, info, data_type, fractional_bits, quantization_info);
Moritz Pflanzer572ade72017-07-21 17:36:33 +010061 }
62
63protected:
64 template <typename U>
65 void fill(U &&tensor)
66 {
67 if(is_data_type_float(_data_type))
68 {
69 float min_bound = 0;
70 float max_bound = 0;
71 std::tie(min_bound, max_bound) = get_activation_layer_test_bounds<T>(_function, _data_type);
72 std::uniform_real_distribution<> distribution(min_bound, max_bound);
73 library->fill(tensor, distribution, 0);
74 }
Michel Iwaniec66cc12f2017-12-07 17:26:40 +000075 else if(is_data_type_quantized_asymmetric(tensor.data_type()))
76 {
77 library->fill_tensor_uniform(tensor, 0);
78 }
Moritz Pflanzer572ade72017-07-21 17:36:33 +010079 else
80 {
81 int min_bound = 0;
82 int max_bound = 0;
83 std::tie(min_bound, max_bound) = get_activation_layer_test_bounds<T>(_function, _data_type, _fractional_bits);
84 std::uniform_int_distribution<> distribution(min_bound, max_bound);
85 library->fill(tensor, distribution, 0);
86 }
87 }
88
Michel Iwaniec66cc12f2017-12-07 17:26:40 +000089 TensorType compute_target(const TensorShape &shape, bool in_place, ActivationLayerInfo info, DataType data_type, int fixed_point_position, QuantizationInfo quantization_info)
Moritz Pflanzer572ade72017-07-21 17:36:33 +010090 {
91 // Create tensors
Michel Iwaniec66cc12f2017-12-07 17:26:40 +000092 TensorType src = create_tensor<TensorType>(shape, data_type, 1, fixed_point_position, quantization_info);
93 TensorType dst = create_tensor<TensorType>(shape, data_type, 1, fixed_point_position, quantization_info);
Moritz Pflanzer572ade72017-07-21 17:36:33 +010094
95 // Create and configure function
96 FunctionType act_layer;
97
98 TensorType *dst_ptr = in_place ? &src : &dst;
99
100 act_layer.configure(&src, dst_ptr, info);
101
102 ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
103 ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
104
105 // Allocate tensors
106 src.allocator()->allocate();
107 ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS);
108
109 if(!in_place)
110 {
111 dst.allocator()->allocate();
112 ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
113 }
114
115 // Fill tensors
116 fill(AccessorType(src));
117
118 // Compute function
119 act_layer.run();
120
121 if(in_place)
122 {
123 return src;
124 }
125 else
126 {
127 return dst;
128 }
129 }
130
Michel Iwaniec66cc12f2017-12-07 17:26:40 +0000131 SimpleTensor<T> compute_reference(const TensorShape &shape, ActivationLayerInfo info, DataType data_type, int fixed_point_position, QuantizationInfo quantization_info)
Moritz Pflanzer572ade72017-07-21 17:36:33 +0100132 {
133 // Create reference
Michel Iwaniec66cc12f2017-12-07 17:26:40 +0000134 SimpleTensor<T> src{ shape, data_type, 1, fixed_point_position, quantization_info };
Moritz Pflanzer572ade72017-07-21 17:36:33 +0100135
136 // Fill reference
137 fill(src);
138
139 return reference::activation_layer<T>(src, info);
140 }
141
142 TensorType _target{};
143 SimpleTensor<T> _reference{};
144 int _fractional_bits{};
Michel Iwaniec66cc12f2017-12-07 17:26:40 +0000145 QuantizationInfo _quantization_info{};
Moritz Pflanzer572ade72017-07-21 17:36:33 +0100146 DataType _data_type{};
147 ActivationLayerInfo::ActivationFunction _function{};
148};
149
150template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
Michel Iwaniec66cc12f2017-12-07 17:26:40 +0000151class ActivationValidationFixture : public ActivationValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
Moritz Pflanzer572ade72017-07-21 17:36:33 +0100152{
153public:
154 template <typename...>
155 void setup(TensorShape shape, bool in_place, ActivationLayerInfo::ActivationFunction function, float alpha_beta, DataType data_type)
156 {
Michel Iwaniec66cc12f2017-12-07 17:26:40 +0000157 ActivationValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, in_place, function, alpha_beta, data_type, 0, QuantizationInfo());
Moritz Pflanzer572ade72017-07-21 17:36:33 +0100158 }
159};
Michel Iwaniec66cc12f2017-12-07 17:26:40 +0000160
161template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
162class ActivationValidationFixedPointFixture : public ActivationValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
163{
164public:
165 template <typename...>
166 void setup(TensorShape shape, bool in_place, ActivationLayerInfo::ActivationFunction function, float alpha_beta, DataType data_type, int fractional_bits)
167 {
168 ActivationValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, in_place, function, alpha_beta, data_type, fractional_bits, QuantizationInfo());
169 }
170};
171
172template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
173class ActivationValidationQuantizedFixture : public ActivationValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
174{
175public:
176 template <typename...>
177 void setup(TensorShape shape, bool in_place, ActivationLayerInfo::ActivationFunction function, float alpha_beta, DataType data_type, QuantizationInfo quantization_info)
178 {
179 ActivationValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, in_place, function, alpha_beta, data_type, 0, quantization_info);
180 }
181};
182
Moritz Pflanzer572ade72017-07-21 17:36:33 +0100183} // namespace validation
184} // namespace test
185} // namespace arm_compute
186#endif /* ARM_COMPUTE_TEST_ACTIVATION_LAYER_FIXTURE */