blob: f4cdd67fb9a147122ffeb3e9c0e4b1e209a2841a [file] [log] [blame]
Samuel Yapfd3ba5a2022-08-24 17:04:34 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ParserFlatbuffersFixture.hpp"
7
8TEST_SUITE("TensorflowLiteParser_BatchMatMul")
9{
10struct BatchMatMulFixture : public ParserFlatbuffersFixture
11{
12 explicit BatchMatMulFixture(const std::string &inputXShape,
13 const std::string &inputYShape,
14 const std::string &outputShape,
15 const std::string &adjX,
16 const std::string &adjY)
17 {
18 m_JsonString = R"(
19 {
20 "version": 3,
21 "operator_codes": [ { "builtin_code": "BATCH_MATMUL" } ],
22 "subgraphs": [
23 {
24 "tensors": [
25 {
26 "shape": )" + inputXShape + R"(,
27 "type": "FLOAT32",
28 "buffer": 0,
29 "name": "inputXTensor",
30 "quantization": {
31 "min": [ 0.0 ],
32 "max": [ 255.0 ],
33 "scale": [ 1.0 ],
34 "zero_point": [ 0 ],
35 }
36 },
37 {
38 "shape": )" + inputYShape + R"(,
39 "type": "FLOAT32",
40 "buffer": 1,
41 "name": "inputYTensor",
42 "quantization": {
43 "min": [ 0.0 ],
44 "max": [ 255.0 ],
45 "scale": [ 1.0 ],
46 "zero_point": [ 0 ],
47 }
48 },
49 {
50 "shape": )" + outputShape + R"(,
51 "type": "FLOAT32",
52 "buffer": 2,
53 "name": "outputTensor",
54 "quantization": {
55 "min": [ 0.0 ],
56 "max": [ 255.0 ],
57 "scale": [ 1.0 ],
58 "zero_point": [ 0 ],
59 }
60 }
61 ],
62 "inputs": [ 0, 1 ],
63 "outputs": [ 2 ],
64 "operators": [
65 {
66 "opcode_index": 0,
67 "inputs": [ 0 , 1 ],
68 "outputs": [ 2 ],
69 "builtin_options_type": "BatchMatMulOptions",
70 "builtin_options": {
71 adj_x: )" + adjX + R"(,
72 adj_y: )" + adjY + R"(,
73 "asymmetric_quantize_inputs": false
74 },
75 "custom_options_format": "FLEXBUFFERS"
76 }
77 ]
78 }
79 ],
80 "buffers": [{},{}]
81 }
82 )";
83 Setup();
84 }
85};
86
87struct BatchMatMulParamsFixture : BatchMatMulFixture
88{
89 BatchMatMulParamsFixture()
90 : BatchMatMulFixture("[ 1, 3, 3 ]",
91 "[ 1, 3, 3 ]",
92 "[ 1, 3, 3 ]",
93 "false",
94 "true")
95 {}
96};
97
98TEST_CASE_FIXTURE(BatchMatMulParamsFixture, "ParseBatchMatMulParams")
99{
100 RunTest<3, armnn::DataType::Float32>(
101 0,
102 {{"inputXTensor", {2.0f, 3.0f, 5.0f,
103 8.0f, 13.0f, 21.0f,
104 34.0f, 55.0f, 89.0f}},
105 {"inputYTensor", {0.0f, 1.0f, 1.0f,
106 1.0f, 0.0f, 1.0f,
107 1.0f, 1.0f, 0.0f}}},
108 {{"outputTensor", {6.0f, 4.0f, 0.0f,
109 26.0f, 16.0f, 0.0f,
110 110.0f, 68.0f, 0.0f}}}
111 );
112}
113
114}