blob: 55dd3428b82b06657de70434f00e9f2f12fe0acc [file] [log] [blame]
Jim Flynn2fd61002019-05-03 12:54:26 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Tensor.hpp>
Jim Flynn2fd61002019-05-03 12:54:26 +01009
Colm Donelan5b5c2222020-09-09 12:48:16 +010010#include <fmt/format.h>
James Ward6d9f5c52020-09-28 11:56:35 +010011#include <mapbox/variant.hpp>
Jim Flynn2fd61002019-05-03 12:54:26 +010012
13namespace armnnUtils
14{
15
16template<typename TContainer>
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000017inline armnn::InputTensors MakeInputTensors(const std::vector<armnn::BindingPointInfo>& inputBindings,
18 const std::vector<TContainer>& inputDataContainers)
Jim Flynn2fd61002019-05-03 12:54:26 +010019{
20 armnn::InputTensors inputTensors;
21
22 const size_t numInputs = inputBindings.size();
23 if (numInputs != inputDataContainers.size())
24 {
Colm Donelan5b5c2222020-09-09 12:48:16 +010025 throw armnn::Exception(fmt::format("The number of inputs does not match number of "
26 "tensor data containers: {0} != {1}",
27 numInputs,
28 inputDataContainers.size()));
Jim Flynn2fd61002019-05-03 12:54:26 +010029 }
30
31 for (size_t i = 0; i < numInputs; i++)
32 {
33 const armnn::BindingPointInfo& inputBinding = inputBindings[i];
34 const TContainer& inputData = inputDataContainers[i];
35
James Ward6d9f5c52020-09-28 11:56:35 +010036 mapbox::util::apply_visitor([&](auto&& value)
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000037 {
38 if (value.size() != inputBinding.second.GetNumElements())
39 {
Colm Donelan5b5c2222020-09-09 12:48:16 +010040 throw armnn::Exception(fmt::format("The input tensor has incorrect size (expected {0} got {1})",
41 inputBinding.second.GetNumElements(),
42 value.size()));
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000043 }
Cathal Corbett5b8093c2021-10-22 11:12:07 +010044 armnn::TensorInfo inputTensorInfo = inputBinding.second;
45 inputTensorInfo.SetConstant(true);
46 armnn::ConstTensor inputTensor(inputTensorInfo, value.data());
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000047 inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
48 },
49 inputData);
Jim Flynn2fd61002019-05-03 12:54:26 +010050 }
51
52 return inputTensors;
53}
54
55template<typename TContainer>
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000056inline armnn::OutputTensors MakeOutputTensors(const std::vector<armnn::BindingPointInfo>& outputBindings,
57 std::vector<TContainer>& outputDataContainers)
Jim Flynn2fd61002019-05-03 12:54:26 +010058{
59 armnn::OutputTensors outputTensors;
60
61 const size_t numOutputs = outputBindings.size();
62 if (numOutputs != outputDataContainers.size())
63 {
Colm Donelan5b5c2222020-09-09 12:48:16 +010064 throw armnn::Exception(fmt::format("Number of outputs does not match number"
65 "of tensor data containers: {0} != {1}",
66 numOutputs,
67 outputDataContainers.size()));
Jim Flynn2fd61002019-05-03 12:54:26 +010068 }
69
70 for (size_t i = 0; i < numOutputs; i++)
71 {
72 const armnn::BindingPointInfo& outputBinding = outputBindings[i];
73 TContainer& outputData = outputDataContainers[i];
74
James Ward6d9f5c52020-09-28 11:56:35 +010075 mapbox::util::apply_visitor([&](auto&& value)
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000076 {
77 if (value.size() != outputBinding.second.GetNumElements())
78 {
79 throw armnn::Exception("Output tensor has incorrect size");
80 }
Jim Flynn2fd61002019-05-03 12:54:26 +010081
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000082 armnn::Tensor outputTensor(outputBinding.second, value.data());
83 outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
84 },
85 outputData);
Jim Flynn2fd61002019-05-03 12:54:26 +010086 }
87
88 return outputTensors;
89}
90
Matthew Bentham4cefc412019-06-18 16:14:34 +010091} // namespace armnnUtils