blob: 40f93ce420a27b033311eab9738bf63ef3ca45d1 [file] [log] [blame]
Samuel Yapa04f4a12022-08-19 11:14:38 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ParserFlatbuffersSerializeFixture.hpp"
7#include <armnnDeserializer/IDeserializer.hpp>
8
9#include <doctest/doctest.h>
10
11#include <string>
12
13TEST_SUITE("Deserializer_BatchMatMul")
14{
15struct BatchMatMulFixture : public ParserFlatbuffersSerializeFixture
16{
17 explicit BatchMatMulFixture(const std::string& inputXShape,
18 const std::string& inputYShape,
19 const std::string& outputShape,
20 const std::string& dataType)
21 {
22 m_JsonString = R"(
23 {
24 inputIds:[
25 0,
26 1
27 ],
28 outputIds:[
29 3
30 ],
31 layers:[
32 {
33 layer_type:"InputLayer",
34 layer:{
35 base:{
36 layerBindingId:0,
37 base:{
38 index:0,
39 layerName:"InputXLayer",
40 layerType:"Input",
41 inputSlots:[
42 {
43 index:0,
44 connection:{
45 sourceLayerIndex:0,
46 outputSlotIndex:0
47 },
48
49 }
50 ],
51 outputSlots:[
52 {
53 index:0,
54 tensorInfo:{
55 dimensions:)" + inputXShape + R"(,
56 dataType:)" + dataType + R"(
57 },
58
59 }
60 ],
61
62 },
63
64 }
65 },
66
67 },
68 {
69 layer_type:"InputLayer",
70 layer:{
71 base:{
72 layerBindingId:1,
73 base:{
74 index:1,
75 layerName:"InputYLayer",
76 layerType:"Input",
77 inputSlots:[
78 {
79 index:0,
80 connection:{
81 sourceLayerIndex:0,
82 outputSlotIndex:0
83 },
84
85 }
86 ],
87 outputSlots:[
88 {
89 index:0,
90 tensorInfo:{
91 dimensions:)" + inputYShape + R"(,
92 dataType:)" + dataType + R"(
93 },
94
95 }
96 ],
97
98 },
99
100 }
101 },
102
103 },
104 {
105 layer_type:"BatchMatMulLayer",
106 layer:{
107 base:{
108 index:2,
109 layerName:"BatchMatMulLayer",
110 layerType:"BatchMatMul",
111 inputSlots:[
112 {
113 index:0,
114 connection:{
115 sourceLayerIndex:0,
116 outputSlotIndex:0
117 },
118
119 },
120 {
121 index:1,
122 connection:{
123 sourceLayerIndex:1,
124 outputSlotIndex:0
125 },
126
127 }
128 ],
129 outputSlots:[
130 {
131 index:0,
132 tensorInfo:{
133 dimensions:)" + outputShape + R"(,
134 dataType:)" + dataType + R"(
135 },
136
137 }
138 ],
139
140 },
141 descriptor:{
142 transposeX:false,
143 transposeY:false,
144 adjointX:false,
145 adjointY:false,
146 dataLayoutX:NHWC,
147 dataLayoutY:NHWC
148 }
149 },
150
151 },
152 {
153 layer_type:"OutputLayer",
154 layer:{
155 base:{
156 layerBindingId:0,
157 base:{
158 index:3,
159 layerName:"OutputLayer",
160 layerType:"Output",
161 inputSlots:[
162 {
163 index:0,
164 connection:{
165 sourceLayerIndex:2,
166 outputSlotIndex:0
167 },
168
169 }
170 ],
171 outputSlots:[
172 {
173 index:0,
174 tensorInfo:{
175 dimensions:)" + outputShape + R"(,
176 dataType:)" + dataType + R"(
177 },
178
179 }
180 ],
181
182 }
183 }
184 },
185
186 }
187 ]
188 }
189 )";
190 Setup();
191 }
192};
193
194struct SimpleBatchMatMulFixture : BatchMatMulFixture
195{
196 SimpleBatchMatMulFixture()
197 : BatchMatMulFixture("[ 1, 2, 2, 1 ]",
198 "[ 1, 2, 2, 1 ]",
199 "[ 1, 2, 2, 1 ]",
200 "Float32")
201 {}
202};
203
204TEST_CASE_FIXTURE(SimpleBatchMatMulFixture, "SimpleBatchMatMulTest")
205{
206 RunTest<4, armnn::DataType::Float32>(
207 0,
208 {{"InputXLayer", { 1.0f, 2.0f, 3.0f, 4.0f }},
209 {"InputYLayer", { 5.0f, 6.0f, 7.0f, 8.0f }}},
210 {{"OutputLayer", { 19.0f, 22.0f, 43.0f, 50.0f }}});
211}
212
213}