Add conversions of U8 to/from BF16 and FP8
Adds type to PadAttribute and ClampAttribute so their pad_const
and max_val/min_val can be deserialized according to type
Adds conversion functions of U8 arrays to/from BF16/FP8 values
Also, refactor and expose TosaSerializer.convertDataToUint8Vec
for converting dtype/data to uint8 list for serialization
And modify convertDataToUint8Vec to serialize bf16 values into
2 bytes each, and serialize fp8 values into single bytes each.
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I05659e8187c76d359f1cc9f71c8c23cafd0e877f
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index 5c53f57..f5f9e58 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -18,6 +18,7 @@
#include "attribute.h"
#include "flatbuffers/idl.h"
#include "flatbuffers/util.h"
+#include "float_utils.h"
#include "numpy_utils.h"
#include "tosa_generated.h"
#include <cstdint>
@@ -411,6 +412,9 @@
tosa_err_t LoadFileSchema(const char* schema_filename);
// data format conversion. little-endian.
+ static tosa_err_t ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
@@ -421,6 +425,9 @@
static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
+ static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
+ static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
static tosa_err_t
ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<half_float::half>& out);
static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);