blob: d2155cf55ac8352b884c2f6c179af25219f5f8bd [file] [log] [blame]
Gunes Bayir806b8e82023-08-23 23:28:31 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "ckw/types/ConstantData.h"
26
27#include <limits>
28
29namespace ckw
30{
31namespace
32{
33 template<typename T>
34 inline typename std::enable_if<std::is_same<T, float>::value, std::string>::type to_str(T value)
35 {
36 std::stringstream ss;
37 ss << std::scientific << std::setprecision(std::numeric_limits<T>::max_digits10) << value;
38 return ss.str();
39 }
40
41 template<typename T>
42 inline typename std::enable_if<!std::is_same<T, float>::value && !std::is_same<T, bool>::value, std::string>::type to_str(T value)
43 {
44 return std::to_string(value);
45 }
46
47 template<typename T>
48 inline typename std::enable_if<std::is_same<T, bool>::value, std::string>::type to_str(T value)
49 {
50 return std::to_string((int) value);
51 }
52}
53
54template<typename T>
55ConstantData::ConstantData(std::initializer_list<std::initializer_list<T>> values, DataType data_type)
56 : _data_type(data_type)
57{
58 CKW_ASSERT(validate<T>(data_type));
59 CKW_ASSERT(values.size() > 0);
60
61 for(auto value_arr: values)
62 {
63 // Each row must have the same number of elements
64 CKW_ASSERT(value_arr.size() == (*values.begin()).size());
65
66 StringVector vec;
67 std::transform(value_arr.begin(), value_arr.end(),
68 std::back_inserter(vec),
69 [](T val) { return to_str(val); });
70
71 _values.push_back(std::move(vec));
72 }
73}
74
75template<typename T>
76bool ConstantData::validate(DataType data_type)
77{
78 switch(data_type)
79 {
80 case DataType::Fp32:
81 case DataType::Fp16:
82 return std::is_same<T, float>::value;
83 case DataType::Bool:
84 return std::is_same<T, bool>::value;
85 case DataType::Int32:
86 case DataType::Int16:
87 case DataType::Int8:
88 return std::is_same<T, int32_t>::value;
89 case DataType::Uint32:
90 case DataType::Uint16:
91 case DataType::Uint8:
92 return std::is_same<T, uint32_t>::value;
93 default:
94 CKW_THROW_MSG("Unknown data type!");
95 break;
96 }
97}
98
99// Necessary instantiations for compiler to recognize
100template ConstantData::ConstantData(std::initializer_list<std::initializer_list<int32_t>>, DataType);
101template ConstantData::ConstantData(std::initializer_list<std::initializer_list<uint32_t>>, DataType);
102template ConstantData::ConstantData(std::initializer_list<std::initializer_list<bool>>, DataType);
103template ConstantData::ConstantData(std::initializer_list<std::initializer_list<float>>, DataType);
104
105template bool ConstantData::validate<int32_t>(DataType);
106template bool ConstantData::validate<uint32_t>(DataType);
107template bool ConstantData::validate<bool>(DataType);
108template bool ConstantData::validate<float>(DataType);
109
110const std::vector<std::vector<std::string>>& ConstantData::values() const
111{
112 return _values;
113}
114
115DataType ConstantData::data_type() const
116{
117 return _data_type;
118}
119
120} // namespace ckw