[serialization_lib] Add acc_type to Conv Attrs

This adds acc_type to ConvAttribute and TransposeConvAttribute

Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I73bab71b2eb90f6451fadee21d5bed1811ecbfd7
diff --git a/include/attribute.def b/include/attribute.def
index 2176f47..723543e 100644
--- a/include/attribute.def
+++ b/include/attribute.def
@@ -34,21 +34,23 @@
               int32_t, S, output_zp,
               DType,   S, acc_type)
 
-DEF_ATTRIBUTE(Conv, 6,
+DEF_ATTRIBUTE(Conv, 7,
               int32_t, V, pad,
               int32_t, V, stride,
               int32_t, V, dilation,
               int32_t, S, input_zp,
               int32_t, S, weight_zp,
-              bool,    S, local_bound)
+              bool,    S, local_bound,
+              DType,   S, acc_type)
 
-DEF_ATTRIBUTE(TransposeConv, 6,
+DEF_ATTRIBUTE(TransposeConv, 7,
               int32_t, V, out_pad,
               int32_t, V, stride,
               int32_t, V, output_shape,
               int32_t, S, input_zp,
               int32_t, S, weight_zp,
-              bool,    S, local_bound)
+              bool,    S, local_bound,
+              DType,   S, acc_type)
 
 DEF_ATTRIBUTE(Pad, 1,
               uint8_t, V, pad_const)
diff --git a/include/tosa_generated.h b/include/tosa_generated.h
index 64d54bc..20f6993 100644
--- a/include/tosa_generated.h
+++ b/include/tosa_generated.h
@@ -759,7 +759,8 @@
     VT_DILATION = 8,
     VT_INPUT_ZP = 10,
     VT_WEIGHT_ZP = 12,
-    VT_LOCAL_BOUND = 14
+    VT_LOCAL_BOUND = 14,
+    VT_ACC_TYPE = 16
   };
   const ::flatbuffers::Vector<int32_t> *pad() const {
     return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_PAD);
@@ -779,6 +780,9 @@
   bool local_bound() const {
     return GetField<uint8_t>(VT_LOCAL_BOUND, 0) != 0;
   }
+  tosa::DType acc_type() const {
+    return static_cast<tosa::DType>(GetField<uint32_t>(VT_ACC_TYPE, 0));
+  }
   bool Verify(::flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
            VerifyOffset(verifier, VT_PAD) &&
@@ -790,6 +794,7 @@
            VerifyField<int32_t>(verifier, VT_INPUT_ZP, 4) &&
            VerifyField<int32_t>(verifier, VT_WEIGHT_ZP, 4) &&
            VerifyField<uint8_t>(verifier, VT_LOCAL_BOUND, 1) &&
+           VerifyField<uint32_t>(verifier, VT_ACC_TYPE, 4) &&
            verifier.EndTable();
   }
 };
@@ -816,6 +821,9 @@
   void add_local_bound(bool local_bound) {
     fbb_.AddElement<uint8_t>(ConvAttribute::VT_LOCAL_BOUND, static_cast<uint8_t>(local_bound), 0);
   }
+  void add_acc_type(tosa::DType acc_type) {
+    fbb_.AddElement<uint32_t>(ConvAttribute::VT_ACC_TYPE, static_cast<uint32_t>(acc_type), 0);
+  }
   explicit ConvAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
         : fbb_(_fbb) {
     start_ = fbb_.StartTable();
@@ -834,8 +842,10 @@
     ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> dilation = 0,
     int32_t input_zp = 0,
     int32_t weight_zp = 0,
-    bool local_bound = false) {
+    bool local_bound = false,
+    tosa::DType acc_type = tosa::DType_UNKNOWN) {
   ConvAttributeBuilder builder_(_fbb);
+  builder_.add_acc_type(acc_type);
   builder_.add_weight_zp(weight_zp);
   builder_.add_input_zp(input_zp);
   builder_.add_dilation(dilation);
@@ -852,7 +862,8 @@
     const std::vector<int32_t> *dilation = nullptr,
     int32_t input_zp = 0,
     int32_t weight_zp = 0,
-    bool local_bound = false) {
+    bool local_bound = false,
+    tosa::DType acc_type = tosa::DType_UNKNOWN) {
   auto pad__ = pad ? _fbb.CreateVector<int32_t>(*pad) : 0;
   auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0;
   auto dilation__ = dilation ? _fbb.CreateVector<int32_t>(*dilation) : 0;
@@ -863,7 +874,8 @@
       dilation__,
       input_zp,
       weight_zp,
-      local_bound);
+      local_bound,
+      acc_type);
 }
 
 struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
@@ -874,7 +886,8 @@
     VT_OUTPUT_SHAPE = 8,
     VT_INPUT_ZP = 10,
     VT_WEIGHT_ZP = 12,
-    VT_LOCAL_BOUND = 14
+    VT_LOCAL_BOUND = 14,
+    VT_ACC_TYPE = 16
   };
   const ::flatbuffers::Vector<int32_t> *out_pad() const {
     return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_OUT_PAD);
@@ -894,6 +907,9 @@
   bool local_bound() const {
     return GetField<uint8_t>(VT_LOCAL_BOUND, 0) != 0;
   }
+  tosa::DType acc_type() const {
+    return static_cast<tosa::DType>(GetField<uint32_t>(VT_ACC_TYPE, 0));
+  }
   bool Verify(::flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
            VerifyOffset(verifier, VT_OUT_PAD) &&
@@ -905,6 +921,7 @@
            VerifyField<int32_t>(verifier, VT_INPUT_ZP, 4) &&
            VerifyField<int32_t>(verifier, VT_WEIGHT_ZP, 4) &&
            VerifyField<uint8_t>(verifier, VT_LOCAL_BOUND, 1) &&
+           VerifyField<uint32_t>(verifier, VT_ACC_TYPE, 4) &&
            verifier.EndTable();
   }
 };
@@ -931,6 +948,9 @@
   void add_local_bound(bool local_bound) {
     fbb_.AddElement<uint8_t>(TransposeConvAttribute::VT_LOCAL_BOUND, static_cast<uint8_t>(local_bound), 0);
   }
+  void add_acc_type(tosa::DType acc_type) {
+    fbb_.AddElement<uint32_t>(TransposeConvAttribute::VT_ACC_TYPE, static_cast<uint32_t>(acc_type), 0);
+  }
   explicit TransposeConvAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
         : fbb_(_fbb) {
     start_ = fbb_.StartTable();
@@ -949,8 +969,10 @@
     ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> output_shape = 0,
     int32_t input_zp = 0,
     int32_t weight_zp = 0,
-    bool local_bound = false) {
+    bool local_bound = false,
+    tosa::DType acc_type = tosa::DType_UNKNOWN) {
   TransposeConvAttributeBuilder builder_(_fbb);
+  builder_.add_acc_type(acc_type);
   builder_.add_weight_zp(weight_zp);
   builder_.add_input_zp(input_zp);
   builder_.add_output_shape(output_shape);
@@ -967,7 +989,8 @@
     const std::vector<int32_t> *output_shape = nullptr,
     int32_t input_zp = 0,
     int32_t weight_zp = 0,
-    bool local_bound = false) {
+    bool local_bound = false,
+    tosa::DType acc_type = tosa::DType_UNKNOWN) {
   auto out_pad__ = out_pad ? _fbb.CreateVector<int32_t>(*out_pad) : 0;
   auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0;
   auto output_shape__ = output_shape ? _fbb.CreateVector<int32_t>(*output_shape) : 0;
@@ -978,7 +1001,8 @@
       output_shape__,
       input_zp,
       weight_zp,
-      local_bound);
+      local_bound,
+      acc_type);
 }
 
 struct PadAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index 2c7996a..9658edf 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -172,7 +172,9 @@
         self.ints.append((a.AddOutputZp, output_zp))
         self.ints.append((a.AddAccType, acc_type))
 
-    def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp, local_bound):
+    def ConvAttribute(
+        self, pad, stride, dilation, input_zp, weight_zp, local_bound, acc_type
+    ):
         from tosa import ConvAttribute as a, Attribute
 
         self.utype = Attribute.Attribute().ConvAttribute
@@ -184,9 +186,10 @@
         self.ints.append((a.AddInputZp, input_zp))
         self.ints.append((a.AddWeightZp, weight_zp))
         self.bools.append((a.AddLocalBound, local_bound))
+        self.ints.append((a.AddAccType, acc_type))
 
     def TransposeConvAttribute(
-        self, outpad, stride, output_shape, input_zp, weight_zp, local_bound
+        self, outpad, stride, output_shape, input_zp, weight_zp, local_bound, acc_type
     ):
         from tosa import TransposeConvAttribute as a, Attribute
 
@@ -199,6 +202,7 @@
         self.ints.append((a.AddInputZp, input_zp))
         self.ints.append((a.AddWeightZp, weight_zp))
         self.bools.append((a.AddLocalBound, local_bound))
+        self.ints.append((a.AddAccType, acc_type))
 
     def PadAttribute(self, serializer_builder, pad_const_val_as_bytes):
         from tosa import PadAttribute as a, Attribute
diff --git a/python/tosa/ConvAttribute.py b/python/tosa/ConvAttribute.py
index b35b67c..dfa75dc 100644
--- a/python/tosa/ConvAttribute.py
+++ b/python/tosa/ConvAttribute.py
@@ -130,8 +130,15 @@
             return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
         return False
 
+    # ConvAttribute
+    def AccType(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
+        return 0
+
 def ConvAttributeStart(builder):
-    builder.StartObject(6)
+    builder.StartObject(7)
 
 def Start(builder):
     ConvAttributeStart(builder)
@@ -190,6 +197,12 @@
 def AddLocalBound(builder, localBound):
     ConvAttributeAddLocalBound(builder, localBound)
 
+def ConvAttributeAddAccType(builder, accType):
+    builder.PrependUint32Slot(6, accType, 0)
+
+def AddAccType(builder, accType):
+    ConvAttributeAddAccType(builder, accType)
+
 def ConvAttributeEnd(builder):
     return builder.EndObject()
 
diff --git a/python/tosa/TransposeConvAttribute.py b/python/tosa/TransposeConvAttribute.py
index a74a433..e5397a8 100644
--- a/python/tosa/TransposeConvAttribute.py
+++ b/python/tosa/TransposeConvAttribute.py
@@ -130,8 +130,15 @@
             return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
         return False
 
+    # TransposeConvAttribute
+    def AccType(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
+        return 0
+
 def TransposeConvAttributeStart(builder):
-    builder.StartObject(6)
+    builder.StartObject(7)
 
 def Start(builder):
     TransposeConvAttributeStart(builder)
@@ -190,6 +197,12 @@
 def AddLocalBound(builder, localBound):
     TransposeConvAttributeAddLocalBound(builder, localBound)
 
+def TransposeConvAttributeAddAccType(builder, accType):
+    builder.PrependUint32Slot(6, accType, 0)
+
+def AddAccType(builder, accType):
+    TransposeConvAttributeAddAccType(builder, accType)
+
 def TransposeConvAttributeEnd(builder):
     return builder.EndObject()
 
diff --git a/schema/tosa.fbs b/schema/tosa.fbs
index 028765d..79b83b1 100644
--- a/schema/tosa.fbs
+++ b/schema/tosa.fbs
@@ -170,6 +170,7 @@
   input_zp: int32;
   weight_zp: int32;
   local_bound: bool;
+  acc_type: DType;
 }
 
 table TransposeConvAttribute {
@@ -179,6 +180,7 @@
   input_zp: int32;
   weight_zp: int32;
   local_bound: bool;
+  acc_type: DType;
 }
 
 table PadAttribute {