blob: a9f021f76e0386d60a3a858a13f5a67b32da85bb [file] [log] [blame]
Teresa Charlin3ab85482021-06-08 16:59:29 +01001//
2// Copyright © 2021 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include "ParserFlatbuffersFixture.hpp"
6#include "../TfLiteParser.hpp"
7
8#include <string>
9#include <iostream>
10
11TEST_SUITE("TensorflowLiteParser_ExpandDims")
12{
13struct ExpandDimsFixture : public ParserFlatbuffersFixture
14{
15 explicit ExpandDimsFixture(const std::string& inputShape,
16 const std::string& outputShape,
17 const std::string& axis)
18 {
19 m_JsonString = R"(
20 {
21 "version": 3,
22 "operator_codes": [ { "builtin_code": "EXPAND_DIMS" } ],
23 "subgraphs": [ {
24 "tensors": [
25 {
26 "shape": )" + inputShape + R"(,
27 "type": "UINT8",
28 "buffer": 0,
29 "name": "inputTensor",
30 "quantization": {
31 "min": [ 0.0 ],
32 "max": [ 255.0 ],
33 "scale": [ 1.0 ],
34 "zero_point": [ 0 ],
35 }
36 },
37 {
38 "shape": )" + outputShape + R"( ,
39 "type": "UINT8",
40 "buffer": 1,
41 "name": "outputTensor",
42 "quantization": {
43 "min": [ 0.0 ],
44 "max": [ 255.0 ],
45 "scale": [ 1.0 ],
46 "zero_point": [ 0 ],
47 }
48 },
49 {
50 "shape": [ 1 ],
51 "type": "UINT8",
52 "buffer": 2,
53 "name": "expand_dims",
54 "quantization": {
55 "min": [ 0.0 ],
56 "max": [ 255.0 ],
57 "scale": [ 1.0 ],
58 "zero_point": [ 0 ],
59 }
60 },
61 ],
62 "inputs": [ 0 ],
63 "outputs": [ 1 ],
64 "operators": [
65 {
66 "opcode_index": 0,
67 "inputs": [ 0 , 2 ],
68 "outputs": [ 1 ],
69 "custom_options_format": "FLEXBUFFERS"
70 }
71 ],
72 } ],
73 "buffers" : [
74 { },
75 { },
76 { "data": )" + axis + R"(, },
77 ]
78 }
79 )";
80 SetupSingleInputSingleOutput("inputTensor", "outputTensor");
81 }
82};
83
84struct ExpandDimsFixture3dto4Daxis0 : ExpandDimsFixture
85{
86 ExpandDimsFixture3dto4Daxis0() : ExpandDimsFixture("[ 2, 2, 1 ]", "[ 1, 2, 2, 1 ]", "[ 0, 0, 0, 0 ]") {}
87};
88
89TEST_CASE_FIXTURE(ExpandDimsFixture3dto4Daxis0, "ParseExpandDims3Dto4Daxis0")
90{
91 RunTest<4, armnn::DataType::QAsymmU8>(0, {{ "inputTensor", { 1, 2, 3, 4 } } },
92 {{ "outputTensor", { 1, 2, 3, 4 } } });
93}
94
95struct ExpandDimsFixture3dto4Daxis3 : ExpandDimsFixture
96{
97 ExpandDimsFixture3dto4Daxis3() : ExpandDimsFixture("[ 1, 2, 2 ]", "[ 1, 2, 2, 1 ]", "[ 3, 0, 0, 0 ]") {}
98};
99
100TEST_CASE_FIXTURE(ExpandDimsFixture3dto4Daxis3, "ParseExpandDims3Dto4Daxis3")
101{
102 RunTest<4, armnn::DataType::QAsymmU8>(0, {{ "inputTensor", { 1, 2, 3, 4 } } },
103 {{ "outputTensor", { 1, 2, 3, 4 } } });
104}
105
106}