blob: b4653cd8db4291bd95476d473255c418bad96544 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
Finn Williamsb49ed182021-06-29 15:50:08 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#pragma once
7
Matteo Martincighc601aa62019-10-29 15:03:22 +00008#include "Schema.hpp"
9
keidav01222c7532019-03-14 17:12:10 +000010#include <armnn/Descriptors.hpp>
11#include <armnn/IRuntime.hpp>
12#include <armnn/TypesUtils.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000013#include <armnn/BackendRegistry.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010014#include <armnn/utility/Assert.hpp>
keidav01222c7532019-03-14 17:12:10 +000015
Finn Williamsb49ed182021-06-29 15:50:08 +010016#include "../TfLiteParser.hpp"
Matteo Martincighc601aa62019-10-29 15:03:22 +000017
18#include <ResolveType.hpp>
19
20#include <test/TensorHelpers.hpp>
21
James Ward58dec6b2020-09-11 17:32:44 +010022#include <fmt/format.h>
Sadik Armagan1625efc2021-06-10 18:24:34 +010023#include <doctest/doctest.h>
keidav01222c7532019-03-14 17:12:10 +000024
telsoa01c577f2c2018-08-31 09:22:23 +010025#include "flatbuffers/idl.h"
26#include "flatbuffers/util.h"
keidav01222c7532019-03-14 17:12:10 +000027#include "flatbuffers/flexbuffers.h"
telsoa01c577f2c2018-08-31 09:22:23 +010028
29#include <schema_generated.h>
Matteo Martincighc601aa62019-10-29 15:03:22 +000030
telsoa01c577f2c2018-08-31 09:22:23 +010031
32using armnnTfLiteParser::ITfLiteParser;
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010033using armnnTfLiteParser::ITfLiteParserPtr;
telsoa01c577f2c2018-08-31 09:22:23 +010034
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010035using TensorRawPtr = const tflite::TensorT *;
telsoa01c577f2c2018-08-31 09:22:23 +010036struct ParserFlatbuffersFixture
37{
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +000038 ParserFlatbuffersFixture() :
Finn Williamsb49ed182021-06-29 15:50:08 +010039 m_Runtime(armnn::IRuntime::Create(armnn::IRuntime::CreationOptions())),
40 m_NetworkIdentifier(0),
41 m_DynamicNetworkIdentifier(1)
telsoa01c577f2c2018-08-31 09:22:23 +010042 {
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010043 ITfLiteParser::TfLiteParserOptions options;
44 options.m_StandInLayerForUnsupported = true;
Sadik Armagand109a4d2020-07-28 10:42:13 +010045 options.m_InferAndValidate = true;
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010046
Finn Williamsb49ed182021-06-29 15:50:08 +010047 m_Parser = std::make_unique<armnnTfLiteParser::TfLiteParserImpl>(
48 armnn::Optional<ITfLiteParser::TfLiteParserOptions>(options));
telsoa01c577f2c2018-08-31 09:22:23 +010049 }
50
51 std::vector<uint8_t> m_GraphBinary;
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010052 std::string m_JsonString;
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010053 armnn::IRuntimePtr m_Runtime;
54 armnn::NetworkId m_NetworkIdentifier;
Finn Williamsb49ed182021-06-29 15:50:08 +010055 armnn::NetworkId m_DynamicNetworkIdentifier;
56 bool m_TestDynamic;
57 std::unique_ptr<armnnTfLiteParser::TfLiteParserImpl> m_Parser;
telsoa01c577f2c2018-08-31 09:22:23 +010058
59 /// If the single-input-single-output overload of Setup() is called, these will store the input and output name
60 /// so they don't need to be passed to the single-input-single-output overload of RunTest().
61 std::string m_SingleInputName;
62 std::string m_SingleOutputName;
63
Finn Williamsb49ed182021-06-29 15:50:08 +010064 void Setup(bool testDynamic = true)
65 {
66 m_TestDynamic = testDynamic;
67 loadNetwork(m_NetworkIdentifier, false);
68
69 if (m_TestDynamic)
70 {
71 loadNetwork(m_DynamicNetworkIdentifier, true);
72 }
73 }
74
75 std::unique_ptr<tflite::ModelT> MakeModelDynamic(std::vector<uint8_t> graphBinary)
76 {
77 const uint8_t* binaryContent = graphBinary.data();
78 const size_t len = graphBinary.size();
79 if (binaryContent == nullptr)
80 {
81 throw armnn::InvalidArgumentException(fmt::format("Invalid (null) binary content {}",
82 CHECK_LOCATION().AsString()));
83 }
84 flatbuffers::Verifier verifier(binaryContent, len);
85 if (verifier.VerifyBuffer<tflite::Model>() == false)
86 {
87 throw armnn::ParseException(fmt::format("Buffer doesn't conform to the expected Tensorflow Lite "
88 "flatbuffers format. size:{} {}",
89 len,
90 CHECK_LOCATION().AsString()));
91 }
92 auto model = tflite::UnPackModel(binaryContent);
93
94 for (auto const& subgraph : model->subgraphs)
95 {
96 std::vector<int32_t> inputIds = subgraph->inputs;
97 for (unsigned int tensorIndex = 0; tensorIndex < subgraph->tensors.size(); ++tensorIndex)
98 {
99 if (std::find(inputIds.begin(), inputIds.end(), tensorIndex) != inputIds.end())
100 {
101 continue;
102 }
103 for (auto const& tensor : subgraph->tensors)
104 {
105 if (tensor->shape_signature.size() != 0)
106 {
107 continue;
108 }
109
110 for (unsigned int i = 0; i < tensor->shape.size(); ++i)
111 {
112 tensor->shape_signature.push_back(-1);
113 }
114 }
115 }
116 }
117
118 return model;
119 }
120
121 void loadNetwork(armnn::NetworkId networkId, bool loadDynamic)
telsoa01c577f2c2018-08-31 09:22:23 +0100122 {
123 bool ok = ReadStringToBinary();
124 if (!ok) {
125 throw armnn::Exception("LoadNetwork failed while reading binary input");
126 }
127
Finn Williamsb49ed182021-06-29 15:50:08 +0100128 armnn::INetworkPtr network = loadDynamic ? m_Parser->LoadModel(MakeModelDynamic(m_GraphBinary))
129 : m_Parser->CreateNetworkFromBinary(m_GraphBinary);
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000130
131 if (!network) {
132 throw armnn::Exception("The parser failed to create an ArmNN network");
133 }
134
135 auto optimized = Optimize(*network, { armnn::Compute::CpuRef },
136 m_Runtime->GetDeviceSpec());
137 std::string errorMessage;
138
Finn Williamsb49ed182021-06-29 15:50:08 +0100139 armnn::Status ret = m_Runtime->LoadNetwork(networkId, move(optimized), errorMessage);
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000140
141 if (ret != armnn::Status::Success)
telsoa01c577f2c2018-08-31 09:22:23 +0100142 {
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000143 throw armnn::Exception(
James Ward58dec6b2020-09-11 17:32:44 +0100144 fmt::format("The runtime failed to load the network. "
145 "Error was: {}. in {} [{}:{}]",
146 errorMessage,
147 __func__,
148 __FILE__,
149 __LINE__));
telsoa01c577f2c2018-08-31 09:22:23 +0100150 }
151 }
152
153 void SetupSingleInputSingleOutput(const std::string& inputName, const std::string& outputName)
154 {
155 // Store the input and output name so they don't need to be passed to the single-input-single-output RunTest().
156 m_SingleInputName = inputName;
157 m_SingleOutputName = outputName;
158 Setup();
159 }
160
161 bool ReadStringToBinary()
162 {
Rob Hughesff3c4262019-12-20 17:43:16 +0000163 std::string schemafile(g_TfLiteSchemaText, g_TfLiteSchemaText + g_TfLiteSchemaText_len);
telsoa01c577f2c2018-08-31 09:22:23 +0100164
165 // parse schema first, so we can use it to parse the data after
166 flatbuffers::Parser parser;
167
Matthew Bentham6c8e8e72019-01-15 17:57:00 +0000168 bool ok = parser.Parse(schemafile.c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100169 ARMNN_ASSERT_MSG(ok, "Failed to parse schema file");
telsoa01c577f2c2018-08-31 09:22:23 +0100170
171 ok &= parser.Parse(m_JsonString.c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100172 ARMNN_ASSERT_MSG(ok, "Failed to parse json input");
telsoa01c577f2c2018-08-31 09:22:23 +0100173
174 if (!ok)
175 {
176 return false;
177 }
178
179 {
180 const uint8_t * bufferPtr = parser.builder_.GetBufferPointer();
181 size_t size = static_cast<size_t>(parser.builder_.GetSize());
182 m_GraphBinary.assign(bufferPtr, bufferPtr+size);
183 }
184 return ok;
185 }
186
187 /// Executes the network with the given input tensor and checks the result against the given output tensor.
keidav011b3e2ea2019-02-21 10:07:37 +0000188 /// This assumes the network has a single input and a single output.
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000189 template <std::size_t NumOutputDimensions,
Rob Hughesfc6bf052019-12-16 17:10:51 +0000190 armnn::DataType ArmnnType>
telsoa01c577f2c2018-08-31 09:22:23 +0100191 void RunTest(size_t subgraphId,
Rob Hughesfc6bf052019-12-16 17:10:51 +0000192 const std::vector<armnn::ResolveType<ArmnnType>>& inputData,
193 const std::vector<armnn::ResolveType<ArmnnType>>& expectedOutputData);
telsoa01c577f2c2018-08-31 09:22:23 +0100194
195 /// Executes the network with the given input tensors and checks the results against the given output tensors.
196 /// This overload supports multiple inputs and multiple outputs, identified by name.
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000197 template <std::size_t NumOutputDimensions,
Rob Hughesfc6bf052019-12-16 17:10:51 +0000198 armnn::DataType ArmnnType>
telsoa01c577f2c2018-08-31 09:22:23 +0100199 void RunTest(size_t subgraphId,
Rob Hughesfc6bf052019-12-16 17:10:51 +0000200 const std::map<std::string, std::vector<armnn::ResolveType<ArmnnType>>>& inputData,
201 const std::map<std::string, std::vector<armnn::ResolveType<ArmnnType>>>& expectedOutputData);
telsoa01c577f2c2018-08-31 09:22:23 +0100202
keidav011b3e2ea2019-02-21 10:07:37 +0000203 /// Multiple Inputs, Multiple Outputs w/ Variable Datatypes and different dimension sizes.
204 /// Executes the network with the given input tensors and checks the results against the given output tensors.
205 /// This overload supports multiple inputs and multiple outputs, identified by name along with the allowance for
206 /// the input datatype to be different to the output
207 template <std::size_t NumOutputDimensions,
208 armnn::DataType ArmnnType1,
Rob Hughesfc6bf052019-12-16 17:10:51 +0000209 armnn::DataType ArmnnType2>
keidav011b3e2ea2019-02-21 10:07:37 +0000210 void RunTest(size_t subgraphId,
Rob Hughesfc6bf052019-12-16 17:10:51 +0000211 const std::map<std::string, std::vector<armnn::ResolveType<ArmnnType1>>>& inputData,
Sadik Armagand109a4d2020-07-28 10:42:13 +0100212 const std::map<std::string, std::vector<armnn::ResolveType<ArmnnType2>>>& expectedOutputData,
213 bool isDynamic = false);
keidav011b3e2ea2019-02-21 10:07:37 +0000214
Sadik Armagan26868492021-01-22 14:25:31 +0000215 /// Multiple Inputs with different DataTypes, Multiple Outputs w/ Variable DataTypes
216 /// Executes the network with the given input tensors and checks the results against the given output tensors.
217 /// This overload supports multiple inputs and multiple outputs, identified by name along with the allowance for
218 /// the input datatype to be different to the output
219 template <std::size_t NumOutputDimensions,
220 armnn::DataType inputType1,
221 armnn::DataType inputType2,
222 armnn::DataType outputType>
223 void RunTest(size_t subgraphId,
224 const std::map<std::string, std::vector<armnn::ResolveType<inputType1>>>& input1Data,
225 const std::map<std::string, std::vector<armnn::ResolveType<inputType2>>>& input2Data,
226 const std::map<std::string, std::vector<armnn::ResolveType<outputType>>>& expectedOutputData);
keidav011b3e2ea2019-02-21 10:07:37 +0000227
228 /// Multiple Inputs, Multiple Outputs w/ Variable Datatypes and different dimension sizes.
229 /// Executes the network with the given input tensors and checks the results against the given output tensors.
230 /// This overload supports multiple inputs and multiple outputs, identified by name along with the allowance for
231 /// the input datatype to be different to the output
232 template<armnn::DataType ArmnnType1,
Rob Hughesfc6bf052019-12-16 17:10:51 +0000233 armnn::DataType ArmnnType2>
keidav011b3e2ea2019-02-21 10:07:37 +0000234 void RunTest(std::size_t subgraphId,
Rob Hughesfc6bf052019-12-16 17:10:51 +0000235 const std::map<std::string, std::vector<armnn::ResolveType<ArmnnType1>>>& inputData,
236 const std::map<std::string, std::vector<armnn::ResolveType<ArmnnType2>>>& expectedOutputData);
keidav011b3e2ea2019-02-21 10:07:37 +0000237
keidav01222c7532019-03-14 17:12:10 +0000238 static inline std::string GenerateDetectionPostProcessJsonString(
239 const armnn::DetectionPostProcessDescriptor& descriptor)
240 {
241 flexbuffers::Builder detectPostProcess;
242 detectPostProcess.Map([&]() {
243 detectPostProcess.Bool("use_regular_nms", descriptor.m_UseRegularNms);
244 detectPostProcess.Int("max_detections", descriptor.m_MaxDetections);
245 detectPostProcess.Int("max_classes_per_detection", descriptor.m_MaxClassesPerDetection);
246 detectPostProcess.Int("detections_per_class", descriptor.m_DetectionsPerClass);
247 detectPostProcess.Int("num_classes", descriptor.m_NumClasses);
248 detectPostProcess.Float("nms_score_threshold", descriptor.m_NmsScoreThreshold);
249 detectPostProcess.Float("nms_iou_threshold", descriptor.m_NmsIouThreshold);
250 detectPostProcess.Float("h_scale", descriptor.m_ScaleH);
251 detectPostProcess.Float("w_scale", descriptor.m_ScaleW);
252 detectPostProcess.Float("x_scale", descriptor.m_ScaleX);
253 detectPostProcess.Float("y_scale", descriptor.m_ScaleY);
254 });
255 detectPostProcess.Finish();
256
257 // Create JSON string
258 std::stringstream strStream;
259 std::vector<uint8_t> buffer = detectPostProcess.GetBuffer();
260 std::copy(buffer.begin(), buffer.end(),std::ostream_iterator<int>(strStream,","));
261
262 return strStream.str();
263 }
264
telsoa01c577f2c2018-08-31 09:22:23 +0100265 void CheckTensors(const TensorRawPtr& tensors, size_t shapeSize, const std::vector<int32_t>& shape,
266 tflite::TensorType tensorType, uint32_t buffer, const std::string& name,
267 const std::vector<float>& min, const std::vector<float>& max,
268 const std::vector<float>& scale, const std::vector<int64_t>& zeroPoint)
269 {
Sadik Armagan1625efc2021-06-10 18:24:34 +0100270 CHECK(tensors);
271 CHECK_EQ(shapeSize, tensors->shape.size());
272 CHECK(std::equal(shape.begin(), shape.end(), tensors->shape.begin(), tensors->shape.end()));
273 CHECK_EQ(tensorType, tensors->type);
274 CHECK_EQ(buffer, tensors->buffer);
275 CHECK_EQ(name, tensors->name);
276 CHECK(tensors->quantization);
277 CHECK(std::equal(min.begin(), min.end(), tensors->quantization.get()->min.begin(),
278 tensors->quantization.get()->min.end()));
279 CHECK(std::equal(max.begin(), max.end(), tensors->quantization.get()->max.begin(),
280 tensors->quantization.get()->max.end()));
281 CHECK(std::equal(scale.begin(), scale.end(), tensors->quantization.get()->scale.begin(),
282 tensors->quantization.get()->scale.end()));
283 CHECK(std::equal(zeroPoint.begin(), zeroPoint.end(),
telsoa01c577f2c2018-08-31 09:22:23 +0100284 tensors->quantization.get()->zero_point.begin(),
Sadik Armagan1625efc2021-06-10 18:24:34 +0100285 tensors->quantization.get()->zero_point.end()));
telsoa01c577f2c2018-08-31 09:22:23 +0100286 }
Sadik Armagan26868492021-01-22 14:25:31 +0000287
288private:
289 /// Fills the InputTensors with given input data
290 template <armnn::DataType dataType>
291 void FillInputTensors(armnn::InputTensors& inputTensors,
292 const std::map<std::string, std::vector<armnn::ResolveType<dataType>>>& inputData,
293 size_t subgraphId);
telsoa01c577f2c2018-08-31 09:22:23 +0100294};
295
Sadik Armagan26868492021-01-22 14:25:31 +0000296/// Fills the InputTensors with given input data
297template <armnn::DataType dataType>
298void ParserFlatbuffersFixture::FillInputTensors(
299 armnn::InputTensors& inputTensors,
300 const std::map<std::string, std::vector<armnn::ResolveType<dataType>>>& inputData,
301 size_t subgraphId)
302{
303 for (auto&& it : inputData)
304 {
305 armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkInputBindingInfo(subgraphId, it.first);
306 armnn::VerifyTensorInfoDataType(bindingInfo.second, dataType);
307 inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) });
308 }
309}
310
keidav011b3e2ea2019-02-21 10:07:37 +0000311/// Single Input, Single Output
312/// Executes the network with the given input tensor and checks the result against the given output tensor.
313/// This overload assumes the network has a single input and a single output.
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000314template <std::size_t NumOutputDimensions,
Rob Hughesfc6bf052019-12-16 17:10:51 +0000315 armnn::DataType armnnType>
telsoa01c577f2c2018-08-31 09:22:23 +0100316void ParserFlatbuffersFixture::RunTest(size_t subgraphId,
Rob Hughesfc6bf052019-12-16 17:10:51 +0000317 const std::vector<armnn::ResolveType<armnnType>>& inputData,
318 const std::vector<armnn::ResolveType<armnnType>>& expectedOutputData)
telsoa01c577f2c2018-08-31 09:22:23 +0100319{
keidav011b3e2ea2019-02-21 10:07:37 +0000320 RunTest<NumOutputDimensions, armnnType>(subgraphId,
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000321 { { m_SingleInputName, inputData } },
322 { { m_SingleOutputName, expectedOutputData } });
telsoa01c577f2c2018-08-31 09:22:23 +0100323}
324
keidav011b3e2ea2019-02-21 10:07:37 +0000325/// Multiple Inputs, Multiple Outputs
326/// Executes the network with the given input tensors and checks the results against the given output tensors.
327/// This overload supports multiple inputs and multiple outputs, identified by name.
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000328template <std::size_t NumOutputDimensions,
Rob Hughesfc6bf052019-12-16 17:10:51 +0000329 armnn::DataType armnnType>
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000330void ParserFlatbuffersFixture::RunTest(size_t subgraphId,
Rob Hughesfc6bf052019-12-16 17:10:51 +0000331 const std::map<std::string, std::vector<armnn::ResolveType<armnnType>>>& inputData,
332 const std::map<std::string, std::vector<armnn::ResolveType<armnnType>>>& expectedOutputData)
telsoa01c577f2c2018-08-31 09:22:23 +0100333{
keidav011b3e2ea2019-02-21 10:07:37 +0000334 RunTest<NumOutputDimensions, armnnType, armnnType>(subgraphId, inputData, expectedOutputData);
335}
336
337/// Multiple Inputs, Multiple Outputs w/ Variable Datatypes
338/// Executes the network with the given input tensors and checks the results against the given output tensors.
339/// This overload supports multiple inputs and multiple outputs, identified by name along with the allowance for
340/// the input datatype to be different to the output
341template <std::size_t NumOutputDimensions,
342 armnn::DataType armnnType1,
Rob Hughesfc6bf052019-12-16 17:10:51 +0000343 armnn::DataType armnnType2>
keidav011b3e2ea2019-02-21 10:07:37 +0000344void ParserFlatbuffersFixture::RunTest(size_t subgraphId,
Rob Hughesfc6bf052019-12-16 17:10:51 +0000345 const std::map<std::string, std::vector<armnn::ResolveType<armnnType1>>>& inputData,
Sadik Armagand109a4d2020-07-28 10:42:13 +0100346 const std::map<std::string, std::vector<armnn::ResolveType<armnnType2>>>& expectedOutputData,
347 bool isDynamic)
keidav011b3e2ea2019-02-21 10:07:37 +0000348{
Rob Hughesfc6bf052019-12-16 17:10:51 +0000349 using DataType2 = armnn::ResolveType<armnnType2>;
350
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000351 // Setup the armnn input tensors from the given vectors.
352 armnn::InputTensors inputTensors;
Sadik Armagan26868492021-01-22 14:25:31 +0000353 FillInputTensors<armnnType1>(inputTensors, inputData, subgraphId);
telsoa01c577f2c2018-08-31 09:22:23 +0100354
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000355 // Allocate storage for the output tensors to be written to and setup the armnn output tensors.
Sadik Armagan483c8112021-06-01 09:24:52 +0100356 std::map<std::string, std::vector<DataType2>> outputStorage;
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000357 armnn::OutputTensors outputTensors;
358 for (auto&& it : expectedOutputData)
359 {
Narumol Prangnawarat386681a2019-04-29 16:40:55 +0100360 armnn::LayerBindingId outputBindingId = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first).first;
361 armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkIdentifier, outputBindingId);
362
363 // Check that output tensors have correct number of dimensions (NumOutputDimensions specified in test)
364 auto outputNumDimensions = outputTensorInfo.GetNumDimensions();
Sadik Armagan1625efc2021-06-10 18:24:34 +0100365 CHECK_MESSAGE((outputNumDimensions == NumOutputDimensions),
James Ward58dec6b2020-09-11 17:32:44 +0100366 fmt::format("Number of dimensions expected {}, but got {} for output layer {}",
367 NumOutputDimensions,
368 outputNumDimensions,
369 it.first));
Narumol Prangnawarat386681a2019-04-29 16:40:55 +0100370
371 armnn::VerifyTensorInfoDataType(outputTensorInfo, armnnType2);
Sadik Armagan483c8112021-06-01 09:24:52 +0100372 outputStorage.emplace(it.first, std::vector<DataType2>(outputTensorInfo.GetNumElements()));
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000373 outputTensors.push_back(
Narumol Prangnawarat386681a2019-04-29 16:40:55 +0100374 { outputBindingId, armnn::Tensor(outputTensorInfo, outputStorage.at(it.first).data()) });
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000375 }
telsoa01c577f2c2018-08-31 09:22:23 +0100376
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000377 m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
telsoa01c577f2c2018-08-31 09:22:23 +0100378
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000379 // Compare each output tensor to the expected values
380 for (auto&& it : expectedOutputData)
381 {
Jim Flynnb4d7eae2019-05-01 14:44:27 +0100382 armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
Sadik Armagan483c8112021-06-01 09:24:52 +0100383 auto outputExpected = it.second;
Bruno Goncalves90211252021-07-11 21:49:00 -0300384 if (std::is_same<DataType2, uint8_t>::value)
385 {
386 auto result = CompareTensors(outputExpected, outputStorage[it.first],
387 bindingInfo.second.GetShape(), bindingInfo.second.GetShape(),
388 true, isDynamic);
389 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
390 }
391 else
392 {
393 auto result = CompareTensors(outputExpected, outputStorage[it.first],
394 bindingInfo.second.GetShape(), bindingInfo.second.GetShape(),
395 false, isDynamic);
396 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
397 }
telsoa01c577f2c2018-08-31 09:22:23 +0100398 }
Finn Williamsb49ed182021-06-29 15:50:08 +0100399
400 if (isDynamic)
401 {
402 m_Runtime->EnqueueWorkload(m_DynamicNetworkIdentifier, inputTensors, outputTensors);
403
404 // Compare each output tensor to the expected values
405 for (auto&& it : expectedOutputData)
406 {
407 armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
408 auto outputExpected = it.second;
409 auto result = CompareTensors(outputExpected, outputStorage[it.first],
410 bindingInfo.second.GetShape(), bindingInfo.second.GetShape(),
411 false, isDynamic);
412 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
413 }
414 }
telsoa01c577f2c2018-08-31 09:22:23 +0100415}
keidav011b3e2ea2019-02-21 10:07:37 +0000416
417/// Multiple Inputs, Multiple Outputs w/ Variable Datatypes and different dimension sizes.
418/// Executes the network with the given input tensors and checks the results against the given output tensors.
419/// This overload supports multiple inputs and multiple outputs, identified by name along with the allowance for
420/// the input datatype to be different to the output.
421template <armnn::DataType armnnType1,
Rob Hughesfc6bf052019-12-16 17:10:51 +0000422 armnn::DataType armnnType2>
keidav011b3e2ea2019-02-21 10:07:37 +0000423void ParserFlatbuffersFixture::RunTest(std::size_t subgraphId,
Rob Hughesfc6bf052019-12-16 17:10:51 +0000424 const std::map<std::string, std::vector<armnn::ResolveType<armnnType1>>>& inputData,
425 const std::map<std::string, std::vector<armnn::ResolveType<armnnType2>>>& expectedOutputData)
keidav011b3e2ea2019-02-21 10:07:37 +0000426{
Rob Hughesfc6bf052019-12-16 17:10:51 +0000427 using DataType2 = armnn::ResolveType<armnnType2>;
428
keidav011b3e2ea2019-02-21 10:07:37 +0000429 // Setup the armnn input tensors from the given vectors.
430 armnn::InputTensors inputTensors;
Sadik Armagan26868492021-01-22 14:25:31 +0000431 FillInputTensors<armnnType1>(inputTensors, inputData, subgraphId);
keidav011b3e2ea2019-02-21 10:07:37 +0000432
433 armnn::OutputTensors outputTensors;
434 outputTensors.reserve(expectedOutputData.size());
435 std::map<std::string, std::vector<DataType2>> outputStorage;
436 for (auto&& it : expectedOutputData)
437 {
Jim Flynnb4d7eae2019-05-01 14:44:27 +0100438 armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
keidav011b3e2ea2019-02-21 10:07:37 +0000439 armnn::VerifyTensorInfoDataType(bindingInfo.second, armnnType2);
440
441 std::vector<DataType2> out(it.second.size());
442 outputStorage.emplace(it.first, out);
443 outputTensors.push_back({ bindingInfo.first,
444 armnn::Tensor(bindingInfo.second,
445 outputStorage.at(it.first).data()) });
446 }
447
448 m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
449
450 // Checks the results.
451 for (auto&& it : expectedOutputData)
452 {
Rob Hughesfc6bf052019-12-16 17:10:51 +0000453 std::vector<armnn::ResolveType<armnnType2>> out = outputStorage.at(it.first);
keidav011b3e2ea2019-02-21 10:07:37 +0000454 {
455 for (unsigned int i = 0; i < out.size(); ++i)
456 {
Sadik Armagan1625efc2021-06-10 18:24:34 +0100457 CHECK(doctest::Approx(it.second[i]).epsilon(0.000001f) == out[i]);
keidav011b3e2ea2019-02-21 10:07:37 +0000458 }
459 }
460 }
Aron Virginas-Tarc975f922019-10-23 17:38:17 +0100461}
Sadik Armagan26868492021-01-22 14:25:31 +0000462
463/// Multiple Inputs with different DataTypes, Multiple Outputs w/ Variable DataTypes
464/// Executes the network with the given input tensors and checks the results against the given output tensors.
465/// This overload supports multiple inputs and multiple outputs, identified by name along with the allowance for
466/// the input datatype to be different to the output
467template <std::size_t NumOutputDimensions,
468 armnn::DataType inputType1,
469 armnn::DataType inputType2,
470 armnn::DataType outputType>
471void ParserFlatbuffersFixture::RunTest(size_t subgraphId,
472 const std::map<std::string, std::vector<armnn::ResolveType<inputType1>>>& input1Data,
473 const std::map<std::string, std::vector<armnn::ResolveType<inputType2>>>& input2Data,
474 const std::map<std::string, std::vector<armnn::ResolveType<outputType>>>& expectedOutputData)
475{
476 using DataType2 = armnn::ResolveType<outputType>;
477
478 // Setup the armnn input tensors from the given vectors.
479 armnn::InputTensors inputTensors;
480 FillInputTensors<inputType1>(inputTensors, input1Data, subgraphId);
481 FillInputTensors<inputType2>(inputTensors, input2Data, subgraphId);
482
483 // Allocate storage for the output tensors to be written to and setup the armnn output tensors.
Sadik Armagan483c8112021-06-01 09:24:52 +0100484 std::map<std::string, std::vector<DataType2>> outputStorage;
Sadik Armagan26868492021-01-22 14:25:31 +0000485 armnn::OutputTensors outputTensors;
486 for (auto&& it : expectedOutputData)
487 {
488 armnn::LayerBindingId outputBindingId = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first).first;
489 armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkIdentifier, outputBindingId);
490
491 // Check that output tensors have correct number of dimensions (NumOutputDimensions specified in test)
492 auto outputNumDimensions = outputTensorInfo.GetNumDimensions();
Sadik Armagan1625efc2021-06-10 18:24:34 +0100493 CHECK_MESSAGE((outputNumDimensions == NumOutputDimensions),
Sadik Armagan26868492021-01-22 14:25:31 +0000494 fmt::format("Number of dimensions expected {}, but got {} for output layer {}",
495 NumOutputDimensions,
496 outputNumDimensions,
497 it.first));
498
499 armnn::VerifyTensorInfoDataType(outputTensorInfo, outputType);
Sadik Armagan483c8112021-06-01 09:24:52 +0100500 outputStorage.emplace(it.first, std::vector<DataType2>(outputTensorInfo.GetNumElements()));
Sadik Armagan26868492021-01-22 14:25:31 +0000501 outputTensors.push_back(
502 { outputBindingId, armnn::Tensor(outputTensorInfo, outputStorage.at(it.first).data()) });
503 }
504
505 m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
506
507 // Compare each output tensor to the expected values
508 for (auto&& it : expectedOutputData)
509 {
510 armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
Sadik Armagan483c8112021-06-01 09:24:52 +0100511 auto outputExpected = it.second;
Bruno Goncalves90211252021-07-11 21:49:00 -0300512 if (std::is_same<DataType2, uint8_t>::value)
513 {
514 auto result = CompareTensors(outputExpected, outputStorage[it.first],
515 bindingInfo.second.GetShape(), bindingInfo.second.GetShape(), true);
516 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
517 }
518 else
519 {
520 auto result = CompareTensors(outputExpected, outputStorage[it.first],
521 bindingInfo.second.GetShape(), bindingInfo.second.GetShape());
522 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
523 }
Sadik Armagan26868492021-01-22 14:25:31 +0000524 }
525}