Add TosaSerializerRegion to python version of serialization_lib

Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: Ibd15f21aa24168730c904224f08fd55e27aae41f
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index 85955aa..2d03d49 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
 #
 #    Licensed under the Apache License, Version 2.0 (the "License");
 #    you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@
 from enum import IntEnum, unique
 from tosa import (
     TosaGraph,
+    TosaRegion,
     TosaBasicBlock,
     TosaTensor,
     TosaOperator,
@@ -404,12 +405,12 @@
         self.placeholderFilename = placeholderFilename
 
     def __str__(self):
-        str = "TosaSerializerTensor name: {} shape: {} dtype: {}".format(
+        concatString = "TosaSerializerTensor name: {} shape: {} dtype: {}".format(
             self.name,
             self.shape,
             DTypeNames[self.dtype],
         )
-        return str
+        return concatString
 
     def setDtype(self, dtype):
         self.dtype = dtype
@@ -500,14 +501,14 @@
         self.outputs = TosaSerializer.toList(outputs)
 
     def __str__(self):
-        str = "Op {}\n----\n".format(self.op)
+        concatString = "Op {}\n----\n".format(self.op)
 
         for i in self.inputs:
-            str = str + "  Input:  {}\n".format(i)
+            concatString = concatString + "  Input:  {}\n".format(i)
         for o in self.outputs:
-            str = str + "  Output: {}\n".format(o)
+            concatString = concatString + "  Output: {}\n".format(o)
 
-        return str
+        return concatString
 
     def serialize(self, builder):
         fb_inputs = TosaSerializer.serializeStrVec(
@@ -532,9 +533,11 @@
 
 
 class TosaSerializerBasicBlock:
-    def __init__(self, name):
+    def __init__(self, name, pathPrefix, saveConstsToFile=False):
         self.name = name
+        self.pathPrefix = pathPrefix
         self.operators = []
+        self.saveConstsToFile = saveConstsToFile
 
         # Dict assures uniqueness, but allows us to look up by name
         self.tensors = dict()
@@ -592,44 +595,33 @@
         return TosaBasicBlock.End(builder)
 
 
-@unique
-class TensorDir(IntEnum):
-    PLACEHOLDER = 0
-    CONST = 1
-    INTERMEDIATE = 2
-    RESULT = 3
-
-
-class TosaSerializer:
-    def __init__(self, pathPrefix, saveConstsToFile=False):
-        self.add_compat_methods()
-        # Get the global TOSA version if not already defined
-
-        self.builder = flatbuffers.Builder(0)
-
+class TosaSerializerRegion:
+    def __init__(self, name, pathPrefix, saveConstsToFile=False):
+        self.name = name
         self.basicBlocks = []
-        self.startBasicBlock("main")
-        self.pathPrefix = pathPrefix
-
-        # Enables inspection of constant data outside of graph
-        self.saveConstsToFile = saveConstsToFile
-
-        # Indicies used for adding/naming tensors
         self.currInputIdx = 0
         self.currConstIdx = 0
         self.currLayerIdx = 1
         self.currResultIdx = 0
+        self.pathPrefix = pathPrefix
+        self.saveConstsToFile = saveConstsToFile
 
-        # Is this an illegal test that is expected to fail?
-        self.expectedReturnCode = 0
-        self.expectedFailure = False
-        self.expectedFailureDesc = ""
+    def addBasicBlock(self, name, pathPrefix, saveConstsToFile):
+        self.currBasicBlock = TosaSerializerBasicBlock(
+            name, pathPrefix, saveConstsToFile
+        )
+        self.basicBlocks.append(self.currBasicBlock)
 
-    def __str__(self):
-        str = ""
-        for bb in self.basicBlocks:
-            str = str + bb.__str__()
-        return str
+    def serialize(self, builder):
+        fb_name = builder.CreateString(self.name)
+        fbv_basicBlocks = TosaSerializer.serializeObjVec(
+            builder, self.basicBlocks, TosaRegion.StartBlocksVector
+        )
+
+        TosaRegion.Start(builder)
+        TosaRegion.AddName(builder, fb_name)
+        TosaRegion.AddBlocks(builder, fbv_basicBlocks)
+        return TosaRegion.End(builder)
 
     def addPlaceholder(self, shape, dtype, vals):
         if not self.currBasicBlock:
@@ -666,7 +658,6 @@
         return tens
 
     def addIntermediate(self, shape, dtype):
-
         if not self.currBasicBlock:
             raise Exception("addTensor called without valid basic block")
 
@@ -696,7 +687,6 @@
         return tens
 
     def addOperator(self, op, inputs, outputs, attributes=None):
-
         if op == TosaOp.Op().CONST:
             raise Exception("Use addConstTensor() to add CONST ops")
 
@@ -707,6 +697,62 @@
             attributes,
         )
 
+
+@unique
+class TensorDir(IntEnum):
+    PLACEHOLDER = 0
+    CONST = 1
+    INTERMEDIATE = 2
+    RESULT = 3
+
+
+class TosaSerializer:
+    def __init__(self, pathPrefix, saveConstsToFile=False):
+        self.add_compat_methods()
+        # Get the global TOSA version if not already defined
+
+        self.builder = flatbuffers.Builder(0)
+
+        self.regions = []
+        self.startRegion("main", pathPrefix, saveConstsToFile)
+
+        # Enables inspection of constant data outside of graph
+        self.saveConstsToFile = saveConstsToFile
+
+        self.currRegion.addBasicBlock("main", pathPrefix, self.saveConstsToFile)
+
+        # Is this an illegal test that is expected to fail?
+        self.expectedReturnCode = 0
+        self.expectedFailure = False
+        self.expectedFailureDesc = ""
+
+    def __str__(self):
+        concatString = ""
+        for region in self.regions:
+            concatString = concatString + str(region)
+        return concatString
+
+    def addPlaceholder(self, shape, dtype, vals):
+        return self.currRegion.addPlaceholder(shape, dtype, vals)
+
+    def addConst(self, shape, dtype, vals):
+        return self.currRegion.addConst(shape, dtype, vals)
+
+    def addIntermediate(self, shape, dtype):
+        return self.currRegion.addIntermediate(shape, dtype)
+
+    def addInputTensor(self, tensor):
+        self.currRegion.addInputTensor(tensor)
+
+    def addOutputTensor(self, tensor):
+        self.currRegion.addOutputTensor(tensor)
+
+    def addOutput(self, shape, dtype):
+        return self.currRegion.addOutput(shape, dtype)
+
+    def addOperator(self, op, inputs, outputs, attributes=None):
+        return self.currRegion.addOperator(op, inputs, outputs, attributes)
+
     def setExpectedReturnCode(self, val, fail, desc=""):
 
         self.expectedReturnCode = val
@@ -724,13 +770,13 @@
         Version.Add_draft(builder, TOSA_VERSION[3])
         version = Version.End(builder)
 
-        fbv_bb = TosaSerializer.serializeObjVec(
-            builder, self.basicBlocks, TosaGraph.StartBlocksVector
+        fbv_region = TosaSerializer.serializeObjVec(
+            builder, self.regions, TosaGraph.StartRegionsVector
         )
 
         TosaGraph.Start(builder)
         TosaGraph.AddVersion(builder, version)
-        TosaGraph.AddBlocks(builder, fbv_bb)
+        TosaGraph.AddRegions(builder, fbv_region)
         graph = TosaGraph.End(builder)
 
         self.builder.Finish(graph, TOSA_GRAPH_IDENTIFIER)
@@ -747,16 +793,17 @@
         ofm_name = []
         ofm_file = []
 
-        for b in self.basicBlocks:
-            if b.name == "main":
-                for i in b.inputs:
-                    ifm_name.append(i)
-                    ifm_file.append(b.tensors[i].placeholderFilename)
-                for o in b.outputs:
-                    ofm_name.append(o)
-                    # Make up an OFM filename here.  One isn't generated until the
-                    # reference tool is run, so any name is a good name
-                    ofm_file.append("ref-{}.npy".format(o))
+        for region in self.regions:
+            for block in region.basicBlocks:
+                if block:
+                    for i in block.inputs:
+                        ifm_name.append(i)
+                        ifm_file.append(block.tensors[i].placeholderFilename)
+                    for o in block.outputs:
+                        ofm_name.append(o)
+                        # Make up an OFM filename here.  One isn't generated until the
+                        # reference tool is run, so any name is a good name
+                        ofm_file.append("ref-{}.npy".format(o))
 
         test_desc["ifm_name"] = ifm_name
         test_desc["ifm_file"] = ifm_file
@@ -769,9 +816,9 @@
 
         return json.dumps(test_desc, indent="  ")
 
-    def startBasicBlock(self, name):
-        self.currBasicBlock = TosaSerializerBasicBlock(name)
-        self.basicBlocks.append(self.currBasicBlock)
+    def startRegion(self, name, pathPrefix, saveConstsToFile):
+        self.currRegion = TosaSerializerRegion(name, pathPrefix, saveConstsToFile)
+        self.regions.append(self.currRegion)
 
     @staticmethod
     def serializeStrVec(builder, vec, start_fcn):
@@ -1090,8 +1137,8 @@
         if not hasattr(TosaGraph, "Start"):
             TosaGraph.Start = TosaGraph.TosaGraphStart
             TosaGraph.AddVersion = TosaGraph.TosaGraphAddVersion
-            TosaGraph.AddBlocks = TosaGraph.TosaGraphAddBlocks
-            TosaGraph.StartBlocksVector = TosaGraph.TosaGraphStartBlocksVector
+            TosaGraph.AddRegions = TosaGraph.TosaGraphAddRegions
+            TosaGraph.StartRegionsVector = TosaGraph.TosaGraphStartRegionsVector
             TosaGraph.End = TosaGraph.TosaGraphEnd
         from tosa import TosaOperator