blob: 01c7babe2432c75bf650a46d66c7bd138aef570c [file] [log] [blame]
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <ResolveType.hpp>
9
10#include <armnn/ArmNN.hpp>
11
12#include <reference/workloads/Encoders.hpp>
13
14#include <vector>
15
16// Utility tenmplate to convert a collection of values to the correct type
17template <armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
18std::vector<T> ConvertToDataType(const std::vector<float>& input,
19 const armnn::TensorInfo& inputTensorInfo)
20{
21 std::vector<T> output(input.size());
22 auto outputTensorInfo = inputTensorInfo;
23 outputTensorInfo.SetDataType(ArmnnType);
24
25 std::unique_ptr<armnn::Encoder<float>> pOutputEncoder = armnn::MakeEncoder<float>(outputTensorInfo, output.data());
26 armnn::Encoder<float>& rOutputEncoder = *pOutputEncoder;
27
28 for (auto it = input.begin(); it != input.end(); ++it)
29 {
30 rOutputEncoder.Set(*it);
31 ++rOutputEncoder;
32 }
33 return output;
34}
35
36// Utility tenmplate to convert a single value to the correct type
37template <typename T>
38T ConvertToDataType(const float& value,
39 const armnn::TensorInfo& tensorInfo)
40{
41 std::vector<T> output(1);
42 std::unique_ptr<armnn::Encoder<float>> pEncoder = armnn::MakeEncoder<float>(tensorInfo, output.data());
43 armnn::Encoder<float>& rEncoder = *pEncoder;
44 rEncoder.Set(value);
45 return output[0];
46}