blob: d437f4da7778b1214dc42d6edfbfe25122bf838f [file] [log] [blame]
//
// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include <reference/workloads/Decoders.hpp>
#include <armnn/utility/IgnoreUnused.hpp>
#include <fmt/format.h>
#include <doctest/doctest.h>
#include <chrono>
template<typename T>
void CompareVector(std::vector<T> vec1, std::vector<T> vec2)
{
CHECK(vec1.size() == vec2.size());
bool mismatch = false;
for (uint32_t i = 0; i < vec1.size(); ++i)
{
if (vec1[i] != vec2[i])
{
MESSAGE(fmt::format("Vector value mismatch: index={} {} != {}",
i,
vec1[i],
vec2[i]));
mismatch = true;
}
}
if (mismatch)
{
FAIL("Error in CompareVector. Vectors don't match.");
}
}
using namespace armnn;
// Basically a per axis decoder but without any decoding/quantization
class MockPerAxisIterator : public PerAxisIterator<const int8_t, Decoder<int8_t>>
{
public:
MockPerAxisIterator(const int8_t* data, const armnn::TensorShape& tensorShape, const unsigned int axis)
: PerAxisIterator(data, tensorShape, axis), m_NumElements(tensorShape.GetNumElements())
{}
int8_t Get() const override
{
return *m_Iterator;
}
virtual std::vector<float> DecodeTensor(const TensorShape &tensorShape,
bool isDepthwise = false) override
{
IgnoreUnused(tensorShape, isDepthwise);
return std::vector<float>{};
};
// Iterates over data using operator[] and returns vector
std::vector<int8_t> Loop()
{
std::vector<int8_t> vec;
for (uint32_t i = 0; i < m_NumElements; ++i)
{
this->operator[](i);
vec.emplace_back(Get());
}
return vec;
}
unsigned int GetAxisIndex()
{
return m_AxisIndex;
}
unsigned int m_NumElements;
};
TEST_SUITE("RefPerAxisIterator")
{
// Test Loop (Equivalent to DecodeTensor) and Axis = 0
TEST_CASE("PerAxisIteratorTest1")
{
std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
// test axis=0
std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 0);
std::vector<int8_t> output = iterator.Loop();
CompareVector(output, expOutput);
// Set iterator to index and check if the axis index is correct
iterator[5];
CHECK(iterator.GetAxisIndex() == 1u);
iterator[1];
CHECK(iterator.GetAxisIndex() == 0u);
iterator[10];
CHECK(iterator.GetAxisIndex() == 2u);
}
// Test Axis = 1
TEST_CASE("PerAxisIteratorTest2")
{
std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
// test axis=1
std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 1);
std::vector<int8_t> output = iterator.Loop();
CompareVector(output, expOutput);
// Set iterator to index and check if the axis index is correct
iterator[5];
CHECK(iterator.GetAxisIndex() == 0u);
iterator[1];
CHECK(iterator.GetAxisIndex() == 0u);
iterator[10];
CHECK(iterator.GetAxisIndex() == 0u);
}
// Test Axis = 2
TEST_CASE("PerAxisIteratorTest3")
{
std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
// test axis=2
std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
std::vector<int8_t> output = iterator.Loop();
CompareVector(output, expOutput);
// Set iterator to index and check if the axis index is correct
iterator[5];
CHECK(iterator.GetAxisIndex() == 0u);
iterator[1];
CHECK(iterator.GetAxisIndex() == 0u);
iterator[10];
CHECK(iterator.GetAxisIndex() == 1u);
}
// Test Axis = 3
TEST_CASE("PerAxisIteratorTest4")
{
std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
// test axis=3
std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 3);
std::vector<int8_t> output = iterator.Loop();
CompareVector(output, expOutput);
// Set iterator to index and check if the axis index is correct
iterator[5];
CHECK(iterator.GetAxisIndex() == 1u);
iterator[1];
CHECK(iterator.GetAxisIndex() == 1u);
iterator[10];
CHECK(iterator.GetAxisIndex() == 0u);
}
// Test Axis = 1. Different tensor shape
TEST_CASE("PerAxisIteratorTest5")
{
using namespace armnn;
std::vector<int8_t> input =
{
0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15
};
std::vector<int8_t> expOutput =
{
0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15
};
TensorInfo tensorInfo ({2,2,2,2},DataType::QSymmS8);
auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 1);
std::vector<int8_t> output = iterator.Loop();
CompareVector(output, expOutput);
// Set iterator to index and check if the axis index is correct
iterator[5];
CHECK(iterator.GetAxisIndex() == 1u);
iterator[1];
CHECK(iterator.GetAxisIndex() == 0u);
iterator[10];
CHECK(iterator.GetAxisIndex() == 0u);
}
// Test the increment and decrement operator
TEST_CASE("PerAxisIteratorTest7")
{
using namespace armnn;
std::vector<int8_t> input =
{
0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11
};
std::vector<int8_t> expOutput =
{
0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11
};
TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
iterator += 3;
CHECK(iterator.Get() == expOutput[3]);
CHECK(iterator.GetAxisIndex() == 1u);
iterator += 3;
CHECK(iterator.Get() == expOutput[6]);
CHECK(iterator.GetAxisIndex() == 1u);
iterator -= 2;
CHECK(iterator.Get() == expOutput[4]);
CHECK(iterator.GetAxisIndex() == 0u);
iterator -= 1;
CHECK(iterator.Get() == expOutput[3]);
CHECK(iterator.GetAxisIndex() == 1u);
}
}