[Serialization_lib] Support StatefulOps for TOSA

- Add variable in TosaTensor to schema file
- Update TosaSerializationTensor regarding variable change
- Rename internal zero_pad() and expose interface as ForceAlignTensorData()

Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: I36fa64eb0802cb5b8d3564ea7233460ef8c9f539
diff --git a/include/tosa_generated.h b/include/tosa_generated.h
index 22819f1..b2805a8 100644
--- a/include/tosa_generated.h
+++ b/include/tosa_generated.h
@@ -2204,7 +2204,8 @@
     VT_NAME = 4,
     VT_SHAPE = 6,
     VT_TYPE = 8,
-    VT_DATA = 10
+    VT_DATA = 10,
+    VT_VARIABLE = 12
   };
   const ::flatbuffers::String *name() const {
     return GetPointer<const ::flatbuffers::String *>(VT_NAME);
@@ -2218,6 +2219,9 @@
   const ::flatbuffers::Vector<uint8_t> *data() const {
     return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_DATA);
   }
+  bool variable() const {
+    return GetField<uint8_t>(VT_VARIABLE, 0) != 0;
+  }
   bool Verify(::flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
            VerifyOffset(verifier, VT_NAME) &&
@@ -2227,6 +2231,7 @@
            VerifyField<uint32_t>(verifier, VT_TYPE, 4) &&
            VerifyOffset(verifier, VT_DATA) &&
            verifier.VerifyVector(data()) &&
+           VerifyField<uint8_t>(verifier, VT_VARIABLE, 1) &&
            verifier.EndTable();
   }
 };
@@ -2247,6 +2252,9 @@
   void add_data(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data) {
     fbb_.AddOffset(TosaTensor::VT_DATA, data);
   }
+  void add_variable(bool variable) {
+    fbb_.AddElement<uint8_t>(TosaTensor::VT_VARIABLE, static_cast<uint8_t>(variable), 0);
+  }
   explicit TosaTensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
         : fbb_(_fbb) {
     start_ = fbb_.StartTable();
@@ -2263,12 +2271,14 @@
     ::flatbuffers::Offset<::flatbuffers::String> name = 0,
     ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> shape = 0,
     tosa::DType type = tosa::DType_UNKNOWN,
-    ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data = 0) {
+    ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data = 0,
+    bool variable = false) {
   TosaTensorBuilder builder_(_fbb);
   builder_.add_data(data);
   builder_.add_type(type);
   builder_.add_shape(shape);
   builder_.add_name(name);
+  builder_.add_variable(variable);
   return builder_.Finish();
 }
 
@@ -2277,7 +2287,8 @@
     const char *name = nullptr,
     const std::vector<int32_t> *shape = nullptr,
     tosa::DType type = tosa::DType_UNKNOWN,
-    const std::vector<uint8_t> *data = nullptr) {
+    const std::vector<uint8_t> *data = nullptr,
+    bool variable = false) {
   auto name__ = name ? _fbb.CreateString(name) : 0;
   auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0;
   if (data) { _fbb.ForceVectorAlignment(data->size(), sizeof(uint8_t), 8); }
@@ -2287,7 +2298,8 @@
       name__,
       shape__,
       type,
-      data__);
+      data__,
+      variable);
 }
 
 struct TosaOperator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index cae6a27..bf44c11 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -1,5 +1,5 @@
 
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
 //
 //    Licensed under the Apache License, Version 2.0 (the "License");
 //    you may not use this file except in compliance with the License.
@@ -112,11 +112,13 @@
     TosaSerializationTensor(const flatbuffers::String* name,
                             const flatbuffers::Vector<int32_t>* shape,
                             DType dtype,
-                            const flatbuffers::Vector<uint8_t>* data);
+                            const flatbuffers::Vector<uint8_t>* data,
+                            bool variable = false);
     TosaSerializationTensor(const std::string& name,
                             const std::vector<int32_t>& shape,
                             DType dtype,
-                            const std::vector<uint8_t>& data);
+                            const std::vector<uint8_t>& data,
+                            bool variable = false);
     TosaSerializationTensor();
     ~TosaSerializationTensor();
 
@@ -129,10 +131,14 @@
     {
         return _shape;
     }
-    DType GetDtype()
+    DType GetDtype() const
     {
         return _dtype;
     }
+    bool GetVariable() const
+    {
+        return _variable;
+    }
     const std::vector<uint8_t>& GetData() const
     {
         return _data;
@@ -169,6 +175,7 @@
     DType _dtype;                /* data type enumeration, see tosa_isa_generated.h */
     std::vector<int32_t> _shape; /* shape of the tensor */
     std::string _name;           /* name of the tensor, used for solving dependency */
+    bool _variable;              /* is this a variable tensor */
     std::vector<uint8_t> _data;  /* data array */
 };
 
@@ -368,6 +375,8 @@
     static tosa_err_t ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
     static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out);
 
+    static void ForceAlignTensorData(std::vector<uint8_t>& buf);
+
     // version
     const TosaVersion& GetVersion()
     {
diff --git a/python/tosa/TosaTensor.py b/python/tosa/TosaTensor.py
index 850ff8f..d8264f2 100644
--- a/python/tosa/TosaTensor.py
+++ b/python/tosa/TosaTensor.py
@@ -96,8 +96,15 @@
         o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
         return o == 0
 
+    # TosaTensor
+    def Variable(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+        if o != 0:
+            return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+        return False
+
 def TosaTensorStart(builder):
-    builder.StartObject(4)
+    builder.StartObject(5)
 
 def Start(builder):
     TosaTensorStart(builder)
@@ -138,6 +145,12 @@
 def StartDataVector(builder, numElems: int) -> int:
     return TosaTensorStartDataVector(builder, numElems)
 
+def TosaTensorAddVariable(builder, variable):
+    builder.PrependBoolSlot(4, variable, 0)
+
+def AddVariable(builder, variable):
+    TosaTensorAddVariable(builder, variable)
+
 def TosaTensorEnd(builder):
     return builder.EndObject()
 
diff --git a/schema/tosa.fbs b/schema/tosa.fbs
index f101fa3..0943f11 100644
--- a/schema/tosa.fbs
+++ b/schema/tosa.fbs
@@ -281,6 +281,7 @@
   shape:[int32];                    // shape of the tensor
   type:DType;                       // data type of the tensor
   data: [ubyte] (force_align: 8);   // raw data array if it's a constant tensor.
+  variable: bool;                   // is this a variable tensor
 }
 
 table TosaOperator {
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index cbb862f..cb44f17 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -1,5 +1,5 @@
 
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
 //
 //    Licensed under the Apache License, Version 2.0 (the "License");
 //    you may not use this file except in compliance with the License.
@@ -22,10 +22,11 @@
 TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name,
                                                  const flatbuffers::Vector<int32_t>* shape,
                                                  DType dtype,
-                                                 const flatbuffers::Vector<uint8_t>* data)
+                                                 const flatbuffers::Vector<uint8_t>* data,
+                                                 bool variable)
 {
-    _dtype = dtype;
-
+    _dtype    = dtype;
+    _variable = variable;
     if (shape)
     {
         std::copy(shape->begin(), shape->end(), std::back_inserter(_shape));
@@ -43,18 +44,21 @@
 TosaSerializationTensor::TosaSerializationTensor(const std::string& name,
                                                  const std::vector<int32_t>& shape,
                                                  DType dtype,
-                                                 const std::vector<uint8_t>& data)
+                                                 const std::vector<uint8_t>& data,
+                                                 bool variable)
 {
-    _dtype = dtype;
-    _shape = shape;
-    _name  = name;
-    _data  = data;
+    _dtype    = dtype;
+    _variable = variable;
+    _shape    = shape;
+    _name     = name;
+    _data     = data;
 }
 
 TosaSerializationTensor::TosaSerializationTensor()
 {
-    _dtype = DType_UNKNOWN;
-    _name  = "UNKNOWN";
+    _dtype    = DType_UNKNOWN;
+    _variable = false;
+    _name     = "UNKNOWN";
 }
 
 TosaSerializationTensor::~TosaSerializationTensor()
@@ -514,12 +518,14 @@
             {
                 auto curr_tensor = fb_tosa_tensors->Get(j);
 
-                auto tensor_name  = curr_tensor->name();
-                auto tensor_shape = curr_tensor->shape();
-                auto tensor_type  = curr_tensor->type();
-                auto tensor_data  = curr_tensor->data();
+                auto tensor_name     = curr_tensor->name();
+                auto tensor_shape    = curr_tensor->shape();
+                auto tensor_type     = curr_tensor->type();
+                auto tensor_variable = curr_tensor->variable();
+                auto tensor_data     = curr_tensor->data();
 
-                new_tensor = new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data);
+                new_tensor =
+                    new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data, tensor_variable);
                 if (new_tensor)
                 {
                     block_tensors_container.push_back(new_tensor);
@@ -676,8 +682,10 @@
                 auto tensor_name     = _builder.CreateString(tensor->GetName().c_str());
                 auto tensor_shape    = _builder.CreateVector(tensor->GetShape());
                 auto tensor_dtype    = tensor->GetDtype();
+                bool tensor_variable = tensor->GetVariable();
                 auto tensor_data     = _builder.CreateVector(tensor->GetData());
-                auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data);
+                auto fboffset_tensor =
+                    CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data, tensor_variable);
                 fboffset_block_tensors.push_back(fboffset_tensor);
             }
             auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors);
@@ -702,7 +710,7 @@
     return TOSA_OK;
 }
 
-void zero_pad(std::vector<uint8_t>& buf)
+void TosaSerializationHandler::ForceAlignTensorData(std::vector<uint8_t>& buf)
 {
     while ((buf.size() % TENSOR_BUFFER_FORCE_ALIGNMENT) != 0)
     {
@@ -721,7 +729,7 @@
         out.push_back(*val_u16 & 0xFF);
         out.push_back((*val_u16 >> 8) & 0xFF);
     }
-    zero_pad(out);
+    ForceAlignTensorData(out);
     return TOSA_OK;
 }
 
@@ -736,7 +744,7 @@
         out.push_back((*val_u32 >> 16) & 0xFF);
         out.push_back((*val_u32 >> 24) & 0xFF);
     }
-    zero_pad(out);
+    ForceAlignTensorData(out);
     return TOSA_OK;
 }
 
@@ -753,7 +761,7 @@
         out.push_back((*val_u64 >> 32) & 0xFF);
         out.push_back((*val_u64 >> 40) & 0xFF);
     }
-    zero_pad(out);
+    ForceAlignTensorData(out);
     return TOSA_OK;
 }
 
@@ -768,7 +776,7 @@
         out.push_back((*val_u32 >> 16) & 0xFF);
         out.push_back((*val_u32 >> 24) & 0xFF);
     }
-    zero_pad(out);
+    ForceAlignTensorData(out);
     return TOSA_OK;
 }
 
@@ -781,7 +789,7 @@
         out.push_back(*val_u16 & 0xFF);
         out.push_back((*val_u16 >> 8) & 0xFF);
     }
-    zero_pad(out);
+    ForceAlignTensorData(out);
     return TOSA_OK;
 }
 
@@ -793,7 +801,7 @@
         uint8_t* val_u8 = reinterpret_cast<uint8_t*>(&val);
         out.push_back(*val_u8);
     }
-    zero_pad(out);
+    ForceAlignTensorData(out);
     return TOSA_OK;
 }
 
@@ -824,7 +832,7 @@
         uint8_t val_u8    = static_cast<uint8_t>(val_packed);
         out.push_back(val_u8);
     }
-    zero_pad(out);
+    ForceAlignTensorData(out);
     return TOSA_OK;
 }
 
@@ -836,7 +844,7 @@
         uint8_t val_u8 = val;
         out.push_back(val_u8);
     }
-    zero_pad(out);
+    ForceAlignTensorData(out);
     return TOSA_OK;
 }