blob: 66ccdbf1d9af048ddd444dd4c289fefe55055aef [file] [log] [blame]
Mike Kelly386ff1a2021-03-29 15:04:50 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <ResolveType.hpp>
9
10#include <armnn/IWorkingMemHandle.hpp>
11#include <armnn/INetwork.hpp>
12
13#include <backendsCommon/test/CommonTestUtils.hpp>
14
15#include <boost/test/unit_test.hpp>
16
17#include <vector>
18
19namespace armnn
20{
21
22namespace experimental
23{
24
25template<DataType ArmnnIType, DataType ArmnnOType,
26 typename TInput = ResolveType <ArmnnIType>, typename TOutput = ResolveType <ArmnnOType>>
27void AsyncEndToEndTestImpl(INetworkPtr network,
28 const std::map<int, std::vector<TInput>>& inputTensorData,
29 const std::map<int, std::vector<TOutput>>& expectedOutputData,
30 std::vector<BackendId> backends,
31 float tolerance = 0.000001f)
32{
33 // Create Runtime in which test will run
34 IRuntime::CreationOptions options;
35 IRuntimePtr runtime(IRuntime::Create(options));
36
37 // Optimize the Network
38 IOptimizedNetworkPtr optNet = Optimize(*network, backends, runtime->GetDeviceSpec());
39
40 // Creates AsyncNetwork
41 NetworkId networkId = 0;
42 std::string errorMessage;
Mike Kelly55a8ffd2021-04-07 20:10:49 +010043 const INetworkProperties networkProperties(false, false, true);
44 runtime->LoadNetwork(networkId, std::move(optNet), errorMessage, networkProperties);
Mike Kelly386ff1a2021-03-29 15:04:50 +010045
46 InputTensors inputTensors;
47 inputTensors.reserve(inputTensorData.size());
48 for (auto&& it : inputTensorData)
49 {
50 inputTensors.push_back({it.first,
Mike Kelly55a8ffd2021-04-07 20:10:49 +010051 ConstTensor(runtime->GetInputTensorInfo(networkId, it.first), it.second.data())});
Mike Kelly386ff1a2021-03-29 15:04:50 +010052 }
53
54 OutputTensors outputTensors;
55 outputTensors.reserve(expectedOutputData.size());
56 std::map<int, std::vector<TOutput>> outputStorage;
57 for (auto&& it : expectedOutputData)
58 {
59 std::vector<TOutput> out(it.second.size());
60 outputStorage.emplace(it.first, out);
61 outputTensors.push_back({it.first,
Mike Kelly55a8ffd2021-04-07 20:10:49 +010062 Tensor(runtime->GetOutputTensorInfo(networkId, it.first),
Mike Kelly386ff1a2021-03-29 15:04:50 +010063 outputStorage.at(it.first).data())});
64 }
65
66 // Create WorkingMemHandle for this async network
Mike Kelly55a8ffd2021-04-07 20:10:49 +010067 std::unique_ptr<IWorkingMemHandle> workingMemHandle = runtime->CreateWorkingMemHandle(networkId);
Mike Kelly386ff1a2021-03-29 15:04:50 +010068 IWorkingMemHandle& workingMemHandleRef = *workingMemHandle.get();
69
70 // Run the async network
Mike Kelly55a8ffd2021-04-07 20:10:49 +010071 runtime->Execute(workingMemHandleRef, inputTensors, outputTensors);
Mike Kelly386ff1a2021-03-29 15:04:50 +010072
73 // Checks the results.
74 for (auto&& it : expectedOutputData)
75 {
76 std::vector<TOutput> out = outputStorage.at(it.first);
77 for (unsigned int i = 0; i < out.size(); ++i)
78 {
79 BOOST_CHECK(Compare<ArmnnOType>(it.second[i], out[i], tolerance) == true);
80 }
81 }
82}
83
84template<typename armnn::DataType DataType>
85INetworkPtr CreateStridedSliceNetwork(const TensorShape& inputShape,
86 const TensorShape& outputShape,
87 const std::vector<int>& beginData,
88 const std::vector<int>& endData,
89 const std::vector<int>& stridesData,
90 int beginMask = 0,
91 int endMask = 0,
92 int shrinkAxisMask = 0,
93 int ellipsisMask = 0,
94 int newAxisMask = 0,
95 const float qScale = 1.0f,
96 const int32_t qOffset = 0)
97{
98 using namespace armnn;
99 // Builds up the structure of the network.
100 INetworkPtr net(INetwork::Create());
101
102 TensorInfo inputTensorInfo(inputShape, DataType, qScale, qOffset);
103 TensorInfo outputTensorInfo(outputShape, DataType, qScale, qOffset);
104
105 armnn::StridedSliceDescriptor stridedSliceDescriptor;
106 stridedSliceDescriptor.m_Begin = beginData;
107 stridedSliceDescriptor.m_End = endData;
108 stridedSliceDescriptor.m_Stride = stridesData;
109 stridedSliceDescriptor.m_BeginMask = beginMask;
110 stridedSliceDescriptor.m_EndMask = endMask;
111 stridedSliceDescriptor.m_ShrinkAxisMask = shrinkAxisMask;
112 stridedSliceDescriptor.m_EllipsisMask = ellipsisMask;
113 stridedSliceDescriptor.m_NewAxisMask = newAxisMask;
114
115 IConnectableLayer* input = net->AddInputLayer(0, "Input_Layer");
116 IConnectableLayer* stridedSlice = net->AddStridedSliceLayer(stridedSliceDescriptor, "splitter");
117 IConnectableLayer* output = net->AddOutputLayer(0);
118
119 Connect(input, stridedSlice, inputTensorInfo, 0, 0);
120 Connect(stridedSlice, output, outputTensorInfo, 0, 0);
121
122 return net;
123}
124
125template<armnn::DataType ArmnnType>
126void StridedSlicedEndToEndTest(const std::vector<BackendId>& backends)
127{
128 using namespace armnn;
129 using T = ResolveType<ArmnnType>;
130
131 const TensorShape& inputShape = {3, 2, 3, 1};
132 const TensorShape& outputShape = {1, 2, 3, 1};
133 const std::vector<int>& beginData = {1, 0, 0, 0};
134 const std::vector<int>& endData = {2, 2, 3, 1};
135 const std::vector<int>& stridesData = {1, 1, 1, 1};
136 int beginMask = 0;
137 int endMask = 0;
138 int shrinkAxisMask = 0;
139 int ellipsisMask = 0;
140 int newAxisMask = 0;
141
142 // Builds up the structure of the network
143 INetworkPtr net = CreateStridedSliceNetwork<ArmnnType>(inputShape,
144 outputShape,
145 beginData,
146 endData,
147 stridesData,
148 beginMask,
149 endMask,
150 shrinkAxisMask,
151 ellipsisMask,
152 newAxisMask);
153
154 BOOST_TEST_CHECKPOINT("create a network");
155
156 // Creates structures for input & output.
157 std::vector<T> inputData{
158 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
159
160 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
161
162 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
163 };
164
165 std::vector<T> outputExpected{
166 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f
167 };
168
169 std::map<int, std::vector<T>> inputTensorData = {{0, inputData}};
170 std::map<int, std::vector<T>> expectedOutputData = {{0, outputExpected}};
171
172 AsyncEndToEndTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
173}
174
175} // experimental namespace
176
177} // armnn namespace
178