blob: d17b61e8fb56f18905d72d3be99af727789c701d [file] [log] [blame]
Aron Virginas-Tar70104002018-10-24 15:33:28 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +00007#include "TypeUtils.hpp"
8
Aron Virginas-Tar70104002018-10-24 15:33:28 +01009#include <armnn/ArmNN.hpp>
narpra01b9546cf2018-11-20 15:21:28 +000010#include <armnn/INetwork.hpp>
Aron Virginas-Tar70104002018-10-24 15:33:28 +010011
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <backendsCommon/test/QuantizeHelper.hpp>
Aron Virginas-Tar70104002018-10-24 15:33:28 +010013
narpra01b9546cf2018-11-20 15:21:28 +000014#include <boost/test/unit_test.hpp>
15
Aron Virginas-Tar70104002018-10-24 15:33:28 +010016#include <vector>
17
18namespace
19{
20
21using namespace armnn;
22
23template<typename T>
24bool ConstantUsageTest(const std::vector<BackendId>& computeDevice,
25 const TensorInfo& commonTensorInfo,
26 const std::vector<T>& inputData,
27 const std::vector<T>& constantData,
28 const std::vector<T>& expectedOutputData)
29{
30 // Create runtime in which test will run
31 IRuntime::CreationOptions options;
32 IRuntimePtr runtime(IRuntime::Create(options));
33
34 // Builds up the structure of the network.
35 INetworkPtr net(INetwork::Create());
36
37 IConnectableLayer* input = net->AddInputLayer(0);
38 IConnectableLayer* constant = net->AddConstantLayer(ConstTensor(commonTensorInfo, constantData));
39 IConnectableLayer* add = net->AddAdditionLayer();
40 IConnectableLayer* output = net->AddOutputLayer(0);
41
42 input->GetOutputSlot(0).Connect(add->GetInputSlot(0));
43 constant->GetOutputSlot(0).Connect(add->GetInputSlot(1));
44 add->GetOutputSlot(0).Connect(output->GetInputSlot(0));
45
46 // Sets the tensors in the network.
47 input->GetOutputSlot(0).SetTensorInfo(commonTensorInfo);
48 constant->GetOutputSlot(0).SetTensorInfo(commonTensorInfo);
49 add->GetOutputSlot(0).SetTensorInfo(commonTensorInfo);
50
51 // optimize the network
52 IOptimizedNetworkPtr optNet = Optimize(*net, computeDevice, runtime->GetDeviceSpec());
53
54 // Loads it into the runtime.
55 NetworkId netId;
56 runtime->LoadNetwork(netId, std::move(optNet));
57
58 // Creates structures for input & output.
59 std::vector<T> outputData(inputData.size());
60
61 InputTensors inputTensors
62 {
63 {0, ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData.data())}
64 };
65 OutputTensors outputTensors
66 {
67 {0, Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
68 };
69
70 // Does the inference.
71 runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
72
73 // Checks the results.
74 return outputData == expectedOutputData;
75}
76
77inline bool ConstantUsageFloat32Test(const std::vector<BackendId>& backends)
78{
79 const TensorInfo commonTensorInfo({ 2, 3 }, DataType::Float32);
80
81 return ConstantUsageTest(backends,
82 commonTensorInfo,
83 std::vector<float>{ 1.f, 2.f, 3.f, 4.f, 5.f, 6.f }, // Input.
84 std::vector<float>{ 6.f, 5.f, 4.f, 3.f, 2.f, 1.f }, // Const input.
85 std::vector<float>{ 7.f, 7.f, 7.f, 7.f, 7.f, 7.f } // Expected output.
86 );
87}
88
89inline bool ConstantUsageUint8Test(const std::vector<BackendId>& backends)
90{
91 TensorInfo commonTensorInfo({ 2, 3 }, DataType::QuantisedAsymm8);
92
93 const float scale = 0.023529f;
94 const int8_t offset = -43;
95
96 commonTensorInfo.SetQuantizationScale(scale);
97 commonTensorInfo.SetQuantizationOffset(offset);
98
99 return ConstantUsageTest(backends,
100 commonTensorInfo,
101 QuantizedVector<uint8_t>(scale, offset, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f }), // Input.
102 QuantizedVector<uint8_t>(scale, offset, { 6.f, 5.f, 4.f, 3.f, 2.f, 1.f }), // Const input.
103 QuantizedVector<uint8_t>(scale, offset, { 7.f, 7.f, 7.f, 7.f, 7.f, 7.f }) // Expected output.
104 );
105}
106
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000107template<typename T>
108bool CompareBoolean(T a, T b)
109{
110 return (a == 0 && b == 0) ||(a != 0 && b != 0);
111};
112
113template<DataType ArmnnIType, DataType ArmnnOType,
114 typename TInput = ResolveType<ArmnnIType>, typename TOutput = ResolveType<ArmnnOType>>
narpra01b9546cf2018-11-20 15:21:28 +0000115void EndToEndLayerTestImpl(INetworkPtr network,
kevmay012b4d88e2019-01-24 14:05:09 +0000116 const std::map<int, std::vector<TInput>>& inputTensorData,
117 const std::map<int, std::vector<TOutput>>& expectedOutputData,
narpra01b9546cf2018-11-20 15:21:28 +0000118 std::vector<BackendId> backends)
119{
120 // Create runtime in which test will run
121 IRuntime::CreationOptions options;
122 IRuntimePtr runtime(IRuntime::Create(options));
123
124 // optimize the network
125 IOptimizedNetworkPtr optNet = Optimize(*network, backends, runtime->GetDeviceSpec());
126
127 // Loads it into the runtime.
128 NetworkId netId;
129 runtime->LoadNetwork(netId, std::move(optNet));
130
131 InputTensors inputTensors;
132 inputTensors.reserve(inputTensorData.size());
133 for (auto&& it : inputTensorData)
134 {
135 inputTensors.push_back({it.first,
136 ConstTensor(runtime->GetInputTensorInfo(netId, it.first), it.second.data())});
137 }
138 OutputTensors outputTensors;
139 outputTensors.reserve(expectedOutputData.size());
kevmay012b4d88e2019-01-24 14:05:09 +0000140 std::map<int, std::vector<TOutput>> outputStorage;
narpra01b9546cf2018-11-20 15:21:28 +0000141 for (auto&& it : expectedOutputData)
142 {
kevmay012b4d88e2019-01-24 14:05:09 +0000143 std::vector<TOutput> out(it.second.size());
narpra01b9546cf2018-11-20 15:21:28 +0000144 outputStorage.emplace(it.first, out);
145 outputTensors.push_back({it.first,
146 Tensor(runtime->GetOutputTensorInfo(netId, it.first),
147 outputStorage.at(it.first).data())});
148 }
149
150 // Does the inference.
151 runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
152
153 // Checks the results.
154 for (auto&& it : expectedOutputData)
155 {
kevmay012b4d88e2019-01-24 14:05:09 +0000156 std::vector<TOutput> out = outputStorage.at(it.first);
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000157 if (ArmnnOType == DataType::Boolean)
158 {
159 for (unsigned int i = 0; i < out.size(); ++i)
160 {
161 BOOST_TEST(CompareBoolean<TOutput>(it.second[i], out[i]));
162 }
163 }
164 else
165 {
166 BOOST_TEST(it.second == out);
167 }
narpra01b9546cf2018-11-20 15:21:28 +0000168 }
169}
170
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000171} // anonymous namespace