blob: cf43a6302af1056a8cb65ba411aa4215b069dfc9 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_TEST_DATASET_FULLY_CONNECTED_LAYER_DATASET_H__
25#define __ARM_COMPUTE_TEST_DATASET_FULLY_CONNECTED_LAYER_DATASET_H__
26
27#include "TypePrinter.h"
28
29#include "arm_compute/core/TensorShape.h"
30#include "dataset/GenericDataset.h"
31
32#include <sstream>
33#include <type_traits>
34
35#ifdef BOOST
36#include "boost_wrapper.h"
Anthony Barbierac69aa12017-07-03 17:39:37 +010037#endif /* BOOST */
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038
39namespace arm_compute
40{
41namespace test
42{
43class FullyConnectedLayerDataObject
44{
45public:
46 operator std::string() const
47 {
48 std::stringstream ss;
49 ss << "FullyConnectedLayer";
50 ss << "_I" << src_shape;
51 ss << "_K" << weights_shape;
52 return ss.str();
53 }
54
55 friend std::ostream &operator<<(std::ostream &os, const FullyConnectedLayerDataObject &obj)
56 {
57 os << static_cast<std::string>(obj);
58 return os;
59 }
60
61public:
62 TensorShape src_shape;
63 TensorShape weights_shape;
64 TensorShape bias_shape;
65 TensorShape dst_shape;
66 bool transpose_weights;
67 bool are_weights_reshaped;
68};
69
70template <unsigned int Size>
71using FullyConnectedLayerDataset = GenericDataset<FullyConnectedLayerDataObject, Size>;
72
73class SmallFullyConnectedLayerDataset final : public FullyConnectedLayerDataset<5>
74{
75public:
76 SmallFullyConnectedLayerDataset()
77 : GenericDataset
78 {
79 FullyConnectedLayerDataObject{ TensorShape(9U, 5U, 7U), TensorShape(315U, 271U), TensorShape(271U), TensorShape(271U), true, false },
80 FullyConnectedLayerDataObject{ TensorShape(9U, 5U, 7U, 3U), TensorShape(315U, 271U), TensorShape(271U), TensorShape(271U, 3U), true, false },
81 FullyConnectedLayerDataObject{ TensorShape(201U), TensorShape(201U, 529U), TensorShape(529U), TensorShape(529U), true, false },
82 FullyConnectedLayerDataObject{ TensorShape(9U, 5U, 7U), TensorShape(315U, 271U), TensorShape(271U), TensorShape(271U), true, true },
83 FullyConnectedLayerDataObject{ TensorShape(201U), TensorShape(201U, 529U), TensorShape(529U), TensorShape(529U), true, true },
84 }
85 {
86 }
87
88 ~SmallFullyConnectedLayerDataset() = default;
89};
90
91class LargeFullyConnectedLayerDataset final : public FullyConnectedLayerDataset<5>
92{
93public:
94 LargeFullyConnectedLayerDataset()
95 : GenericDataset
96 {
97 FullyConnectedLayerDataObject{ TensorShape(9U, 5U, 257U), TensorShape(11565U, 2123U), TensorShape(2123U), TensorShape(2123U), true, false },
98 FullyConnectedLayerDataObject{ TensorShape(9U, 5U, 257U, 2U), TensorShape(11565U, 2123U), TensorShape(2123U), TensorShape(2123U, 2U), true, false },
99 FullyConnectedLayerDataObject{ TensorShape(3127U), TensorShape(3127U, 989U), TensorShape(989U), TensorShape(989U), true, false },
100 FullyConnectedLayerDataObject{ TensorShape(9U, 5U, 257U), TensorShape(11565U, 2123U), TensorShape(2123U), TensorShape(2123U), true, true },
101 FullyConnectedLayerDataObject{ TensorShape(3127U), TensorShape(3127U, 989U), TensorShape(989U), TensorShape(989U), true, true },
102 }
103 {
104 }
105
106 ~LargeFullyConnectedLayerDataset() = default;
107};
108
109class AlexNetFullyConnectedLayerDataset final : public FullyConnectedLayerDataset<3>
110{
111public:
112 AlexNetFullyConnectedLayerDataset()
113 : GenericDataset
114 {
115 FullyConnectedLayerDataObject{ TensorShape(6U, 6U, 256U), TensorShape(9216U, 4096U), TensorShape(4096U), TensorShape(4096U), true },
116 FullyConnectedLayerDataObject{ TensorShape(4096U), TensorShape(4096U, 4096U), TensorShape(4096U), TensorShape(4096U), true },
117 FullyConnectedLayerDataObject{ TensorShape(4096U), TensorShape(4096U, 1000U), TensorShape(1000U), TensorShape(1000U), true },
118 }
119 {
120 }
121
122 ~AlexNetFullyConnectedLayerDataset() = default;
123};
124
125class LeNet5FullyConnectedLayerDataset final : public FullyConnectedLayerDataset<2>
126{
127public:
128 LeNet5FullyConnectedLayerDataset()
129 : GenericDataset
130 {
131 FullyConnectedLayerDataObject{ TensorShape(4U, 4U, 50U), TensorShape(800U, 500U), TensorShape(500U), TensorShape(500U) },
132 FullyConnectedLayerDataObject{ TensorShape(500U), TensorShape(500U, 10U), TensorShape(10U), TensorShape(10U) },
133 }
134 {
135 }
136
137 ~LeNet5FullyConnectedLayerDataset() = default;
138};
139
140class GoogLeNetFullyConnectedLayerDataset final : public FullyConnectedLayerDataset<1>
141{
142public:
143 GoogLeNetFullyConnectedLayerDataset()
144 : GenericDataset
145 {
146 FullyConnectedLayerDataObject{ TensorShape(1024U), TensorShape(1024U, 1000U), TensorShape(1000U), TensorShape(1000U), true },
147 }
148 {
149 }
150
151 ~GoogLeNetFullyConnectedLayerDataset() = default;
152};
153} // namespace test
154} // namespace arm_compute
Anthony Barbierac69aa12017-07-03 17:39:37 +0100155#endif /* __ARM_COMPUTE_TEST_DATASET_FULLY_CONNECTED_LAYER_DATASET_H__ */