blob: 3124498d66fedd400377c8dece6bb05c08aa8f00 [file] [log] [blame]
//
// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "../TfLiteParser.hpp"
#include <doctest/doctest.h>
TEST_SUITE("TensorflowLiteParser_OutputShapeOfSqueeze")
{
struct TfLiteParserFixture
{
armnnTfLiteParser::TfLiteParserImpl m_Parser;
unsigned int m_InputShape[4];
TfLiteParserFixture() : m_Parser( ), m_InputShape { 1, 2, 2, 1 } {}
~TfLiteParserFixture() { }
};
TEST_CASE_FIXTURE(TfLiteParserFixture, "EmptySqueezeDims_OutputWithAllDimensionsSqueezed")
{
std::vector<uint32_t> squeezeDims = { };
armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
CHECK(outputTensorInfo.GetNumElements() == 4);
CHECK(outputTensorInfo.GetNumDimensions() == 2);
CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 2, 2 })));
};
TEST_CASE_FIXTURE(TfLiteParserFixture, "SqueezeDimsNotIncludingSizeOneDimensions_NoDimensionsSqueezedInOutput")
{
std::vector<uint32_t> squeezeDims = { 1, 2 };
armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
CHECK(outputTensorInfo.GetNumElements() == 4);
CHECK(outputTensorInfo.GetNumDimensions() == 4);
CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
};
TEST_CASE_FIXTURE(TfLiteParserFixture, "SqueezeDimsRangePartial_OutputWithDimensionsWithinRangeSqueezed")
{
std::vector<uint32_t> squeezeDims = { 1, 3 };
armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
CHECK(outputTensorInfo.GetNumElements() == 4);
CHECK(outputTensorInfo.GetNumDimensions() == 3);
CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2 })));
};
}