blob: bf5a37b00f5bf96a0d47e4cb268864b38a00c558 [file] [log] [blame]
Jim Flynn2fd61002019-05-03 12:54:26 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Tensor.hpp>
9#include <vector>
10
11#include <boost/format.hpp>
12#include <boost/variant/apply_visitor.hpp>
13
14namespace armnnUtils
15{
16
17template<typename TContainer>
18inline armnn::InputTensors MakeInputTensors(
19 const std::vector<armnn::BindingPointInfo>& inputBindings,
20 const std::vector<TContainer>& inputDataContainers)
21{
22 armnn::InputTensors inputTensors;
23
24 const size_t numInputs = inputBindings.size();
25 if (numInputs != inputDataContainers.size())
26 {
27 throw armnn::Exception(boost::str(boost::format("Number of inputs does not match number of "
28 "tensor data containers: %1% != %2%") % numInputs % inputDataContainers.size()));
29 }
30
31 for (size_t i = 0; i < numInputs; i++)
32 {
33 const armnn::BindingPointInfo& inputBinding = inputBindings[i];
34 const TContainer& inputData = inputDataContainers[i];
35
36 boost::apply_visitor([&](auto&& value)
37 {
38 if (value.size() != inputBinding.second.GetNumElements())
39 {
40 throw armnn::Exception("Input tensor has incorrect size");
41 }
42
43 armnn::ConstTensor inputTensor(inputBinding.second, value.data());
44 inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
45 },
46 inputData);
47 }
48
49 return inputTensors;
50}
51
52template<typename TContainer>
53inline armnn::OutputTensors MakeOutputTensors(
54 const std::vector<armnn::BindingPointInfo>& outputBindings,
55 std::vector<TContainer>& outputDataContainers)
56{
57 armnn::OutputTensors outputTensors;
58
59 const size_t numOutputs = outputBindings.size();
60 if (numOutputs != outputDataContainers.size())
61 {
62 throw armnn::Exception(boost::str(boost::format("Number of outputs does not match number of "
63 "tensor data containers: %1% != %2%") % numOutputs % outputDataContainers.size()));
64 }
65
66 for (size_t i = 0; i < numOutputs; i++)
67 {
68 const armnn::BindingPointInfo& outputBinding = outputBindings[i];
69 TContainer& outputData = outputDataContainers[i];
70
71 boost::apply_visitor([&](auto&& value)
72 {
73 if (value.size() != outputBinding.second.GetNumElements())
74 {
75 throw armnn::Exception("Output tensor has incorrect size");
76 }
77
78 armnn::Tensor outputTensor(outputBinding.second, value.data());
79 outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
80 },
81 outputData);
82 }
83
84 return outputTensors;
85}
86
87} // namespace armnnUtils