blob: 48a86dcefc4f5ca20526a2f12dbd43f142e9fd4a [file] [log] [blame]
Narumol Prangnawaratbfaee6b2021-05-24 18:50:24 +01001//
2// Copyright © 2021 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "ParserFlatbuffersFixture.hpp"
8#include "../TfLiteParser.hpp"
9
10#include <string>
11
12BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
13
14struct PreluFixture : public ParserFlatbuffersFixture
15{
16 explicit PreluFixture(const std::string& inputShape,
17 const std::string& alphaShape,
18 const std::string& outputShape,
19 const std::string& inputIndex,
20 const std::string& alphaData)
21 {
22 m_JsonString = R"(
23 {
24 "version": 3,
25 "operator_codes": [
26 {
27 "builtin_code": "PRELU",
28 "version": 1
29 }
30 ],
31 "subgraphs": [
32 {
33 "tensors": [
34 {
35 "shape": )" + inputShape + R"(,
36 "type": "FLOAT32",
37 "buffer": 1,
38 "name": "input0",
39 "quantization": {
40 "details_type": "NONE",
41 "quantized_dimension": 0
42 },
43 "is_variable": false
44 },
45 {
46 "shape": )" + alphaShape + R"(,
47 "type": "FLOAT32",
48 "buffer": 2,
49 "name": "input1",
50 "quantization": {
51 "details_type": "NONE",
52 "quantized_dimension": 0
53 },
54 "is_variable": false
55 },
56 {
57 "shape": )" + outputShape + R"(,
58 "type": "FLOAT32",
59 "buffer": 3,
60 "name": "output",
61 "quantization": {
62 "details_type": "NONE",
63 "quantized_dimension": 0
64 },
65 "is_variable": false
66 }
67 ],
68 "inputs": )" + inputIndex + R"(,
69 "outputs": [
70 2
71 ],
72 "operators": [
73 {
74 "opcode_index": 0,
75 "inputs": [
76 0,
77 1
78 ],
79 "outputs": [
80 2
81 ],
82 "builtin_options_type": "NONE",
83 "custom_options_format": "FLEXBUFFERS"
84 }
85 ],
86 "name": "main"
87 }
88 ],
89 "description": "MLIR Converted.",
90 "buffers": [
91 {
92 },
93 {
94 },
95 { )" + alphaData + R"(
96 },
97 {
98 }
99 ]
100 }
101 )";
102 Setup();
103 }
104};
105
106struct SimplePreluFixture : PreluFixture
107{
108 SimplePreluFixture() : PreluFixture("[ 2, 3 ]",
Narumol Prangnawarat4a4af112021-05-25 14:26:24 +0100109 "[ 1 ]",
Narumol Prangnawaratbfaee6b2021-05-24 18:50:24 +0100110 "[ 2, 3 ]",
111 "[ 0, 1 ]",
112 "") {}
113};
114
115struct PreluConstAlphaFixture : PreluFixture
116{
117 PreluConstAlphaFixture() : PreluFixture(
Narumol Prangnawarat4a4af112021-05-25 14:26:24 +0100118 "[ 1, 2, 3 ]",
119 "[ 1, 2, 3 ]",
120 "[ 1, 2, 3 ]",
Narumol Prangnawaratbfaee6b2021-05-24 18:50:24 +0100121 "[ 0 ]",
122 "\"data\": [ 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62 ]"){}
123};
124
Narumol Prangnawarat4a4af112021-05-25 14:26:24 +0100125struct PreluBroadcastAlphaFixture : PreluFixture
126{
127 PreluBroadcastAlphaFixture() : PreluFixture(
128 "[ 1, 1, 2, 3 ]",
129 "[ 1, 3 ]",
130 "[ 1, 1, 2, 3 ]",
131 "[ 0 ]",
132 "\"data\": [ 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62 ]"){}
133};
134
Narumol Prangnawaratbfaee6b2021-05-24 18:50:24 +0100135struct PreluDynamicTensorFixture : PreluFixture
136{
137 PreluDynamicTensorFixture() : PreluFixture("[ 2, 3 ]",
138 "[ 1, 1 ]",
139 "[]",
140 "[ 0 ]",
141 "\"data\": [ 0, 0, 128, 62 ]") {}
142};
143
144BOOST_FIXTURE_TEST_CASE(SimplePrelu, SimplePreluFixture)
145{
146 RunTest<2, armnn::DataType::Float32>(
147 0,
148 {{"input0", { -14.f, 2.f, 0.f, 1.f, -5.f, 14.f }},{"input1", { 0.25f }}},
149 {{"output", { -3.5f, 2.f, 0.f, 1.f, -1.25f, 14.f }}});
150}
151
152BOOST_FIXTURE_TEST_CASE(PreluConstAlpha, PreluConstAlphaFixture)
153{
Narumol Prangnawarat4a4af112021-05-25 14:26:24 +0100154 RunTest<3, armnn::DataType::Float32>(
155 0,
156 {{"input0", { -14.f, 2.f, 0.f, 1.f, -5.f, 14.f }}},
157 {{"output", { -3.5f, 2.f, 0.f, 1.f, -1.25f, 14.f }}});
158}
159
160BOOST_FIXTURE_TEST_CASE(PreluBroadcastAlpha, PreluBroadcastAlphaFixture)
161{
162 RunTest<4, armnn::DataType::Float32>(
Narumol Prangnawaratbfaee6b2021-05-24 18:50:24 +0100163 0,
164 {{"input0", { -14.f, 2.f, 0.f, 1.f, -5.f, 14.f }}},
165 {{"output", { -3.5f, 2.f, 0.f, 1.f, -1.25f, 14.f }}});
166}
167
168BOOST_FIXTURE_TEST_CASE(PreluDynamicTensor, PreluDynamicTensorFixture)
169{
170 RunTest<2, armnn::DataType::Float32, armnn::DataType::Float32>(
171 0,
172 {{"input0", { -14.f, 2.f, 0.f, 1.f, -5.f, 14.f }}},
173 {{"output", { -3.5f, 2.f, 0.f, 1.f, -1.25f, 14.f }}},
174 true);
175}
176
177BOOST_AUTO_TEST_SUITE_END()