blob: 395038d9593ddd437c1338c08948281e96b7a111 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "../TfLiteParser.hpp"
7#include <iostream>
8#include <string>
9
Sadik Armagan1625efc2021-06-10 18:24:34 +010010#include <doctest/doctest.h>
11
12TEST_SUITE("TensorflowLiteParser_OutputShapeOfSqueeze")
13{
14
telsoa01c577f2c2018-08-31 09:22:23 +010015struct TfLiteParserFixture
16{
17
Kevin May7d96b162021-02-03 17:38:41 +000018 armnnTfLiteParser::TfLiteParserImpl m_Parser;
telsoa01c577f2c2018-08-31 09:22:23 +010019 unsigned int m_InputShape[4];
20
Kevin May7d96b162021-02-03 17:38:41 +000021 TfLiteParserFixture() : m_Parser( ), m_InputShape { 1, 2, 2, 1 } {}
telsoa01c577f2c2018-08-31 09:22:23 +010022 ~TfLiteParserFixture() { }
23
24};
25
Sadik Armagan1625efc2021-06-10 18:24:34 +010026TEST_CASE_FIXTURE(TfLiteParserFixture, "EmptySqueezeDims_OutputWithAllDimensionsSqueezed")
telsoa01c577f2c2018-08-31 09:22:23 +010027{
28
29 std::vector<uint32_t> squeezeDims = { };
30
31 armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
32 armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
Sadik Armagan1625efc2021-06-10 18:24:34 +010033 CHECK(outputTensorInfo.GetNumElements() == 4);
34 CHECK(outputTensorInfo.GetNumDimensions() == 2);
35 CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 2, 2 })));
telsoa01c577f2c2018-08-31 09:22:23 +010036};
37
Sadik Armagan1625efc2021-06-10 18:24:34 +010038TEST_CASE_FIXTURE(TfLiteParserFixture, "SqueezeDimsNotIncludingSizeOneDimensions_NoDimensionsSqueezedInOutput")
telsoa01c577f2c2018-08-31 09:22:23 +010039{
40 std::vector<uint32_t> squeezeDims = { 1, 2 };
41
42 armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
43 armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
Sadik Armagan1625efc2021-06-10 18:24:34 +010044 CHECK(outputTensorInfo.GetNumElements() == 4);
45 CHECK(outputTensorInfo.GetNumDimensions() == 4);
46 CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
telsoa01c577f2c2018-08-31 09:22:23 +010047};
48
Sadik Armagan1625efc2021-06-10 18:24:34 +010049TEST_CASE_FIXTURE(TfLiteParserFixture, "SqueezeDimsRangePartial_OutputWithDimensionsWithinRangeSqueezed")
telsoa01c577f2c2018-08-31 09:22:23 +010050{
51 std::vector<uint32_t> squeezeDims = { 1, 3 };
52
53 armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
54 armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
Sadik Armagan1625efc2021-06-10 18:24:34 +010055 CHECK(outputTensorInfo.GetNumElements() == 4);
56 CHECK(outputTensorInfo.GetNumDimensions() == 3);
57 CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2 })));
telsoa01c577f2c2018-08-31 09:22:23 +010058};
59
Sadik Armagan1625efc2021-06-10 18:24:34 +010060}