blob: 10e682e36a958c25eb41eff18187e8a355400a32 [file] [log] [blame]
Nina Drozd200e3802019-04-15 09:47:39 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "ParserFlatbuffersFixture.hpp"
8#include "../TfLiteParser.hpp"
9
10#include <string>
11#include <iostream>
12
13BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14
15struct UnpackFixture : public ParserFlatbuffersFixture
16{
17 explicit UnpackFixture(const std::string & inputShape,
18 const unsigned int numberOfOutputs,
19 const std::string & outputShape,
20 const std::string & axis,
21 const std::string & num)
22 {
23 // As input index is 0, output indexes start at 1
24 std::string outputIndexes = "1";
25 for(unsigned int i = 1; i < numberOfOutputs; i++)
26 {
27 outputIndexes += ", " + std::to_string(i+1);
28 }
29 m_JsonString = R"(
30 {
31 "version": 3,
32 "operator_codes": [ { "builtin_code": "UNPACK" } ],
33 "subgraphs": [ {
34 "tensors": [
35 {
36 "shape": )" + inputShape + R"(,
37 "type": "FLOAT32",
38 "buffer": 0,
39 "name": "inputTensor",
40 "quantization": {
41 "min": [ 0.0 ],
42 "max": [ 255.0 ],
43 "scale": [ 1.0 ],
44 "zero_point": [ 0 ],
45 }
46 },)";
47 // Append the required number of outputs for this UnpackFixture.
48 // As input index is 0, output indexes start at 1.
49 for(unsigned int i = 0; i < numberOfOutputs; i++)
50 {
51 m_JsonString += R"(
52 {
53 "shape": )" + outputShape + R"( ,
54 "type": "FLOAT32",
55 "buffer": )" + std::to_string(i + 1) + R"(,
56 "name": "outputTensor)" + std::to_string(i + 1) + R"(",
57 "quantization": {
58 "min": [ 0.0 ],
59 "max": [ 255.0 ],
60 "scale": [ 1.0 ],
61 "zero_point": [ 0 ],
62 }
63 },)";
64 }
65 m_JsonString += R"(
66 ],
67 "inputs": [ 0 ],
68 "outputs": [ )" + outputIndexes + R"( ],
69 "operators": [
70 {
71 "opcode_index": 0,
72 "inputs": [ 0 ],
73 "outputs": [ )" + outputIndexes + R"( ],
74 "builtin_options_type": "UnpackOptions",
75 "builtin_options": {
76 "axis": )" + axis;
77
78 if(!num.empty())
79 {
80 m_JsonString += R"(,
81 "num" : )" + num;
82 }
83
84 m_JsonString += R"(
85 },
86 "custom_options_format": "FLEXBUFFERS"
87 }
88 ],
89 } ],
90 "buffers" : [
91 { },
92 { }
93 ]
94 }
95 )";
96 Setup();
97 }
98};
99
100struct DefaultUnpackAxisZeroFixture : UnpackFixture
101{
102 DefaultUnpackAxisZeroFixture() : UnpackFixture("[ 4, 1, 6 ]", 4, "[ 1, 6 ]", "0", "") {}
103};
104
105BOOST_FIXTURE_TEST_CASE(UnpackAxisZeroNumIsDefaultNotSpecified, DefaultUnpackAxisZeroFixture)
106{
107 RunTest<2, armnn::DataType::Float32>(
108 0,
109 { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
110 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
111 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
112 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f } } },
113 { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }},
114 {"outputTensor2", { 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }},
115 {"outputTensor3", { 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }},
116 {"outputTensor4", { 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }} });
117}
118
119BOOST_AUTO_TEST_SUITE_END()