blob: 72af9d9241792f4e324afda2f684bda1eb17e6d7 [file] [log] [blame]
Sang-Hoon Park0d008f72020-03-13 14:56:05 +00001/*
2 * Copyright (c) 2020 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_TEST_QLSTM_LAYER_NORMALIZATION_FIXTURE
25#define ARM_COMPUTE_TEST_QLSTM_LAYER_NORMALIZATION_FIXTURE
26
27#include "arm_compute/core/TensorShape.h"
28#include "arm_compute/core/Types.h"
Sheri Zhangb18252d2020-04-07 11:04:57 +010029#ifdef ARM_COMPUTE_CL
30#include "arm_compute/runtime/CL/CLScheduler.h"
31#endif /* ARM_COMPUTE_CL */
Sang-Hoon Park0d008f72020-03-13 14:56:05 +000032#include "arm_compute/runtime/NEON/NEScheduler.h"
33#include "tests/AssetsLibrary.h"
34#include "tests/Globals.h"
35#include "tests/IAccessor.h"
36#include "tests/framework/Asserts.h"
37#include "tests/framework/Fixture.h"
38#include "tests/validation/Helpers.h"
39#include "tests/validation/reference/QLSTMLayerNormalization.h"
40
41namespace arm_compute
42{
43namespace test
44{
45namespace validation
46{
47template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
48class QLSTMLayerNormalizationValidationFixture : public framework::Fixture
49{
50public:
51 template <typename...>
52 void setup(TensorShape input_shape, TensorShape weight_shape, TensorShape bias_shape, DataType data_type, QuantizationInfo weight_qinfo)
53 {
54 ARM_COMPUTE_ERROR_ON(data_type != DataType::QSYMM16);
55
56 _data_type = data_type;
57 _qinfo = weight_qinfo;
58
59 _target = compute_target(input_shape, weight_shape, bias_shape);
60 _reference = compute_reference(input_shape, weight_shape, bias_shape);
61 }
62
63protected:
64 template <typename InputType, typename BiasType>
65 void fill(InputType &&input_tensor, InputType &&weight_tensor, BiasType &&bias_tensor)
66 {
67 switch(_data_type)
68 {
69 case DataType::QSYMM16:
70 {
71 // Value ranges are based on reference implementation's test case.
72 constexpr int16_t input_min = -1000;
73 constexpr int16_t input_max = 1000;
74 constexpr int16_t weight_min = 19000;
75 constexpr int16_t weight_max = 27000;
76 constexpr int32_t bias_min = -16000000;
77 constexpr int32_t bias_max = -13000000;
78
79 std::uniform_int_distribution<> input_distribution(input_min, input_max);
80 std::uniform_int_distribution<> weight_distribution(weight_min, weight_max);
81 std::uniform_int_distribution<> bias_distribution(bias_min, bias_max);
82
83 library->fill(input_tensor, input_distribution, 0);
84 library->fill(weight_tensor, weight_distribution, 0);
85 library->fill(bias_tensor, bias_distribution, 0);
86 break;
87 }
88 default:
89 ARM_COMPUTE_ERROR("non-supported data type");
90 break;
91 }
92 }
93
94 void allocate_tensors(const std::vector<TensorType *> &tensors)
95 {
96 for(auto t : tensors)
97 {
98 ARM_COMPUTE_EXPECT(t->info()->is_resizable(), framework::LogLevel::ERRORS);
99 t->allocator()->allocate();
100 ARM_COMPUTE_EXPECT(!t->info()->is_resizable(), framework::LogLevel::ERRORS);
101 }
102 }
103
Sheri Zhangb18252d2020-04-07 11:04:57 +0100104 virtual void run_target(FunctionType &fn) = 0;
105
Sang-Hoon Park0d008f72020-03-13 14:56:05 +0000106 TensorType compute_target(const TensorShape &input_shape, const TensorShape &weight_shape, const TensorShape &bias_shape)
107 {
108 TensorType input = create_tensor<TensorType>(input_shape, _data_type, 1);
109 TensorType weight = create_tensor<TensorType>(weight_shape, _data_type, 1, _qinfo);
110 TensorType bias = create_tensor<TensorType>(bias_shape, DataType::S32, 1);
111 TensorType output = create_tensor<TensorType>(input_shape, _data_type, 1);
112
113 FunctionType fn;
114 fn.configure(&input, &output, &weight, &bias);
115 allocate_tensors({ &input, &weight, &bias, &output });
116 fill(AccessorType(input), AccessorType(weight), AccessorType(bias));
117
Sheri Zhangb18252d2020-04-07 11:04:57 +0100118 run_target(fn);
Sang-Hoon Park0d008f72020-03-13 14:56:05 +0000119
120 return output;
121 }
122
123 SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weight_shape, const TensorShape &bias_shape)
124 {
125 // Create reference
126 SimpleTensor<T> input{ input_shape, _data_type, 1 };
127 SimpleTensor<T> weight{ weight_shape, _data_type, 1, _qinfo };
128 SimpleTensor<int32_t> bias{ bias_shape, DataType::S32, 1 };
129
130 // Fill reference
131 fill(input, weight, bias);
132
133 return reference::qlstm_layer_normalization(input, weight, bias);
134 }
135
136 TensorType _target{};
137 SimpleTensor<T> _reference{};
138 DataType _data_type{};
139 QuantizationInfo _qinfo{};
140};
141
Sheri Zhangb18252d2020-04-07 11:04:57 +0100142template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
143class NEQLSTMLayerNormalizationValidationFixture : public QLSTMLayerNormalizationValidationFixture<TensorType, AccessorType, FunctionType, T>
144{
145protected:
146 void run_target(FunctionType &fn) override
147 {
148 ThreadInfo tinfo;
149 tinfo.cpu_info = &NEScheduler::get().cpu_info();
150 fn.run(fn.window(), tinfo);
151 }
152};
153
154#ifdef ARM_COMPUTE_CL
155template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
156class CLQLSTMLayerNormalizationValidationFixture : public QLSTMLayerNormalizationValidationFixture<TensorType, AccessorType, FunctionType, T>
157{
158protected:
159 void run_target(FunctionType &fn) override
160 {
161 CLScheduler::get().default_init();
162 fn.run(fn.window(), CLScheduler::get().queue());
163 }
164};
165#endif /* ARM_COMPUTE_CL */
166
Sang-Hoon Park0d008f72020-03-13 14:56:05 +0000167} // namespace validation
168} // namespace test
169} // namespace arm_compute
170
171#endif /* ARM_COMPUTE_TEST_QLSTM_LAYER_NORMALIZATION_FIXTURE */