blob: 47e0a320b899c18550c55fc62c7ed602feac9a5f [file] [log] [blame]
Jim Flynn2fd61002019-05-03 12:54:26 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Tensor.hpp>
9#include <vector>
10
11#include <boost/format.hpp>
12#include <boost/variant/apply_visitor.hpp>
13
14namespace armnnUtils
15{
16
17template<typename TContainer>
18inline armnn::InputTensors MakeInputTensors(
19 const std::vector<armnn::BindingPointInfo>& inputBindings,
20 const std::vector<TContainer>& inputDataContainers)
21{
22 armnn::InputTensors inputTensors;
23
24 const size_t numInputs = inputBindings.size();
25 if (numInputs != inputDataContainers.size())
26 {
27 throw armnn::Exception(boost::str(boost::format("Number of inputs does not match number of "
28 "tensor data containers: %1% != %2%") % numInputs % inputDataContainers.size()));
29 }
30
31 for (size_t i = 0; i < numInputs; i++)
32 {
33 const armnn::BindingPointInfo& inputBinding = inputBindings[i];
34 const TContainer& inputData = inputDataContainers[i];
35
36 boost::apply_visitor([&](auto&& value)
37 {
38 if (value.size() != inputBinding.second.GetNumElements())
39 {
Matthew Bentham4cefc412019-06-18 16:14:34 +010040 std::ostringstream msg;
41 msg << "Input tensor has incorrect size (expected "
42 << inputBinding.second.GetNumElements() << " got "
43 << value.size();
44 throw armnn::Exception(msg.str());
Jim Flynn2fd61002019-05-03 12:54:26 +010045 }
46
47 armnn::ConstTensor inputTensor(inputBinding.second, value.data());
48 inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
49 },
50 inputData);
51 }
52
53 return inputTensors;
54}
55
56template<typename TContainer>
57inline armnn::OutputTensors MakeOutputTensors(
58 const std::vector<armnn::BindingPointInfo>& outputBindings,
59 std::vector<TContainer>& outputDataContainers)
60{
61 armnn::OutputTensors outputTensors;
62
63 const size_t numOutputs = outputBindings.size();
64 if (numOutputs != outputDataContainers.size())
65 {
66 throw armnn::Exception(boost::str(boost::format("Number of outputs does not match number of "
67 "tensor data containers: %1% != %2%") % numOutputs % outputDataContainers.size()));
68 }
69
70 for (size_t i = 0; i < numOutputs; i++)
71 {
72 const armnn::BindingPointInfo& outputBinding = outputBindings[i];
73 TContainer& outputData = outputDataContainers[i];
74
75 boost::apply_visitor([&](auto&& value)
76 {
77 if (value.size() != outputBinding.second.GetNumElements())
78 {
79 throw armnn::Exception("Output tensor has incorrect size");
80 }
81
82 armnn::Tensor outputTensor(outputBinding.second, value.data());
83 outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
84 },
85 outputData);
86 }
87
88 return outputTensors;
89}
90
Matthew Bentham4cefc412019-06-18 16:14:34 +010091} // namespace armnnUtils