blob: fa933a0ec3a6c3ef3caac8b07de4bd71d7408c56 [file] [log] [blame]
Aron Virginas-Tar13b653f2019-11-01 11:40:39 +00001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <aclCommon/ArmComputeTensorUtils.hpp>
7
Sadik Armagan1625efc2021-06-10 18:24:34 +01008#include <doctest/doctest.h>
Aron Virginas-Tar13b653f2019-11-01 11:40:39 +00009
10using namespace armnn::armcomputetensorutils;
11
Sadik Armagan1625efc2021-06-10 18:24:34 +010012TEST_SUITE("ArmComputeTensorUtils")
13{
14TEST_CASE("BuildArmComputeTensorInfoTest")
Aron Virginas-Tar13b653f2019-11-01 11:40:39 +000015{
16
17 const armnn::TensorShape tensorShape = { 1, 2, 3, 4 };
Derek Lambertif90c56d2020-01-10 17:14:08 +000018 const armnn::DataType dataType = armnn::DataType::QAsymmU8;
Aron Virginas-Tar13b653f2019-11-01 11:40:39 +000019
20 const std::vector<float> quantScales = { 1.5f, 2.5f, 3.5f, 4.5f };
21 const float quantScale = quantScales[0];
22 const int32_t quantOffset = 128;
23
24 // Tensor info with per-tensor quantization
25 const armnn::TensorInfo tensorInfo0(tensorShape, dataType, quantScale, quantOffset);
26 const arm_compute::TensorInfo aclTensorInfo0 = BuildArmComputeTensorInfo(tensorInfo0);
27
28 const arm_compute::TensorShape& aclTensorShape = aclTensorInfo0.tensor_shape();
Sadik Armagan1625efc2021-06-10 18:24:34 +010029 CHECK(aclTensorShape.num_dimensions() == tensorShape.GetNumDimensions());
Aron Virginas-Tar13b653f2019-11-01 11:40:39 +000030 for(unsigned int i = 0u; i < tensorShape.GetNumDimensions(); ++i)
31 {
32 // NOTE: arm_compute tensor dimensions are stored in the opposite order
Sadik Armagan1625efc2021-06-10 18:24:34 +010033 CHECK(aclTensorShape[i] == tensorShape[tensorShape.GetNumDimensions() - i - 1]);
Aron Virginas-Tar13b653f2019-11-01 11:40:39 +000034 }
35
Sadik Armagan1625efc2021-06-10 18:24:34 +010036 CHECK(aclTensorInfo0.data_type() == arm_compute::DataType::QASYMM8);
37 CHECK(aclTensorInfo0.quantization_info().scale()[0] == quantScale);
Aron Virginas-Tar13b653f2019-11-01 11:40:39 +000038
39 // Tensor info with per-axis quantization
40 const armnn::TensorInfo tensorInfo1(tensorShape, dataType, quantScales, 0);
41 const arm_compute::TensorInfo aclTensorInfo1 = BuildArmComputeTensorInfo(tensorInfo1);
42
Sadik Armagan1625efc2021-06-10 18:24:34 +010043 CHECK(aclTensorInfo1.quantization_info().scale() == quantScales);
Aron Virginas-Tar13b653f2019-11-01 11:40:39 +000044}
45
Sadik Armagan1625efc2021-06-10 18:24:34 +010046}