blob: c21f3208cefa3e7b41344bdc9a40896dbb0df2cd [file] [log] [blame]
Michalis Spyroubcedf512018-03-22 14:55:08 +00001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_TEST_LSTM_LAYER_DATASET
25#define ARM_COMPUTE_TEST_LSTM_LAYER_DATASET
26
27#include "utils/TypePrinter.h"
28
29#include "arm_compute/core/TensorShape.h"
30#include "arm_compute/core/Types.h"
31
32namespace arm_compute
33{
34namespace test
35{
36namespace datasets
37{
38class LSTMLayerDataset
39{
40public:
41 using type = std::tuple<TensorShape, TensorShape, TensorShape, TensorShape, TensorShape, TensorShape, TensorShape, ActivationLayerInfo, float, float>;
42
43 struct iterator
44 {
45 iterator(std::vector<TensorShape>::const_iterator src_it,
46 std::vector<TensorShape>::const_iterator input_weights_it,
47 std::vector<TensorShape>::const_iterator recurrent_weights_it,
48 std::vector<TensorShape>::const_iterator cells_bias_it,
49 std::vector<TensorShape>::const_iterator output_cell_it,
50 std::vector<TensorShape>::const_iterator dst_it,
51 std::vector<TensorShape>::const_iterator scratch_it,
52 std::vector<ActivationLayerInfo>::const_iterator infos_it,
53 std::vector<float>::const_iterator cell_threshold_it,
54 std::vector<float>::const_iterator projection_threshold_it)
55 : _src_it{ std::move(src_it) },
56 _input_weights_it{ std::move(input_weights_it) },
57 _recurrent_weights_it{ std::move(recurrent_weights_it) },
58 _cells_bias_it{ std::move(cells_bias_it) },
59 _output_cell_it{ std::move(output_cell_it) },
60 _dst_it{ std::move(dst_it) },
61 _scratch_it{ std::move(scratch_it) },
62 _infos_it{ std::move(infos_it) },
63 _cell_threshold_it{ std::move(cell_threshold_it) },
64 _projection_threshold_it{ std::move(projection_threshold_it) }
65 {
66 }
67
68 std::string description() const
69 {
70 std::stringstream description;
71 description << "In=" << *_src_it << ":";
72 description << "InputWeights=" << *_input_weights_it << ":";
73 description << "RecurrentWeights=" << *_recurrent_weights_it << ":";
74 description << "Biases=" << *_cells_bias_it << ":";
75 description << "Scratch=" << *_scratch_it << ":";
76 description << "Out=" << *_dst_it;
77 return description.str();
78 }
79
80 LSTMLayerDataset::type operator*() const
81 {
82 return std::make_tuple(*_src_it, *_input_weights_it, *_recurrent_weights_it, *_cells_bias_it, *_output_cell_it, *_dst_it, *_scratch_it, *_infos_it, *_cell_threshold_it, *_projection_threshold_it);
83 }
84
85 iterator &operator++()
86 {
87 ++_src_it;
88 ++_input_weights_it;
89 ++_recurrent_weights_it;
90 ++_cells_bias_it;
91 ++_output_cell_it;
92 ++_dst_it;
93 ++_scratch_it;
94 ++_infos_it;
95 ++_cell_threshold_it;
96 ++_projection_threshold_it;
97
98 return *this;
99 }
100
101 private:
102 std::vector<TensorShape>::const_iterator _src_it;
103 std::vector<TensorShape>::const_iterator _input_weights_it;
104 std::vector<TensorShape>::const_iterator _recurrent_weights_it;
105 std::vector<TensorShape>::const_iterator _cells_bias_it;
106 std::vector<TensorShape>::const_iterator _output_cell_it;
107 std::vector<TensorShape>::const_iterator _dst_it;
108 std::vector<TensorShape>::const_iterator _scratch_it;
109 std::vector<ActivationLayerInfo>::const_iterator _infos_it;
110 std::vector<float>::const_iterator _cell_threshold_it;
111 std::vector<float>::const_iterator _projection_threshold_it;
112 };
113
114 iterator begin() const
115 {
116 return iterator(_src_shapes.begin(), _input_weights_shapes.begin(), _recurrent_weights_shapes.begin(), _cell_bias_shapes.begin(), _output_cell_shapes.begin(), _dst_shapes.begin(),
117 _scratch_shapes.begin(), _infos.begin(), _cell_threshold.begin(), _projection_threshold.begin());
118 }
119
120 int size() const
121 {
122 return std::min(_src_shapes.size(), std::min(_input_weights_shapes.size(), std::min(_recurrent_weights_shapes.size(), std::min(_cell_bias_shapes.size(), std::min(_output_cell_shapes.size(),
123 std::min(_dst_shapes.size(), std::min(_scratch_shapes.size(), std::min(_cell_threshold.size(), std::min(_projection_threshold.size(), _infos.size())))))))));
124 }
125
126 void add_config(TensorShape src, TensorShape input_weights, TensorShape recurrent_weights, TensorShape cell_bias_weights, TensorShape output_cell_state, TensorShape dst, TensorShape scratch,
127 ActivationLayerInfo info, float cell_threshold, float projection_threshold)
128 {
129 _src_shapes.emplace_back(std::move(src));
130 _input_weights_shapes.emplace_back(std::move(input_weights));
131 _recurrent_weights_shapes.emplace_back(std::move(recurrent_weights));
132 _cell_bias_shapes.emplace_back(std::move(cell_bias_weights));
133 _output_cell_shapes.emplace_back(std::move(output_cell_state));
134 _dst_shapes.emplace_back(std::move(dst));
135 _scratch_shapes.emplace_back(std::move(scratch));
136 _infos.emplace_back(std::move(info));
137 _cell_threshold.emplace_back(std::move(cell_threshold));
138 _projection_threshold.emplace_back(std::move(projection_threshold));
139 }
140
141protected:
142 LSTMLayerDataset() = default;
143 LSTMLayerDataset(LSTMLayerDataset &&) = default;
144
145private:
146 std::vector<TensorShape> _src_shapes{};
147 std::vector<TensorShape> _input_weights_shapes{};
148 std::vector<TensorShape> _recurrent_weights_shapes{};
149 std::vector<TensorShape> _cell_bias_shapes{};
150 std::vector<TensorShape> _output_cell_shapes{};
151 std::vector<TensorShape> _dst_shapes{};
152 std::vector<TensorShape> _scratch_shapes{};
153 std::vector<ActivationLayerInfo> _infos{};
154 std::vector<float> _cell_threshold{};
155 std::vector<float> _projection_threshold{};
156};
157
158class SmallLSTMLayerDataset final : public LSTMLayerDataset
159{
160public:
161 SmallLSTMLayerDataset()
162 {
Georgios Pinitas3ada2b72018-08-23 15:54:36 +0100163 add_config(TensorShape(8U), TensorShape(8U, 16U), TensorShape(16U, 16U), TensorShape(16U), TensorShape(16U), TensorShape(16U), TensorShape(64U),
164 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), 0.05f, 0.93f);
165 add_config(TensorShape(8U, 2U), TensorShape(8U, 16U), TensorShape(16U, 16U), TensorShape(16U), TensorShape(16U, 2U), TensorShape(16U, 2U), TensorShape(64U, 2U),
166 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), 0.05f, 0.93f);
167 add_config(TensorShape(8U, 2U), TensorShape(8U, 16U), TensorShape(16U, 16U), TensorShape(16U), TensorShape(16U, 2U), TensorShape(16U, 2U), TensorShape(48U, 2U),
168 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), 0.05f, 0.93f);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000169 }
170};
171
172} // namespace datasets
173} // namespace test
174} // namespace arm_compute
175#endif /* ARM_COMPUTE_TEST_LSTM_LAYER_DATASET */