blob: f18397ea30d28b00ce73e0e6e45ccea8b892fef4 [file] [log] [blame]
Moritz Pflanzeree493ae2017-07-05 10:52:21 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017 Arm Limited.
Moritz Pflanzeree493ae2017-07-05 10:52:21 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_TEST_GEMM_DATASET
25#define ARM_COMPUTE_TEST_GEMM_DATASET
26
Anthony Barbier2a07e182017-08-04 18:20:27 +010027#include "utils/TypePrinter.h"
Moritz Pflanzeree493ae2017-07-05 10:52:21 +010028
29#include "arm_compute/core/TensorShape.h"
30
31namespace arm_compute
32{
33namespace test
34{
35namespace datasets
36{
37class GEMMDataset
38{
39public:
40 using type = std::tuple<TensorShape, TensorShape, TensorShape, TensorShape, float, float>;
41
42 struct iterator
43 {
44 iterator(std::vector<TensorShape>::const_iterator a_it,
45 std::vector<TensorShape>::const_iterator b_it,
46 std::vector<TensorShape>::const_iterator c_it,
47 std::vector<TensorShape>::const_iterator dst_it,
48 std::vector<float>::const_iterator alpha_it,
49 std::vector<float>::const_iterator beta_it)
50 : _a_it{ std::move(a_it) },
51 _b_it{ std::move(b_it) },
52 _c_it{ std::move(c_it) },
53 _dst_it{ std::move(dst_it) },
54 _alpha_it{ std::move(alpha_it) },
55 _beta_it{ std::move(beta_it) }
56 {
57 }
58
59 std::string description() const
60 {
61 std::stringstream description;
62 description << "A=" << *_a_it << ":";
63 description << "B=" << *_b_it << ":";
64 description << "C=" << *_c_it << ":";
65 description << "Out=" << *_dst_it << ":";
Moritz Pflanzerec69f932017-07-19 12:03:33 +010066 description << "Alpha=" << *_alpha_it << ":";
Moritz Pflanzeree493ae2017-07-05 10:52:21 +010067 description << "Beta=" << *_beta_it;
68 return description.str();
69 }
70
71 GEMMDataset::type operator*() const
72 {
73 return std::make_tuple(*_a_it, *_b_it, *_c_it, *_dst_it, *_alpha_it, *_beta_it);
74 }
75
76 iterator &operator++()
77 {
78 ++_a_it;
79 ++_b_it;
80 ++_c_it;
81 ++_dst_it;
82 ++_alpha_it;
83 ++_beta_it;
84
85 return *this;
86 }
87
88 private:
89 std::vector<TensorShape>::const_iterator _a_it;
90 std::vector<TensorShape>::const_iterator _b_it;
91 std::vector<TensorShape>::const_iterator _c_it;
92 std::vector<TensorShape>::const_iterator _dst_it;
93 std::vector<float>::const_iterator _alpha_it;
94 std::vector<float>::const_iterator _beta_it;
95 };
96
97 iterator begin() const
98 {
99 return iterator(_a_shapes.begin(), _b_shapes.begin(), _c_shapes.begin(), _dst_shapes.begin(), _alpha.begin(), _beta.begin());
100 }
101
102 int size() const
103 {
104 return std::min(_a_shapes.size(), std::min(_b_shapes.size(), std::min(_c_shapes.size(), std::min(_dst_shapes.size(), std::min(_alpha.size(), _beta.size())))));
105 }
106
107 void add_config(TensorShape a, TensorShape b, TensorShape c, TensorShape dst, float alpha, float beta)
108 {
109 _a_shapes.emplace_back(std::move(a));
110 _b_shapes.emplace_back(std::move(b));
111 _c_shapes.emplace_back(std::move(c));
112 _dst_shapes.emplace_back(std::move(dst));
113 _alpha.emplace_back(std::move(alpha));
114 _beta.emplace_back(std::move(beta));
115 }
116
117protected:
118 GEMMDataset() = default;
119 GEMMDataset(GEMMDataset &&) = default;
120
121private:
122 std::vector<TensorShape> _a_shapes{};
123 std::vector<TensorShape> _b_shapes{};
124 std::vector<TensorShape> _c_shapes{};
125 std::vector<TensorShape> _dst_shapes{};
126 std::vector<float> _alpha{};
127 std::vector<float> _beta{};
128};
129} // namespace datasets
130} // namespace test
131} // namespace arm_compute
132#endif /* ARM_COMPUTE_TEST_GEMM_DATASET */