blob: ce055572310f605227282eb511b32a3412c2cc6f [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_TEST_DATASET_GEMM_DATASET_H__
25#define __ARM_COMPUTE_TEST_DATASET_GEMM_DATASET_H__
26
27#include "TypePrinter.h"
28
29#include "arm_compute/core/TensorShape.h"
30#include "arm_compute/core/Types.h"
31#include "dataset/GenericDataset.h"
32
33#include <ostream>
34#include <sstream>
35
36#include <tuple>
37#include <type_traits>
38
39#ifdef BOOST
40#include "boost_wrapper.h"
Anthony Barbierac69aa12017-07-03 17:39:37 +010041#endif /* BOOST */
Anthony Barbier6ff3b192017-09-04 18:44:23 +010042
43namespace arm_compute
44{
45namespace test
46{
47class GEMMDataObject
48{
49public:
50 //Data object used for matrix multiple
51 //D = alpha * A * B + beta * C;
52 TensorShape shape_a;
53 TensorShape shape_b;
54 TensorShape shape_c;
55 TensorShape shape_d;
56 float alpha;
57 float beta;
58
59 operator std::string() const
60 {
61 std::stringstream ss;
62 ss << "GEMM";
63 ss << "_A" << shape_a;
64 ss << "_B" << shape_b;
65 ss << "_C" << shape_c;
66 ss << "_D" << shape_d;
67 ss << "_alpha" << alpha;
68 ss << "_beta" << beta;
69 return ss.str();
70 }
71
72 friend std::ostream &operator<<(std::ostream &os, const GEMMDataObject &obj)
73 {
74 os << static_cast<std::string>(obj);
75 return os;
76 }
77};
78
79class SmallGEMMDataset : public GenericDataset<GEMMDataObject, 4>
80{
81public:
82 SmallGEMMDataset()
83 : GenericDataset
84 {
SiCong Li10c672c2017-06-22 15:46:40 +010085 GEMMDataObject{ TensorShape(21U, 13U), TensorShape(33U, 21U), TensorShape(33U, 13U), TensorShape(33U, 13U), 1.0f, 0.0f },
86 GEMMDataObject{ TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U, 1U), TensorShape(23U, 1U), 1.0f, 0.0f },
87 GEMMDataObject{ TensorShape(38U, 12U), TensorShape(21U, 38U), TensorShape(21U, 12U), TensorShape(21U, 12U), 0.2f, 1.2f },
88 GEMMDataObject{ TensorShape(32U, 1U), TensorShape(17U, 32U), TensorShape(17U, 1U), TensorShape(17U, 1U), 0.4f, 0.7f },
Anthony Barbier6ff3b192017-09-04 18:44:23 +010089 }
90 {
91 }
92
93 ~SmallGEMMDataset() = default;
94};
95
96class LargeGEMMDataset : public GenericDataset<GEMMDataObject, 4>
97{
98public:
99 LargeGEMMDataset()
100 : GenericDataset
101 {
SiCong Li10c672c2017-06-22 15:46:40 +0100102 GEMMDataObject{ TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U, 429U), TensorShape(871U, 429U), 1.0f, 0.0f },
103 GEMMDataObject{ TensorShape(1021U, 1U), TensorShape(783U, 1021U), TensorShape(783U, 1U), TensorShape(783U, 1U), 1.0f, 0.0f },
104 GEMMDataObject{ TensorShape(681U, 1023U), TensorShape(213U, 681U), TensorShape(213U, 1023U), TensorShape(213U, 1023U), 0.2f, 1.2f },
105 GEMMDataObject{ TensorShape(941U, 1U), TensorShape(623U, 941U), TensorShape(623U, 1U), TensorShape(623U, 1U), 0.4f, 0.7f },
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100106 }
107 {
108 }
109
110 ~LargeGEMMDataset() = default;
111};
112
113class GoogLeNetGEMMDataset1 : public GenericDataset<GEMMDataObject, 32>
114{
115public:
116 GoogLeNetGEMMDataset1()
117 : GenericDataset
118 {
119 GEMMDataObject{ TensorShape(147U, 12544U), TensorShape(64U, 147U), TensorShape(64U, 12544U), TensorShape(64U, 12544U), 1.0f, 0.0f },
120 GEMMDataObject{ TensorShape(64U, 3136U), TensorShape(64U, 64U), TensorShape(64U, 3136U), TensorShape(64U, 3136U), 1.0f, 0.0f },
121 GEMMDataObject{ TensorShape(576U, 3136U), TensorShape(192U, 576U), TensorShape(192U, 3136U), TensorShape(192U, 3136U), 1.0f, 0.0f },
122 GEMMDataObject{ TensorShape(192U, 784U), TensorShape(64U, 192U), TensorShape(64U, 784U), TensorShape(64U, 784U), 1.0f, 0.0f },
123 GEMMDataObject{ TensorShape(192U, 784U), TensorShape(96U, 192U), TensorShape(96U, 784U), TensorShape(96U, 784U), 1.0f, 0.0f },
124 GEMMDataObject{ TensorShape(864U, 784U), TensorShape(128U, 864U), TensorShape(128U, 784U), TensorShape(128U, 784U), 1.0f, 0.0f },
125 GEMMDataObject{ TensorShape(192U, 784U), TensorShape(16U, 192U), TensorShape(16U, 784U), TensorShape(16U, 784U), 1.0f, 0.0f },
126 GEMMDataObject{ TensorShape(400U, 784U), TensorShape(32U, 400U), TensorShape(32U, 784U), TensorShape(32U, 784U), 1.0f, 0.0f },
127 GEMMDataObject{ TensorShape(192U, 784U), TensorShape(32U, 192U), TensorShape(32U, 784U), TensorShape(32U, 784U), 1.0f, 0.0f },
128 GEMMDataObject{ TensorShape(256U, 784U), TensorShape(128U, 256U), TensorShape(128U, 784U), TensorShape(128U, 784U), 1.0f, 0.0f },
129 GEMMDataObject{ TensorShape(256U, 784U), TensorShape(128U, 256U), TensorShape(128U, 784U), TensorShape(128U, 784U), 1.0f, 0.0f },
130 GEMMDataObject{ TensorShape(1152U, 784U), TensorShape(192U, 1152U), TensorShape(192U, 784U), TensorShape(192U, 784U), 1.0f, 0.0f },
131 GEMMDataObject{ TensorShape(256U, 784U), TensorShape(32U, 256U), TensorShape(32U, 784U), TensorShape(32U, 784U), 1.0f, 0.0f },
132 GEMMDataObject{ TensorShape(800U, 784U), TensorShape(96U, 800U), TensorShape(96U, 784U), TensorShape(96U, 784U), 1.0f, 0.0f },
133 GEMMDataObject{ TensorShape(256U, 784U), TensorShape(64U, 256U), TensorShape(64U, 784U), TensorShape(64U, 784U), 1.0f, 0.0f },
134 GEMMDataObject{ TensorShape(480U, 196U), TensorShape(192U, 480U), TensorShape(192U, 196U), TensorShape(192U, 196U), 1.0f, 0.0f },
135 GEMMDataObject{ TensorShape(480U, 196U), TensorShape(96U, 480U), TensorShape(96U, 196U), TensorShape(96U, 196U), 1.0f, 0.0f },
136 GEMMDataObject{ TensorShape(864U, 196U), TensorShape(204U, 864U), TensorShape(204U, 196U), TensorShape(204U, 196U), 1.0f, 0.0f },
137 GEMMDataObject{ TensorShape(480U, 196U), TensorShape(16U, 480U), TensorShape(16U, 196U), TensorShape(16U, 196U), 1.0f, 0.0f },
138 GEMMDataObject{ TensorShape(400U, 196U), TensorShape(48U, 400U), TensorShape(48U, 196U), TensorShape(48U, 196U), 1.0f, 0.0f },
139 GEMMDataObject{ TensorShape(480U, 196U), TensorShape(64U, 480U), TensorShape(64U, 196U), TensorShape(64U, 196U), 1.0f, 0.0f },
140 GEMMDataObject{ TensorShape(508U, 196U), TensorShape(160U, 508U), TensorShape(160U, 196U), TensorShape(160U, 196U), 1.0f, 0.0f },
141 GEMMDataObject{ TensorShape(508U, 196U), TensorShape(112U, 508U), TensorShape(112U, 196U), TensorShape(112U, 196U), 1.0f, 0.0f },
142 GEMMDataObject{ TensorShape(1008U, 196U), TensorShape(224U, 1008U), TensorShape(224U, 196U), TensorShape(224U, 196U), 1.0f, 0.0f },
143 GEMMDataObject{ TensorShape(508U, 196U), TensorShape(24U, 508U), TensorShape(24U, 196U), TensorShape(24U, 196U), 1.0f, 0.0f },
144 GEMMDataObject{ TensorShape(600U, 196U), TensorShape(64U, 600U), TensorShape(64U, 196U), TensorShape(64U, 196U), 1.0f, 0.0f },
145 GEMMDataObject{ TensorShape(508U, 196U), TensorShape(64U, 508U), TensorShape(64U, 196U), TensorShape(64U, 196U), 1.0f, 0.0f },
146 GEMMDataObject{ TensorShape(512U, 196U), TensorShape(128U, 512U), TensorShape(128U, 196U), TensorShape(128U, 196U), 1.0f, 0.0f },
147 GEMMDataObject{ TensorShape(512U, 196U), TensorShape(128U, 512U), TensorShape(128U, 196U), TensorShape(128U, 196U), 1.0f, 0.0f },
148 GEMMDataObject{ TensorShape(1152U, 196U), TensorShape(256U, 1152U), TensorShape(256U, 196U), TensorShape(256U, 196U), 1.0f, 0.0f },
149 GEMMDataObject{ TensorShape(512U, 196U), TensorShape(24U, 512U), TensorShape(24U, 196U), TensorShape(24U, 196U), 1.0f, 0.0f },
150 GEMMDataObject{ TensorShape(600U, 196U), TensorShape(64U, 600U), TensorShape(64U, 196U), TensorShape(64U, 196U), 1.0f, 0.0f }
151 }
152 {
153 }
154
155 ~GoogLeNetGEMMDataset1() = default;
156};
157
158class GoogLeNetGEMMDataset2 : public GenericDataset<GEMMDataObject, 32>
159{
160public:
161 GoogLeNetGEMMDataset2()
162 : GenericDataset
163 {
164 GEMMDataObject{ TensorShape(512U, 196U), TensorShape(64U, 512U), TensorShape(64U, 196U), TensorShape(64U, 196U), 1.0f, 0.0f },
165 GEMMDataObject{ TensorShape(512U, 196U), TensorShape(112U, 512U), TensorShape(112U, 196U), TensorShape(112U, 196U), 1.0f, 0.0f },
166 GEMMDataObject{ TensorShape(512U, 196U), TensorShape(144U, 512U), TensorShape(144U, 196U), TensorShape(144U, 196U), 1.0f, 0.0f },
167 GEMMDataObject{ TensorShape(1296U, 196U), TensorShape(288U, 1296U), TensorShape(288U, 196U), TensorShape(288U, 196U), 1.0f, 0.0f },
168 GEMMDataObject{ TensorShape(512U, 196U), TensorShape(32U, 512U), TensorShape(32U, 196U), TensorShape(32U, 196U), 1.0f, 0.0f },
169 GEMMDataObject{ TensorShape(800U, 196U), TensorShape(64U, 800U), TensorShape(64U, 196U), TensorShape(64U, 196U), 1.0f, 0.0f },
170 GEMMDataObject{ TensorShape(512U, 196U), TensorShape(64U, 512U), TensorShape(64U, 196U), TensorShape(64U, 196U), 1.0f, 0.0f },
171 GEMMDataObject{ TensorShape(528U, 196U), TensorShape(256U, 528U), TensorShape(256U, 196U), TensorShape(256U, 196U), 1.0f, 0.0f },
172 GEMMDataObject{ TensorShape(528U, 196U), TensorShape(160U, 528U), TensorShape(160U, 196U), TensorShape(160U, 196U), 1.0f, 0.0f },
173 GEMMDataObject{ TensorShape(1440U, 196U), TensorShape(320U, 1440U), TensorShape(320U, 196U), TensorShape(320U, 196U), 1.0f, 0.0f },
174 GEMMDataObject{ TensorShape(528U, 196U), TensorShape(32U, 528U), TensorShape(32U, 196U), TensorShape(32U, 196U), 1.0f, 0.0f },
175 GEMMDataObject{ TensorShape(800U, 196U), TensorShape(128U, 800U), TensorShape(128U, 196U), TensorShape(128U, 196U), 1.0f, 0.0f },
176 GEMMDataObject{ TensorShape(528U, 196U), TensorShape(128U, 528U), TensorShape(128U, 196U), TensorShape(128U, 196U), 1.0f, 0.0f },
177 GEMMDataObject{ TensorShape(832U, 49U), TensorShape(256U, 832U), TensorShape(256U, 49U), TensorShape(256U, 49U), 1.0f, 0.0f },
178 GEMMDataObject{ TensorShape(832U, 49U), TensorShape(160U, 832U), TensorShape(160U, 49U), TensorShape(160U, 49U), 1.0f, 0.0f },
179 GEMMDataObject{ TensorShape(1440U, 49U), TensorShape(320U, 1440U), TensorShape(320U, 49U), TensorShape(320U, 49U), 1.0f, 0.0f },
180 GEMMDataObject{ TensorShape(832U, 49U), TensorShape(48U, 832U), TensorShape(48U, 49U), TensorShape(48U, 49U), 1.0f, 0.0f },
181 GEMMDataObject{ TensorShape(1200U, 49U), TensorShape(128U, 1200U), TensorShape(128U, 49U), TensorShape(128U, 49U), 1.0f, 0.0f },
182 GEMMDataObject{ TensorShape(832U, 49U), TensorShape(128U, 832U), TensorShape(128U, 49U), TensorShape(128U, 49U), 1.0f, 0.0f },
183 GEMMDataObject{ TensorShape(832U, 49U), TensorShape(384U, 832U), TensorShape(384U, 49U), TensorShape(384U, 49U), 1.0f, 0.0f },
184 GEMMDataObject{ TensorShape(832U, 49U), TensorShape(192U, 832U), TensorShape(192U, 49U), TensorShape(192U, 49U), 1.0f, 0.0f },
185 GEMMDataObject{ TensorShape(1728U, 49U), TensorShape(384U, 1728U), TensorShape(384U, 49U), TensorShape(384U, 49U), 1.0f, 0.0f },
186 GEMMDataObject{ TensorShape(832U, 49U), TensorShape(48U, 832U), TensorShape(48U, 49U), TensorShape(48U, 49U), 1.0f, 0.0f },
187 GEMMDataObject{ TensorShape(1200U, 49U), TensorShape(128U, 1200U), TensorShape(128U, 49U), TensorShape(128U, 49U), 1.0f, 0.0f },
188 GEMMDataObject{ TensorShape(832U, 49U), TensorShape(128U, 832U), TensorShape(128U, 49U), TensorShape(128U, 49U), 1.0f, 0.0f },
189 GEMMDataObject{ TensorShape(508U, 16U), TensorShape(128U, 508U), TensorShape(128U, 16U), TensorShape(128U, 16U), 1.0f, 0.0f },
190 GEMMDataObject{ TensorShape(2048U, 1U), TensorShape(1024U, 2048U), TensorShape(1024U, 1U), TensorShape(1024U, 1U), 1.0f, 0.0f },
191 GEMMDataObject{ TensorShape(1024U, 1U), TensorShape(1008U, 1024U), TensorShape(1008U, 1U), TensorShape(1008U, 1U), 1.0f, 0.0f },
192 GEMMDataObject{ TensorShape(528U, 16U), TensorShape(128U, 528U), TensorShape(128U, 16U), TensorShape(128U, 16U), 1.0f, 0.0f },
193 GEMMDataObject{ TensorShape(2048U, 1U), TensorShape(1024U, 2048U), TensorShape(1024U, 1U), TensorShape(1024U, 1U), 1.0f, 0.0f },
194 GEMMDataObject{ TensorShape(1024U, 1U), TensorShape(1008U, 1024U), TensorShape(1008U, 1U), TensorShape(1008U, 1U), 1.0f, 0.0f },
195 GEMMDataObject{ TensorShape(1024U, 1U), TensorShape(1008U, 1024U), TensorShape(1008U, 1U), TensorShape(1008U, 1U), 1.0f, 0.0f }
196 }
197 {
198 }
199
200 ~GoogLeNetGEMMDataset2() = default;
201};
SiCong Li10c672c2017-06-22 15:46:40 +0100202
203class MatrixMultiplyDataset : public GenericDataset<GEMMDataObject, 3>
204{
205public:
206 MatrixMultiplyDataset()
207 : GenericDataset
208 {
209 GEMMDataObject{ TensorShape(1024U, 1U), TensorShape(1000U, 1024U), TensorShape(1000U, 1U), TensorShape(1000U, 1U), 1.0f, 0.0f },
210 GEMMDataObject{ TensorShape(256U, 784U), TensorShape(64U, 256U), TensorShape(64U, 784U), TensorShape(64U, 784U), 1.0f, 0.0f },
211 GEMMDataObject{ TensorShape(1152U, 2704U), TensorShape(256U, 1152U), TensorShape(256U, 2704U), TensorShape(256U, 2704U), 1.0f, 0.0f },
212 }
213 {
214 }
215
216 ~MatrixMultiplyDataset() = default;
217};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100218} // namespace test
219} // namespace arm_compute
220#endif //__ARM_COMPUTE_TEST_DATASET_GEMM_DATASET_H__