Remove quantinfo types
Any needed information has been moved into the attributes for each operator.
This aligns with the structure of the attributes in the TOSA
specification, and generally simplifies the code.
Change-Id: I8243e91b09de1a9115f8af09c5e7def7e8f2866b
Signed-off-by: Eric Kunze <eric.kunze@arm.com>
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index 4d7d7bf..10372e4 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -30,7 +30,7 @@
# Keep version number in sync with the version default value with schema/tosa.fbs
TOSA_VERSION_MAJOR = 0
-TOSA_VERSION_MINOR = 25
+TOSA_VERSION_MINOR = 30
TOSA_VERSION_PATCH = 0
TOSA_VERSION_DRAFT = True
TOSA_VERSION = [
@@ -141,7 +141,7 @@
def __init__(self):
super().__init__()
- def PoolAttribute(self, kernel, stride, pad):
+ def PoolAttribute(self, kernel, stride, pad, input_zp, output_zp):
from tosa import PoolAttribute as a, Attribute
self.utype = Attribute.Attribute().PoolAttribute
@@ -150,8 +150,10 @@
self.intvecs.append((a.AddPad, pad))
self.intvecs.append((a.AddKernel, kernel))
self.intvecs.append((a.AddStride, stride))
+ self.ints.append((a.AddInputZp, input_zp))
+ self.ints.append((a.AddOutputZp, output_zp))
- def ConvAttribute(self, pad, stride, dilation):
+ def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp):
from tosa import ConvAttribute as a, Attribute
self.utype = Attribute.Attribute().ConvAttribute
@@ -160,8 +162,10 @@
self.intvecs.append((a.AddPad, pad))
self.intvecs.append((a.AddStride, stride))
self.intvecs.append((a.AddDilation, dilation))
+ self.ints.append((a.AddInputZp, input_zp))
+ self.ints.append((a.AddWeightZp, weight_zp))
- def TransposeConvAttribute(self, outpad, stride, output_shape):
+ def TransposeConvAttribute(self, outpad, stride, output_shape, input_zp, weight_zp):
from tosa import TransposeConvAttribute as a, Attribute
self.utype = Attribute.Attribute().TransposeConvAttribute
@@ -170,6 +174,8 @@
self.intvecs.append((a.AddOutPad, outpad))
self.intvecs.append((a.AddStride, stride))
self.intvecs.append((a.AddOutputShape, output_shape))
+ self.ints.append((a.AddInputZp, input_zp))
+ self.ints.append((a.AddWeightZp, weight_zp))
def PadAttribute(self, padding, pad_const_int, pad_const_fp):
from tosa import PadAttribute as a, Attribute
@@ -311,43 +317,32 @@
self.intvecs.append((a.AddTable, table))
+ def MatMulAttribute(self, A_zp, B_zp):
+ from tosa import MatMulAttribute as a, Attribute
-class TosaSerializerQuantInfo(TosaSerializerUnion):
- """This class handles encapsulating all of the enumerated types for quantinfo"""
+ self.utype = Attribute.Attribute().MatMulAttribute
+ self.optFcns = (a.Start, a.End)
- def __init__(self):
- super().__init__()
+ self.ints.append((a.AddAZp, A_zp))
+ self.ints.append((a.AddBZp, B_zp))
- def ConvQuantInfo(self, input_zp, weight_zp):
- from tosa import ConvQuantInfo as q, QuantInfo
+ def FullyConnectedAttribute(self, input_zp, weight_zp):
+ from tosa import FullyConnectedAttribute as a, Attribute
- self.utype = QuantInfo.QuantInfo().ConvQuantInfo
- self.optFcns = (q.Start, q.End)
- self.ints.append((q.AddInputZp, input_zp))
- self.ints.append((q.AddWeightZp, weight_zp))
+ self.utype = Attribute.Attribute().FullyConnectedAttribute
+ self.optFcns = (a.Start, a.End)
- def UnaryQuantInfo(self, input_zp, output_zp):
- from tosa import UnaryQuantInfo as q, QuantInfo
+ self.ints.append((a.AddInputZp, input_zp))
+ self.ints.append((a.AddWeightZp, weight_zp))
- self.utype = QuantInfo.QuantInfo().UnaryQuantInfo
- self.optFcns = (q.Start, q.End)
- self.ints.append((q.AddInputZp, input_zp))
- self.ints.append((q.AddOutputZp, output_zp))
+ def NegateAttribute(self, input1_zp, output_zp):
+ from tosa import NegateAttribute as a, Attribute
- def MatMulQuantInfo(self, a_zp, b_zp):
- from tosa import MatMulQuantInfo as q, QuantInfo
+ self.utype = Attribute.Attribute().NegateAttribute
+ self.optFcns = (a.Start, a.End)
- self.utype = QuantInfo.QuantInfo().MatMulQuantInfo
- self.optFcns = (q.Start, q.End)
- self.ints.append((q.AddAZp, a_zp))
- self.ints.append((q.AddBZp, b_zp))
-
- def PadQuantInfo(self, input_zp):
- from tosa import PadQuantInfo as q, QuantInfo
-
- self.utype = QuantInfo.QuantInfo().PadQuantInfo
- self.optFcns = (q.Start, q.End)
- self.ints.append((q.AddInputZp, input_zp))
+ self.ints.append((a.AddInput1Zp, input1_zp))
+ self.ints.append((a.AddOutputZp, output_zp))
class TosaSerializerTensor:
@@ -467,12 +462,11 @@
class TosaSerializerOperator:
- def __init__(self, op, inputs, outputs, attributes=None, quantInfo=None):
+ def __init__(self, op, inputs, outputs, attributes=None):
self.op = op
self.attributes = attributes
self.inputs = TosaSerializer.toList(inputs)
self.outputs = TosaSerializer.toList(outputs)
- self.quantInfo = quantInfo
def __str__(self):
str = "Op {}\n----\n".format(self.op)
@@ -491,13 +485,10 @@
fb_outputs = TosaSerializer.serializeStrVec(
builder, self.outputs, TosaOperator.StartOutputsVector
)
- # Need to serialize quant_info and attributes enums still
+ # Need to serialize attributes enums still
if self.attributes is not None:
fb_attributes = self.attributes.serialize(builder)
- if self.quantInfo is not None:
- fb_qinfo = self.quantInfo.serialize(builder)
-
TosaOperator.Start(builder)
TosaOperator.AddOp(builder, self.op)
TosaOperator.AddInputs(builder, fb_inputs)
@@ -505,9 +496,6 @@
if self.attributes is not None:
TosaOperator.AddAttributeType(builder, self.attributes.utype)
TosaOperator.AddAttribute(builder, fb_attributes)
- if self.quantInfo is not None:
- TosaOperator.AddQuantInfoType(builder, self.quantInfo.utype)
- TosaOperator.AddQuantInfo(builder, fb_qinfo)
return TosaOperator.End(builder)
@@ -544,10 +532,8 @@
def addOutput(self, name):
self.outputs.append(name)
- def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None):
- self.operators.append(
- TosaSerializerOperator(op, inputs, outputs, attributes, quant_info)
- )
+ def addOperator(self, op, inputs, outputs, attributes=None):
+ self.operators.append(TosaSerializerOperator(op, inputs, outputs, attributes))
def serialize(self, builder):
fb_name = builder.CreateString(self.name)
@@ -671,13 +657,16 @@
self.currBasicBlock.addOutput(name)
return tens
- def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None):
+ def addOperator(self, op, inputs, outputs, attributes=None):
if op == TosaOp.Op().CONST:
raise Exception("Use addConstTensor() to add CONST ops")
return self.currBasicBlock.addOperator(
- op, inputs, outputs, attributes, quant_info
+ op,
+ inputs,
+ outputs,
+ attributes,
)
def setExpectedReturnCode(self, val, fail, desc=""):
@@ -861,21 +850,48 @@
ConvAttribute.StartDilationVector = (
ConvAttribute.ConvAttributeStartDilationVector
)
+ ConvAttribute.AddInputZp = ConvAttribute.ConvAttributeAddInputZp
+ ConvAttribute.AddWeightZp = ConvAttribute.ConvAttributeAddWeightZp
ConvAttribute.End = ConvAttribute.ConvAttributeEnd
- from tosa import ConvQuantInfo
+ from tosa import FullyConnectedAttribute
- if not hasattr(ConvQuantInfo, "Start"):
- ConvQuantInfo.Start = ConvQuantInfo.ConvQuantInfoStart
- ConvQuantInfo.AddInputZp = ConvQuantInfo.ConvQuantInfoAddInputZp
- ConvQuantInfo.AddWeightZp = ConvQuantInfo.ConvQuantInfoAddWeightZp
- ConvQuantInfo.End = ConvQuantInfo.ConvQuantInfoEnd
- from tosa import MatMulQuantInfo
+ if not hasattr(FullyConnectedAttribute, "Start"):
+ FullyConnectedAttribute.Start = (
+ FullyConnectedAttribute.FullyConnectedAttributeStart
+ )
+ FullyConnectedAttribute.AddInputZp = (
+ FullyConnectedAttribute.FullyConnectedAttributeAddInputZp
+ )
+ FullyConnectedAttribute.AddWeightZp = (
+ FullyConnectedAttribute.FullyConnectedAttributeAddWeightZp
+ )
+ FullyConnectedAttribute.End = (
+ FullyConnectedAttribute.FullyConnectedAttributeEnd
+ )
+ from tosa import MatMulAttribute
- if not hasattr(MatMulQuantInfo, "Start"):
- MatMulQuantInfo.Start = MatMulQuantInfo.MatMulQuantInfoStart
- MatMulQuantInfo.AddAZp = MatMulQuantInfo.MatMulQuantInfoAddAZp
- MatMulQuantInfo.AddBZp = MatMulQuantInfo.MatMulQuantInfoAddBZp
- MatMulQuantInfo.End = MatMulQuantInfo.MatMulQuantInfoEnd
+ if not hasattr(MatMulAttribute, "Start"):
+ MatMulAttribute.Start = MatMulAttribute.MatMulAttributeStart
+ MatMulAttribute.AddAZp = MatMulAttribute.MatMulAttributeAddAZp
+ MatMulAttribute.AddBZp = MatMulAttribute.MatMulAttributeAddBZp
+ MatMulAttribute.End = MatMulAttribute.MatMulAttributeEnd
+ from tosa import PoolAttribute
+
+ if not hasattr(PoolAttribute, "Start"):
+ PoolAttribute.Start = PoolAttribute.PoolAttributeStart
+ PoolAttribute.AddPad = PoolAttribute.PoolAttributeAddPad
+ PoolAttribute.StartPadVector = PoolAttribute.PoolAttributeStartPadVector
+ PoolAttribute.AddKernel = PoolAttribute.PoolAttributeAddKernel
+ PoolAttribute.StartKernelVector = (
+ PoolAttribute.PoolAttributeStartKernelVector
+ )
+ PoolAttribute.AddStride = PoolAttribute.PoolAttributeAddStride
+ PoolAttribute.StartStrideVector = (
+ PoolAttribute.PoolAttributeStartStrideVector
+ )
+ PoolAttribute.AddInputZp = PoolAttribute.PoolAttributeAddInputZp
+ PoolAttribute.AddOutputZp = PoolAttribute.PoolAttributeAddOutputZp
+ PoolAttribute.End = PoolAttribute.PoolAttributeEnd
from tosa import MulAttribute
if not hasattr(MulAttribute, "Start"):
@@ -893,12 +909,6 @@
PadAttribute.AddPadConstInt = PadAttribute.PadAttributeAddPadConstInt
PadAttribute.AddPadConstFp = PadAttribute.PadAttributeAddPadConstFp
PadAttribute.End = PadAttribute.PadAttributeEnd
- from tosa import PadQuantInfo
-
- if not hasattr(PadQuantInfo, "Start"):
- PadQuantInfo.Start = PadQuantInfo.PadQuantInfoStart
- PadQuantInfo.AddInputZp = PadQuantInfo.PadQuantInfoAddInputZp
- PadQuantInfo.End = PadQuantInfo.PadQuantInfoEnd
from tosa import PoolAttribute
if not hasattr(PoolAttribute, "Start"):
@@ -913,6 +923,8 @@
PoolAttribute.StartStrideVector = (
PoolAttribute.PoolAttributeStartStrideVector
)
+ PoolAttribute.AddInputZp = PoolAttribute.PoolAttributeAddInputZp
+ PoolAttribute.AddOutputZp = PoolAttribute.PoolAttributeAddOutputZp
PoolAttribute.End = PoolAttribute.PoolAttributeEnd
from tosa import RescaleAttribute
@@ -1048,8 +1060,6 @@
TosaOperator.StartOutputsVector = (
TosaOperator.TosaOperatorStartOutputsVector
)
- TosaOperator.AddQuantInfoType = TosaOperator.TosaOperatorAddQuantInfoType
- TosaOperator.AddQuantInfo = TosaOperator.TosaOperatorAddQuantInfo
TosaOperator.End = TosaOperator.TosaOperatorEnd
from tosa import TosaTensor
@@ -1095,16 +1105,15 @@
TransposeConvAttribute.StartOutputShapeVector = (
TransposeConvAttribute.TransposeConvAttributeStartOutputShapeVector
)
+ TransposeConvAttribute.AddInputZp = (
+ TransposeConvAttribute.TransposeConvAttributeAddInputZp
+ )
+ TransposeConvAttribute.AddWeightZp = (
+ TransposeConvAttribute.TransposeConvAttributeAddWeightZp
+ )
TransposeConvAttribute.End = (
TransposeConvAttribute.TransposeConvAttributeEnd
)
- from tosa import UnaryQuantInfo
-
- if not hasattr(UnaryQuantInfo, "Start"):
- UnaryQuantInfo.Start = UnaryQuantInfo.UnaryQuantInfoStart
- UnaryQuantInfo.AddInputZp = UnaryQuantInfo.UnaryQuantInfoAddInputZp
- UnaryQuantInfo.AddOutputZp = UnaryQuantInfo.UnaryQuantInfoAddOutputZp
- UnaryQuantInfo.End = UnaryQuantInfo.UnaryQuantInfoEnd
from tosa import Version
if not hasattr(Version, "Start"):
@@ -1114,6 +1123,35 @@
Version.Add_patch = Version.VersionAdd_patch
Version.Add_draft = Version.VersionAdd_draft
Version.End = Version.VersionEnd
+ from tosa import MatMulAttribute
+
+ if not hasattr(MatMulAttribute, "Start"):
+ MatMulAttribute.Start = MatMulAttribute.MatMulAttributeStart
+ MatMulAttribute.AddAZp = MatMulAttribute.MatMulAttributeAddAZp
+ MatMulAttribute.AddBZp = MatMulAttribute.MatMulAttributeAddBZp
+ MatMulAttribute.End = MatMulAttribute.MatMulAttributeEnd
+ from tosa import FullyConnectedAttribute
+
+ if not hasattr(FullyConnectedAttribute, "Start"):
+ FullyConnectedAttribute.Start = (
+ FullyConnectedAttribute.FullyConnectedAttributeStart
+ )
+ FullyConnectedAttribute.AddInputZp = (
+ FullyConnectedAttribute.FullyConnectedAttributeAddInputZp
+ )
+ FullyConnectedAttribute.AddWeightZp = (
+ FullyConnectedAttribute.FullyConnectedAttributeAddWeightZp
+ )
+ FullyConnectedAttribute.End = (
+ FullyConnectedAttribute.FullyConnectedAttributeEnd
+ )
+ from tosa import NegateAttribute
+
+ if not hasattr(NegateAttribute, "Start"):
+ NegateAttribute.Start = NegateAttribute.NegateAttributeStart
+ NegateAttribute.AddInput1Zp = NegateAttribute.NegateAttributeAddInput1Zp
+ NegateAttribute.AddOutputZp = NegateAttribute.NegateAttributeAddOutputZp
+ NegateAttribute.End = NegateAttribute.NegateAttributeEnd
from tosa import WhileLoopAttribute
if not hasattr(WhileLoopAttribute, "Start"):