blob: 9c1c5fb05dfa67615714e6222d404f78d7e1f1f2 [file] [log] [blame]
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +00001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Gunes Bayir8918b232023-03-17 13:52:21 +000024#ifndef ACL_TESTS_DATASETS_MATMULDATASET
25#define ACL_TESTS_DATASETS_MATMULDATASET
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000026
27#include "arm_compute/core/TensorShape.h"
28#include "utils/TypePrinter.h"
29
30namespace arm_compute
31{
32namespace test
33{
34namespace datasets
35{
Gunes Bayir8918b232023-03-17 13:52:21 +000036class MatMulDataset
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000037{
38public:
39 using type = std::tuple<TensorShape, TensorShape, TensorShape>;
40
41 struct iterator
42 {
43 iterator(std::vector<TensorShape>::const_iterator a_it,
44 std::vector<TensorShape>::const_iterator b_it,
45 std::vector<TensorShape>::const_iterator dst_it)
46 : _a_it{ std::move(a_it) },
47 _b_it{ std::move(b_it) },
48 _dst_it{ std::move(dst_it) }
49 {
50 }
51
52 std::string description() const
53 {
54 std::stringstream description;
55 description << "A=" << *_a_it << ":";
56 description << "B=" << *_b_it << ":";
57 description << "Out=" << *_dst_it << ":";
58 return description.str();
59 }
60
Gunes Bayir8918b232023-03-17 13:52:21 +000061 MatMulDataset::type operator*() const
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000062 {
63 return std::make_tuple(*_a_it, *_b_it, *_dst_it);
64 }
65
66 iterator &operator++()
67 {
68 ++_a_it;
69 ++_b_it;
70 ++_dst_it;
71
72 return *this;
73 }
74
75 private:
76 std::vector<TensorShape>::const_iterator _a_it;
77 std::vector<TensorShape>::const_iterator _b_it;
78 std::vector<TensorShape>::const_iterator _dst_it;
79 };
80
81 iterator begin() const
82 {
83 return iterator(_a_shapes.begin(), _b_shapes.begin(), _dst_shapes.begin());
84 }
85
86 int size() const
87 {
88 return std::min(_a_shapes.size(), std::min(_b_shapes.size(), _dst_shapes.size()));
89 }
90
91 void add_config(TensorShape a, TensorShape b, TensorShape dst)
92 {
93 _a_shapes.emplace_back(std::move(a));
94 _b_shapes.emplace_back(std::move(b));
95 _dst_shapes.emplace_back(std::move(dst));
96 }
97
98protected:
Gunes Bayir8918b232023-03-17 13:52:21 +000099 MatMulDataset() = default;
100 MatMulDataset(MatMulDataset &&) = default;
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000101
102private:
103 std::vector<TensorShape> _a_shapes{};
104 std::vector<TensorShape> _b_shapes{};
105 std::vector<TensorShape> _dst_shapes{};
106};
107} // namespace datasets
108} // namespace test
109} // namespace arm_compute
Gunes Bayir8918b232023-03-17 13:52:21 +0000110#endif /* ACL_TESTS_DATASETS_MATMULDATASET */