blob: 590675b46c0de44689ff85b075c59018bb493bad [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include <boost/test/unit_test.hpp>
7#include "../TfLiteParser.hpp"
8#include <iostream>
9#include <string>
10
11struct TfLiteParserFixture
12{
13
14 armnnTfLiteParser::TfLiteParser m_Parser;
15 unsigned int m_InputShape[4];
16
17 TfLiteParserFixture() : m_Parser( ), m_InputShape { 1, 2, 2, 1 } {
18 m_Parser.Create();
19 }
20 ~TfLiteParserFixture() { }
21
22};
23
24BOOST_AUTO_TEST_SUITE(TensorflowLiteParser);
25
26
27BOOST_FIXTURE_TEST_CASE( EmptySqueezeDims_OutputWithAllDimensionsSqueezed, TfLiteParserFixture )
28{
29
30 std::vector<uint32_t> squeezeDims = { };
31
32 armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
33 armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
34 BOOST_TEST(outputTensorInfo.GetNumElements() == 4);
35 BOOST_TEST(outputTensorInfo.GetNumDimensions() == 2);
36 BOOST_TEST((outputTensorInfo.GetShape() == armnn::TensorShape({ 2, 2 })));
37};
38
39BOOST_FIXTURE_TEST_CASE( SqueezeDimsNotIncludingSizeOneDimensions_NoDimensionsSqueezedInOutput, TfLiteParserFixture )
40{
41 std::vector<uint32_t> squeezeDims = { 1, 2 };
42
43 armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
44 armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
45 BOOST_TEST(outputTensorInfo.GetNumElements() == 4);
46 BOOST_TEST(outputTensorInfo.GetNumDimensions() == 4);
47 BOOST_TEST((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
48};
49
50BOOST_FIXTURE_TEST_CASE( SqueezeDimsRangePartial_OutputWithDimensionsWithinRangeSqueezed, TfLiteParserFixture )
51{
52 std::vector<uint32_t> squeezeDims = { 1, 3 };
53
54 armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
55 armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
56 BOOST_TEST(outputTensorInfo.GetNumElements() == 4);
57 BOOST_TEST(outputTensorInfo.GetNumDimensions() == 3);
58 BOOST_TEST((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2 })));
59};
60
61BOOST_AUTO_TEST_SUITE_END();