blob: 55dd3428b82b06657de70434f00e9f2f12fe0acc [file] [log] [blame]
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include <armnn/Tensor.hpp>
#include <fmt/format.h>
#include <mapbox/variant.hpp>
namespace armnnUtils
{
template<typename TContainer>
inline armnn::InputTensors MakeInputTensors(const std::vector<armnn::BindingPointInfo>& inputBindings,
const std::vector<TContainer>& inputDataContainers)
{
armnn::InputTensors inputTensors;
const size_t numInputs = inputBindings.size();
if (numInputs != inputDataContainers.size())
{
throw armnn::Exception(fmt::format("The number of inputs does not match number of "
"tensor data containers: {0} != {1}",
numInputs,
inputDataContainers.size()));
}
for (size_t i = 0; i < numInputs; i++)
{
const armnn::BindingPointInfo& inputBinding = inputBindings[i];
const TContainer& inputData = inputDataContainers[i];
mapbox::util::apply_visitor([&](auto&& value)
{
if (value.size() != inputBinding.second.GetNumElements())
{
throw armnn::Exception(fmt::format("The input tensor has incorrect size (expected {0} got {1})",
inputBinding.second.GetNumElements(),
value.size()));
}
armnn::TensorInfo inputTensorInfo = inputBinding.second;
inputTensorInfo.SetConstant(true);
armnn::ConstTensor inputTensor(inputTensorInfo, value.data());
inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
},
inputData);
}
return inputTensors;
}
template<typename TContainer>
inline armnn::OutputTensors MakeOutputTensors(const std::vector<armnn::BindingPointInfo>& outputBindings,
std::vector<TContainer>& outputDataContainers)
{
armnn::OutputTensors outputTensors;
const size_t numOutputs = outputBindings.size();
if (numOutputs != outputDataContainers.size())
{
throw armnn::Exception(fmt::format("Number of outputs does not match number"
"of tensor data containers: {0} != {1}",
numOutputs,
outputDataContainers.size()));
}
for (size_t i = 0; i < numOutputs; i++)
{
const armnn::BindingPointInfo& outputBinding = outputBindings[i];
TContainer& outputData = outputDataContainers[i];
mapbox::util::apply_visitor([&](auto&& value)
{
if (value.size() != outputBinding.second.GetNumElements())
{
throw armnn::Exception("Output tensor has incorrect size");
}
armnn::Tensor outputTensor(outputBinding.second, value.data());
outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
},
outputData);
}
return outputTensors;
}
} // namespace armnnUtils