blob: ade2f2dd20cfedacec387ecdb4175449dd199202 [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
Jerry Ge13a32912023-07-03 16:36:41 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunze2364dcd2021-04-26 11:06:57 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef _TOSA_NUMPY_UTILS_H
17#define _TOSA_NUMPY_UTILS_H
18
19#include <cassert>
20#include <cctype>
21#include <cstdint>
22#include <cstdio>
23#include <cstdlib>
24#include <cstring>
25#include <vector>
26
Won Jeona8141522024-04-29 23:57:27 +000027#include "cfloat.h"
James Ward485a11d2022-08-05 13:48:37 +010028#include "half.hpp"
29
Won Jeona8141522024-04-29 23:57:27 +000030using bf16 = ct::cfloat<int16_t, 8, true, true, true>;
31using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>;
32using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>;
33
Eric Kunze2364dcd2021-04-26 11:06:57 -070034class NumpyUtilities
35{
36public:
37 enum NPError
38 {
39 NO_ERROR = 0,
40 FILE_NOT_FOUND,
41 FILE_IO_ERROR,
42 FILE_TYPE_MISMATCH,
43 HEADER_PARSE_ERROR,
44 BUFFER_SIZE_MISMATCH,
Jerry Ge13a32912023-07-03 16:36:41 +000045 DATA_TYPE_NOT_SUPPORTED,
Eric Kunze2364dcd2021-04-26 11:06:57 -070046 };
47
TatWai Chong679bdad2023-07-31 15:15:12 -070048 template <typename T>
49 static const char* getDTypeString(bool& is_bool)
50 {
51 is_bool = false;
52 if (std::is_same<T, bool>::value)
53 {
54 is_bool = true;
55 return "'|b1'";
56 }
57 if (std::is_same<T, uint8_t>::value)
58 {
59 return "'|u1'";
60 }
61 if (std::is_same<T, int8_t>::value)
62 {
63 return "'|i1'";
64 }
65 if (std::is_same<T, uint16_t>::value)
66 {
67 return "'<u2'";
68 }
69 if (std::is_same<T, int16_t>::value)
70 {
71 return "'<i2'";
72 }
73 if (std::is_same<T, int32_t>::value)
74 {
75 return "'<i4'";
76 }
77 if (std::is_same<T, int64_t>::value)
78 {
79 return "'<i8'";
80 }
81 if (std::is_same<T, float>::value)
82 {
83 return "'<f4'";
84 }
85 if (std::is_same<T, double>::value)
86 {
87 return "'<f8'";
88 }
89 if (std::is_same<T, half_float::half>::value)
90 {
91 return "'<f2'";
92 }
Won Jeona8141522024-04-29 23:57:27 +000093 if (std::is_same<T, bf16>::value)
94 {
95 return "'<V2'";
96 }
97 if (std::is_same<T, fp8e4m3>::value)
98 {
99 return "'<V1'";
100 }
101 if (std::is_same<T, fp8e5m2>::value)
102 {
103 return "'<f1'";
104 }
TatWai Chong679bdad2023-07-31 15:15:12 -0700105 assert(false && "unsupported Dtype");
106 };
Eric Kunze2364dcd2021-04-26 11:06:57 -0700107
TatWai Chong679bdad2023-07-31 15:15:12 -0700108 template <typename T>
109 static NPError writeToNpyFile(const char* filename, const uint32_t elems, const T* databuf)
110 {
111 std::vector<int32_t> shape = { static_cast<int32_t>(elems) };
112 return writeToNpyFile(filename, shape, databuf);
113 }
Tai Ly3ef34fb2023-04-04 20:34:05 +0000114
TatWai Chong679bdad2023-07-31 15:15:12 -0700115 template <typename T>
116 static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const T* databuf)
117 {
118 bool is_bool;
119 const char* dtype_str = getDTypeString<T>(is_bool);
120 return writeToNpyFileCommon(filename, dtype_str, sizeof(T), shape, databuf, is_bool);
121 }
James Ward485a11d2022-08-05 13:48:37 +0100122
TatWai Chong679bdad2023-07-31 15:15:12 -0700123 template <typename T>
124 static NPError readFromNpyFile(const char* filename, const uint32_t elems, T* databuf)
125 {
126 bool is_bool;
127 const char* dtype_str = getDTypeString<T>(is_bool);
128 return readFromNpyFileCommon(filename, dtype_str, sizeof(T), elems, databuf, is_bool);
129 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700130
TatWai Chong679bdad2023-07-31 15:15:12 -0700131 template <typename D, typename S>
132 static void copyBufferByElement(D* dest_buf, S* src_buf, int num)
133 {
Won Jeon780ffb52023-08-21 13:32:36 -0700134 static_assert(sizeof(D) >= sizeof(S), "The size of dest_buf must be equal to or larger than that of src_buf");
TatWai Chong679bdad2023-07-31 15:15:12 -0700135 for (int i = 0; i < num; ++i)
136 {
137 dest_buf[i] = src_buf[i];
138 }
139 }
Tai Ly3ef34fb2023-04-04 20:34:05 +0000140
Eric Kunze2364dcd2021-04-26 11:06:57 -0700141private:
142 static NPError writeToNpyFileCommon(const char* filename,
143 const char* dtype_str,
144 const size_t elementsize,
145 const std::vector<int32_t>& shape,
146 const void* databuf,
147 bool bool_translate);
148 static NPError readFromNpyFileCommon(const char* filename,
149 const char* dtype_str,
150 const size_t elementsize,
151 const uint32_t elems,
152 void* databuf,
153 bool bool_translate);
154 static NPError checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str);
Jerry Ge13a32912023-07-03 16:36:41 +0000155 static NPError getHeader(FILE* infile, bool& is_signed, int& bit_length, char& byte_order);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700156 static NPError writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str);
157};
158
TatWai Chong39b5edc2023-08-18 17:58:17 +0000159template <>
160NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf);
161
Eric Kunze2364dcd2021-04-26 11:06:57 -0700162#endif // _TOSA_NUMPY_UTILS_H