blob: 5f17f782f3fa11b0bcb5a64f4dacb7f7872ee9e3 [file] [log] [blame]
Aron Virginas-Tar70104002018-10-24 15:33:28 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include <armnn/ArmNN.hpp>
8
9#include <backends/test/QuantizeHelper.hpp>
10
11#include <vector>
12
13namespace
14{
15
16using namespace armnn;
17
18template<typename T>
19bool ConstantUsageTest(const std::vector<BackendId>& computeDevice,
20 const TensorInfo& commonTensorInfo,
21 const std::vector<T>& inputData,
22 const std::vector<T>& constantData,
23 const std::vector<T>& expectedOutputData)
24{
25 // Create runtime in which test will run
26 IRuntime::CreationOptions options;
27 IRuntimePtr runtime(IRuntime::Create(options));
28
29 // Builds up the structure of the network.
30 INetworkPtr net(INetwork::Create());
31
32 IConnectableLayer* input = net->AddInputLayer(0);
33 IConnectableLayer* constant = net->AddConstantLayer(ConstTensor(commonTensorInfo, constantData));
34 IConnectableLayer* add = net->AddAdditionLayer();
35 IConnectableLayer* output = net->AddOutputLayer(0);
36
37 input->GetOutputSlot(0).Connect(add->GetInputSlot(0));
38 constant->GetOutputSlot(0).Connect(add->GetInputSlot(1));
39 add->GetOutputSlot(0).Connect(output->GetInputSlot(0));
40
41 // Sets the tensors in the network.
42 input->GetOutputSlot(0).SetTensorInfo(commonTensorInfo);
43 constant->GetOutputSlot(0).SetTensorInfo(commonTensorInfo);
44 add->GetOutputSlot(0).SetTensorInfo(commonTensorInfo);
45
46 // optimize the network
47 IOptimizedNetworkPtr optNet = Optimize(*net, computeDevice, runtime->GetDeviceSpec());
48
49 // Loads it into the runtime.
50 NetworkId netId;
51 runtime->LoadNetwork(netId, std::move(optNet));
52
53 // Creates structures for input & output.
54 std::vector<T> outputData(inputData.size());
55
56 InputTensors inputTensors
57 {
58 {0, ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData.data())}
59 };
60 OutputTensors outputTensors
61 {
62 {0, Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
63 };
64
65 // Does the inference.
66 runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
67
68 // Checks the results.
69 return outputData == expectedOutputData;
70}
71
72inline bool ConstantUsageFloat32Test(const std::vector<BackendId>& backends)
73{
74 const TensorInfo commonTensorInfo({ 2, 3 }, DataType::Float32);
75
76 return ConstantUsageTest(backends,
77 commonTensorInfo,
78 std::vector<float>{ 1.f, 2.f, 3.f, 4.f, 5.f, 6.f }, // Input.
79 std::vector<float>{ 6.f, 5.f, 4.f, 3.f, 2.f, 1.f }, // Const input.
80 std::vector<float>{ 7.f, 7.f, 7.f, 7.f, 7.f, 7.f } // Expected output.
81 );
82}
83
84inline bool ConstantUsageUint8Test(const std::vector<BackendId>& backends)
85{
86 TensorInfo commonTensorInfo({ 2, 3 }, DataType::QuantisedAsymm8);
87
88 const float scale = 0.023529f;
89 const int8_t offset = -43;
90
91 commonTensorInfo.SetQuantizationScale(scale);
92 commonTensorInfo.SetQuantizationOffset(offset);
93
94 return ConstantUsageTest(backends,
95 commonTensorInfo,
96 QuantizedVector<uint8_t>(scale, offset, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f }), // Input.
97 QuantizedVector<uint8_t>(scale, offset, { 6.f, 5.f, 4.f, 3.f, 2.f, 1.f }), // Const input.
98 QuantizedVector<uint8_t>(scale, offset, { 7.f, 7.f, 7.f, 7.f, 7.f, 7.f }) // Expected output.
99 );
100}
101
102} // anonymous namespace