blob: cfb62637d1fc3e413f8d7e4412da7b4373e3b5b4 [file] [log] [blame]
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ConstantTestImpl.hpp"
7
Aron Virginas-Tar48623a02019-10-22 10:00:28 +01008#include <QuantizeHelper.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01009#include <ResolveType.hpp>
10
11#include <armnn/ArmNN.hpp>
12
Matteo Martincighe011d202019-11-28 11:35:47 +000013#include <armnnUtils/Permute.hpp>
14
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010015#include <backendsCommon/CpuTensorHandle.hpp>
16
17#include <backendsCommon/test/TensorCopyUtils.hpp>
18#include <backendsCommon/test/WorkloadTestUtils.hpp>
19
20#include <test/TensorHelpers.hpp>
21
22namespace
23{
24
25template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
26LayerTestResult<T, 4> ConstantTestImpl(
27 armnn::IWorkloadFactory& workloadFactory,
28 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
29 float qScale,
30 int32_t qOffset)
31{
Derek Lambertic374ff02019-12-10 21:57:35 +000032 boost::ignore_unused(memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010033 constexpr unsigned int inputWidth = 3;
34 constexpr unsigned int inputHeight = 4;
35 constexpr unsigned int inputChannels = 3;
36 constexpr unsigned int inputBatchSize = 2;
37
38 constexpr unsigned int outputWidth = inputWidth;
39 constexpr unsigned int outputHeight = inputHeight;
40 constexpr unsigned int outputChannels = inputChannels;
41 constexpr unsigned int outputBatchSize = inputBatchSize;
42
43 armnn::TensorInfo inputTensorInfo({ inputBatchSize, inputChannels, inputHeight, inputWidth },
44 ArmnnType, qScale, qOffset);
45
46 armnn::TensorInfo outputTensorInfo({ outputBatchSize, outputChannels, outputHeight, outputWidth },
47 ArmnnType, qScale, qOffset);
48
49 // Set quantization parameters if the requested type is a quantized type.
50 if(armnn::IsQuantizedType<T>())
51 {
52 inputTensorInfo.SetQuantizationScale(qScale);
53 inputTensorInfo.SetQuantizationOffset(qOffset);
54 outputTensorInfo.SetQuantizationScale(qScale);
55 outputTensorInfo.SetQuantizationOffset(qOffset);
56 }
57
58 auto input = MakeTensor<T, 4>(inputTensorInfo, std::vector<T>(
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010059 armnnUtils::QuantizedVector<T>(
60 {
61 // Batch 0, Channel 0
62 235.0f, 46.0f, 178.0f,
63 100.0f, 123.0f, 19.0f,
64 172.0f, 74.0f, 250.0f,
65 6.0f, 195.0f, 80.0f,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010066
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010067 // Batch 0, Channel 1
68 113.0f, 95.0f, 202.0f,
69 77.0f, 114.0f, 71.0f,
70 122.0f, 246.0f, 166.0f,
71 82.0f, 28.0f, 37.0f,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010072
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010073 // Batch 0, Channel 2
74 56.0f, 170.0f, 162.0f,
75 194.0f, 89.0f, 254.0f,
76 12.0f, 209.0f, 200.0f,
77 1.0f, 64.0f, 54.0f,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010078
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010079 // Batch 1, Channel 0
80 67.0f, 90.0f, 49.0f,
81 7.0f, 163.0f, 18.0f,
82 25.0f, 117.0f, 103.0f,
83 247.0f, 59.0f, 189.0f,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010084
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010085 // Batch 1, Channel 1
86 239.0f, 104.0f, 199.0f,
87 17.0f, 124.0f, 153.0f,
88 222.0f, 217.0f, 75.0f,
89 32.0f, 126.0f, 21.0f,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010090
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010091 // Batch 1, Channel 2
92 97.0f, 145.0f, 215.0f,
93 115.0f, 116.0f, 238.0f,
94 226.0f, 16.0f, 132.0f,
95 92.0f, 125.0f, 88.0f,
96 },
97 qScale, qOffset)));
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010098
99 LayerTestResult<T, 4> result(outputTensorInfo);
100 result.outputExpected = input;
101
102 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
103
104 armnn::ScopedCpuTensorHandle constantTensor(inputTensorInfo);
105 AllocateAndCopyDataToITensorHandle(&constantTensor, &input[0][0][0][0]);
106
107 armnn::ConstantQueueDescriptor descriptor;
108 descriptor.m_LayerOutput = &constantTensor;
109
110 armnn::WorkloadInfo info;
111 AddOutputToWorkload(descriptor, info, outputTensorInfo, outputHandle.get());
112
113 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateConstant(descriptor, info);
114
115 outputHandle->Allocate();
116
117 workload->PostAllocationConfigure();
118 workload->Execute();
119
120 CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
121 return result;
122}
123
124} // anonymous namespace
125
126LayerTestResult<float, 4> ConstantTest(
127 armnn::IWorkloadFactory& workloadFactory,
128 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
129{
130 return ConstantTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, 0.0f, 0);
131}
132
133LayerTestResult<int16_t, 4> ConstantInt16SimpleQuantizationScaleNoOffsetTest(
134 armnn::IWorkloadFactory& workloadFactory,
135 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
136{
137 return ConstantTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager, 1.0f, 0);
138}
139
140LayerTestResult<uint8_t, 4> ConstantUint8SimpleQuantizationScaleNoOffsetTest(
141 armnn::IWorkloadFactory& workloadFactory,
142 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
143{
144 return ConstantTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, 1.0f, 0);
145}
146
147LayerTestResult<uint8_t, 4> ConstantUint8CustomQuantizationScaleAndOffsetTest(
148 armnn::IWorkloadFactory& workloadFactory,
149 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
150{
151 return ConstantTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, 2e-6f, 1);
152}
153
154LayerTestResult<int16_t, 4> ConstantInt16CustomQuantizationScaleAndOffsetTest(
155 armnn::IWorkloadFactory& workloadFactory,
156 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
157{
158 return ConstantTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager, 2e-6f, 1);
159}