blob: 3124498d66fedd400377c8dece6bb05c08aa8f00 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
Finn Williamsb49ed182021-06-29 15:50:08 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "../TfLiteParser.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01007
Sadik Armagan1625efc2021-06-10 18:24:34 +01008#include <doctest/doctest.h>
9
10TEST_SUITE("TensorflowLiteParser_OutputShapeOfSqueeze")
11{
12
telsoa01c577f2c2018-08-31 09:22:23 +010013struct TfLiteParserFixture
14{
15
Kevin May7d96b162021-02-03 17:38:41 +000016 armnnTfLiteParser::TfLiteParserImpl m_Parser;
telsoa01c577f2c2018-08-31 09:22:23 +010017 unsigned int m_InputShape[4];
18
Kevin May7d96b162021-02-03 17:38:41 +000019 TfLiteParserFixture() : m_Parser( ), m_InputShape { 1, 2, 2, 1 } {}
telsoa01c577f2c2018-08-31 09:22:23 +010020 ~TfLiteParserFixture() { }
21
22};
23
Sadik Armagan1625efc2021-06-10 18:24:34 +010024TEST_CASE_FIXTURE(TfLiteParserFixture, "EmptySqueezeDims_OutputWithAllDimensionsSqueezed")
telsoa01c577f2c2018-08-31 09:22:23 +010025{
26
27 std::vector<uint32_t> squeezeDims = { };
28
29 armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
30 armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
Sadik Armagan1625efc2021-06-10 18:24:34 +010031 CHECK(outputTensorInfo.GetNumElements() == 4);
32 CHECK(outputTensorInfo.GetNumDimensions() == 2);
33 CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 2, 2 })));
telsoa01c577f2c2018-08-31 09:22:23 +010034};
35
Sadik Armagan1625efc2021-06-10 18:24:34 +010036TEST_CASE_FIXTURE(TfLiteParserFixture, "SqueezeDimsNotIncludingSizeOneDimensions_NoDimensionsSqueezedInOutput")
telsoa01c577f2c2018-08-31 09:22:23 +010037{
38 std::vector<uint32_t> squeezeDims = { 1, 2 };
39
40 armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
41 armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
Sadik Armagan1625efc2021-06-10 18:24:34 +010042 CHECK(outputTensorInfo.GetNumElements() == 4);
43 CHECK(outputTensorInfo.GetNumDimensions() == 4);
44 CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
telsoa01c577f2c2018-08-31 09:22:23 +010045};
46
Sadik Armagan1625efc2021-06-10 18:24:34 +010047TEST_CASE_FIXTURE(TfLiteParserFixture, "SqueezeDimsRangePartial_OutputWithDimensionsWithinRangeSqueezed")
telsoa01c577f2c2018-08-31 09:22:23 +010048{
49 std::vector<uint32_t> squeezeDims = { 1, 3 };
50
51 armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
52 armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
Sadik Armagan1625efc2021-06-10 18:24:34 +010053 CHECK(outputTensorInfo.GetNumElements() == 4);
54 CHECK(outputTensorInfo.GetNumDimensions() == 3);
55 CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2 })));
telsoa01c577f2c2018-08-31 09:22:23 +010056};
57
Sadik Armagan1625efc2021-06-10 18:24:34 +010058}