blob: 062c05b1d996e6d206dac920c72e40f2d2646522 [file] [log] [blame]
Gian Marcofa4cacd2017-10-18 17:05:02 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_TEST_GEMMLOWP_DATASET
25#define ARM_COMPUTE_TEST_GEMMLOWP_DATASET
26
27#include "utils/TypePrinter.h"
28
29#include "arm_compute/core/TensorShape.h"
30
31namespace arm_compute
32{
33namespace test
34{
35namespace datasets
36{
37class GEMMLowpDataset
38{
39public:
Gian Marcoe75a02b2017-11-08 12:24:09 +000040 using type = std::tuple<TensorShape, TensorShape, TensorShape, int32_t, int32_t>;
Gian Marcofa4cacd2017-10-18 17:05:02 +010041
42 struct iterator
43 {
44 iterator(std::vector<TensorShape>::const_iterator a_it,
45 std::vector<TensorShape>::const_iterator b_it,
46 std::vector<TensorShape>::const_iterator c_it,
47 std::vector<int32_t>::const_iterator a_offset_it,
Gian Marcoe75a02b2017-11-08 12:24:09 +000048 std::vector<int32_t>::const_iterator b_offset_it)
Gian Marcofa4cacd2017-10-18 17:05:02 +010049 : _a_it{ std::move(a_it) },
50 _b_it{ std::move(b_it) },
51 _c_it{ std::move(c_it) },
52 _a_offset_it{ std::move(a_offset_it) },
Gian Marcoe75a02b2017-11-08 12:24:09 +000053 _b_offset_it{ std::move(b_offset_it) }
Gian Marcofa4cacd2017-10-18 17:05:02 +010054 {
55 }
56
57 std::string description() const
58 {
59 std::stringstream description;
60 description << "A=" << *_a_it << ":";
61 description << "B=" << *_b_it << ":";
62 description << "C=" << *_c_it << ":";
63 description << "a_offset=" << *_a_offset_it << ":";
64 description << "b_offset=" << *_b_offset_it << ":";
Gian Marcofa4cacd2017-10-18 17:05:02 +010065 return description.str();
66 }
67
68 GEMMLowpDataset::type operator*() const
69 {
Gian Marcoe75a02b2017-11-08 12:24:09 +000070 return std::make_tuple(*_a_it, *_b_it, *_c_it, *_a_offset_it, *_b_offset_it);
Gian Marcofa4cacd2017-10-18 17:05:02 +010071 }
72
73 iterator &operator++()
74 {
75 ++_a_it;
76 ++_b_it;
77 ++_c_it;
78 ++_a_offset_it;
79 ++_b_offset_it;
Gian Marcofa4cacd2017-10-18 17:05:02 +010080
81 return *this;
82 }
83
84 private:
85 std::vector<TensorShape>::const_iterator _a_it;
86 std::vector<TensorShape>::const_iterator _b_it;
87 std::vector<TensorShape>::const_iterator _c_it;
88 std::vector<int32_t>::const_iterator _a_offset_it;
89 std::vector<int32_t>::const_iterator _b_offset_it;
Gian Marcofa4cacd2017-10-18 17:05:02 +010090 };
91
92 iterator begin() const
93 {
Gian Marcoe75a02b2017-11-08 12:24:09 +000094 return iterator(_a_shapes.begin(), _b_shapes.begin(), _c_shapes.begin(), _a_offset.begin(), _b_offset.begin());
Gian Marcofa4cacd2017-10-18 17:05:02 +010095 }
96
97 int size() const
98 {
Gian Marcoe75a02b2017-11-08 12:24:09 +000099 return std::min(_a_shapes.size(), std::min(_b_shapes.size(), std::min(_c_shapes.size(), std::min(_a_offset.size(), _b_offset.size()))));
Gian Marcofa4cacd2017-10-18 17:05:02 +0100100 }
101
Gian Marcoe75a02b2017-11-08 12:24:09 +0000102 void add_config(TensorShape a, TensorShape b, TensorShape c, int32_t a_offset, int32_t b_offset)
Gian Marcofa4cacd2017-10-18 17:05:02 +0100103 {
104 _a_shapes.emplace_back(std::move(a));
105 _b_shapes.emplace_back(std::move(b));
106 _c_shapes.emplace_back(std::move(c));
107 _a_offset.emplace_back(std::move(a_offset));
108 _b_offset.emplace_back(std::move(b_offset));
Gian Marcofa4cacd2017-10-18 17:05:02 +0100109 }
110
111protected:
112 GEMMLowpDataset() = default;
113 GEMMLowpDataset(GEMMLowpDataset &&) = default;
114
115private:
116 std::vector<TensorShape> _a_shapes{};
117 std::vector<TensorShape> _b_shapes{};
118 std::vector<TensorShape> _c_shapes{};
119 std::vector<int32_t> _a_offset{};
120 std::vector<int32_t> _b_offset{};
Gian Marcofa4cacd2017-10-18 17:05:02 +0100121};
122} // namespace datasets
123} // namespace test
124} // namespace arm_compute
125#endif /* ARM_COMPUTE_TEST_GEMMLOWP_DATASET */