blob: e616158f2939d83c3bb5bb862d82bf49da170340 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#include <boost/test/unit_test.hpp>
7#include "../TfLiteParser.hpp"
8#include <iostream>
9#include <string>
10
11struct TfLiteParserFixture
12{
13
Kevin May7d96b162021-02-03 17:38:41 +000014 armnnTfLiteParser::TfLiteParserImpl m_Parser;
telsoa01c577f2c2018-08-31 09:22:23 +010015 unsigned int m_InputShape[4];
16
Kevin May7d96b162021-02-03 17:38:41 +000017 TfLiteParserFixture() : m_Parser( ), m_InputShape { 1, 2, 2, 1 } {}
telsoa01c577f2c2018-08-31 09:22:23 +010018 ~TfLiteParserFixture() { }
19
20};
21
22BOOST_AUTO_TEST_SUITE(TensorflowLiteParser);
23
24
25BOOST_FIXTURE_TEST_CASE( EmptySqueezeDims_OutputWithAllDimensionsSqueezed, TfLiteParserFixture )
26{
27
28 std::vector<uint32_t> squeezeDims = { };
29
30 armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
31 armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
32 BOOST_TEST(outputTensorInfo.GetNumElements() == 4);
33 BOOST_TEST(outputTensorInfo.GetNumDimensions() == 2);
34 BOOST_TEST((outputTensorInfo.GetShape() == armnn::TensorShape({ 2, 2 })));
35};
36
37BOOST_FIXTURE_TEST_CASE( SqueezeDimsNotIncludingSizeOneDimensions_NoDimensionsSqueezedInOutput, TfLiteParserFixture )
38{
39 std::vector<uint32_t> squeezeDims = { 1, 2 };
40
41 armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
42 armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
43 BOOST_TEST(outputTensorInfo.GetNumElements() == 4);
44 BOOST_TEST(outputTensorInfo.GetNumDimensions() == 4);
45 BOOST_TEST((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
46};
47
48BOOST_FIXTURE_TEST_CASE( SqueezeDimsRangePartial_OutputWithDimensionsWithinRangeSqueezed, TfLiteParserFixture )
49{
50 std::vector<uint32_t> squeezeDims = { 1, 3 };
51
52 armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
53 armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
54 BOOST_TEST(outputTensorInfo.GetNumElements() == 4);
55 BOOST_TEST(outputTensorInfo.GetNumDimensions() == 3);
56 BOOST_TEST((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2 })));
57};
58
59BOOST_AUTO_TEST_SUITE_END();