Update reference model/serialization library to 0.21.0 with unit tests added/updated
- update tosa.GATHER
- update tosa.RESIZE
- add tosa.SCATTER

Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I1c3247b831a64e35a85c4044b24c6c29b8e18d25
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 302e4f4..0e57a7b 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -158,6 +158,29 @@
         return shape_list
 
     @staticmethod
+    def tgScatter(testGen, opName, rank):
+        pl, const = opName['operands']
+
+        assert(pl == 2)
+        assert(const == 0)
+        assert(rank == 3)
+
+        values_in_shape = testGen.makeShape(rank)
+
+        # Constrict the batch size?
+        if testGen.args.max_batch_size:
+            values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
+
+        W = testGen.randInt(testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1])
+        input_shape = [values_in_shape[0], W, values_in_shape[2]]
+
+        shape_list = []
+        shape_list.append(values_in_shape.copy())
+        shape_list.append(input_shape.copy())
+
+        return shape_list
+
+    @staticmethod
     def tgBroadcastFuzz(testGen, op, rank):
         shape = testGen.makeShape(rank)
 
@@ -650,6 +673,8 @@
                 outputDTypeList = [ DType.INT8 ]
             elif m == ResizeMode.BILINEAR and dtype == DType.INT16:
                 outputDTypeList = [ DType.INT48 ]
+            elif dtype == DType.FLOAT:
+                outputDTypeList = [ DType.FLOAT ]
             else:
                 continue
 
@@ -659,19 +684,52 @@
                     # Randomly generate legal output dimensions and shift
                     # and then compute the stride and offset based on them
                     output_dims = [ testGen.randInt(), testGen.randInt() ]
+                    in_center_h = (ifm_shape[1] - 1) / 2.0
+                    in_center_w = (ifm_shape[2] - 1) / 2.0
+                    out_center_h = (output_dims[0] - 1) / 2.0
+                    out_center_w = (output_dims[1] - 1) / 2.0
 
-                    shift = testGen.randInt(1, 11)
+                    fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
+                    fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
+                    fp_offset_y = in_center_h - fp_stride_y * out_center_h
+                    fp_offset_x = in_center_w - fp_stride_x * out_center_w
 
-                    stride = [ (ifm_shape[1] << shift) // output_dims[0],
-                               (ifm_shape[2] << shift) // output_dims[1] ]
+                    if outputDType == DType.FLOAT:
+                        shift = 0
+                        stride = [0, 0]
+                        offset = [0, 0]
+                        stride_fp = [ fp_stride_y, fp_stride_x]
+                        offset_fp = [ fp_offset_y, fp_offset_x]
+                        arg_list.append(('mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}'.format(m, output_dims[0], output_dims[1],
+                                                                                                  testGen.typeStr(outputDType), stride_fp[0], stride_fp[1],
+                                                                                                  offset_fp[0], offset_fp[1]),
+                                             [m, stride, offset, shift, stride_fp, offset_fp, output_dims, dtype, outputDType]))
+                    else:
+                        shift = 11
+                        unit = float(1 << shift)
+                        stride_y = int(round(fp_stride_y * unit))
+                        stride_x = int(round(fp_stride_x * unit))
+                        offset_y = int(round(fp_offset_y * unit))
+                        offset_x = int(round(fp_offset_x * unit))
 
-                    offset = [ testGen.randInt(-stride[0], (ifm_shape[1] << shift) - (output_dims[0] - 1) * stride[0]),
-                               testGen.randInt(-stride[1], (ifm_shape[2] << shift) - (output_dims[1] - 1) * stride[1]) ]
+                        while (stride_y >= 32768 or stride_x >= 32768 or offset_y >= 32768 or offset_x >= 32768 or offset_y < -32768 or offset_x < -32768):
+                            shift = shift - 1
+                            unit = float(1 << shift)
+                            stride_y = int(round(fp_stride_y * unit))
+                            stride_x = int(round(fp_stride_x * unit))
+                            offset_y = int(round(fp_offset_y * unit))
+                            offset_x = int(round(fp_offset_x * unit))
 
-                    arg_list.append(('mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}'.format(m, shift, output_dims[0], output_dims[1],
-                                                                                              testGen.typeStr(outputDType), stride[0], stride[1],
-                                                                                              offset[0], offset[1]),
-                                         [m, stride, offset, shift, output_dims, dtype, outputDType]))
+                        stride = [ stride_y, stride_x]
+                        offset = [ offset_y, offset_x]
+
+                        stride_fp = [0.0, 0.0]
+                        offset_fp = [0.0, 0.0]
+
+                        arg_list.append(('mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}'.format(m, shift, output_dims[0], output_dims[1],
+                                                                                                  testGen.typeStr(outputDType), stride[0], stride[1],
+                                                                                                  offset[0], offset[1]),
+                                             [m, stride, offset, shift, stride_fp, offset_fp, output_dims, dtype, outputDType]))
 
         return arg_list
 
@@ -1139,29 +1197,44 @@
         return result_tens
 
 
-    def build_gather(self, op, values, axis):
+    def build_gather(self, op, values):
 
         # Create a new indicies tensor
         # here with data that doesn't exceed the dimensions of the values tensor
 
-        max_val = values.shape[axis]
-        indicies_arr = np.int32(self.rng.integers(low=0, high=max_val, size=[self.randInt(1, max_val + 1)]))
+        K = values.shape[1] # K
+        W = self.randInt(self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]) # W
+        indicies_arr = np.int32(self.rng.integers(low=0, high=K, size=[values.shape[0], W])) # (N, W)
         indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, Usage.INDEX, [], indicies_arr)
 
-        result_tens = OutputShaper.gatherOp(self.ser, values, indicies, axis)
+        result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
 
-        attr = ts.TosaSerializerAttribute()
-        attr.AxisAttribute(axis)
-
-        self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name], attr)
+        self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name])
 
         return result_tens
 
-    def build_resize(self, op, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype):
-        result_tens = OutputShaper.resizeOp(self.ser, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype)
+    def build_scatter(self, op, values_in, input):
+
+        # Create a new indicies tensor
+        # here with data that doesn't exceed the dimensions of the values_in tensor
+
+        K = values_in.shape[1] # K
+        W = input.shape[1] # W
+        indicies_arr = np.int32(self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])) # (N, W)
+        indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, Usage.INDEX, [], indicies_arr)
+
+        result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
+
+        self.ser.addOperator(op, [values_in.name, indicies.name, input.name], [result_tens.name])
+
+        return result_tens
+
+    def build_resize(self, op, input, mode, stride, offset, shift, stride_fp, offset_fp, output_dims, input_dtype, output_dtype):
+        result_tens = OutputShaper.resizeOp(self.ser, input, mode, stride, offset, shift, stride_fp, offset_fp, output_dims, input_dtype, output_dtype)
 
         attr = ts.TosaSerializerAttribute()
-        attr.ResizeAttribute(output_dims, stride, offset, shift, mode)
+
+        attr.ResizeAttribute(output_dims, stride, offset, shift, stride_fp, offset_fp, mode)
 
         self.ser.addOperator(op, [input.name], [result_tens.name], attr)
         return result_tens
@@ -1966,10 +2039,20 @@
         # Scatter/Gather
         'gather':
         { 'op':        Op.GATHER,
+          # Only specify 'values' tensor here. 'indices' is generated in op building stage
           'operands':  (1, 0),
-          'build_fcn': (build_gather, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
-          'types':     TYPE_INT },
+          'rank':      (3, 3),
+          'build_fcn': (build_gather, TosaTensorGen.tgBasic, None),
+          'types':     TYPE_INT_FP },
 
+        'scatter':
+        { 'op':        Op.SCATTER,
+          # Only specify 'values_in' tensor here.
+          #'indices' and 'input' are generated in op building stage
+          'operands':  (2, 0),
+          'rank':      (3, 3),
+          'build_fcn': (build_scatter, TosaTensorGen.tgScatter, None),
+          'types':     TYPE_INT_FP },
 
         # Image operations
         'resize':
@@ -1977,7 +2060,7 @@
           'operands':  (1, 0),
           'rank':      (4, 4),
           'build_fcn': ( build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
-          'types':    [ DType.INT8, DType.INT16 ] },
+          'types':    [ DType.INT8, DType.INT16, DType.FLOAT ] },
 
 
         # Data nodes
@@ -2319,11 +2402,27 @@
         return ser.addOutput(output_shape, a.dtype, a.usage, a.dformat)
 
     @staticmethod
-    def gatherOp(ser, values, indicies, axis):
-        # indicies minus the axis + values - the indexes used to look up values.
-        output_shape = [*values.shape[0:axis],  indicies.shape[0],  *values.shape[axis+1:]]
+    def gatherOp(ser, values, indices):
+        assert len(values.shape) == 3
+        assert len(indices.shape) == 2
+        assert values.shape[0] == indices.shape[0]
 
-        return ser.addOutput(output_shape, values.dtype, indicies.usage, indicies.dformat)
+        output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
+
+        return ser.addOutput(output_shape, values.dtype, values.usage, values.dformat)
+
+    @staticmethod
+    def scatterOp(ser, values_in, indices, input):
+        assert len(values_in.shape) == 3
+        assert len(indices.shape) == 2
+        assert len(input.shape) == 3
+        assert values_in.shape[0] == indices.shape[0] # N
+        assert input.shape[1] == indices.shape[1] # W
+        assert values_in.shape[2] == input.shape[2] # C
+
+        output_shape = values_in.shape
+
+        return ser.addOutput(output_shape, values_in.dtype, values_in.usage, values_in.dformat)
 
     @staticmethod
     def tableOp(ser, input, table):
@@ -2331,12 +2430,16 @@
         return ser.addOutput(input.shape, DType.INT32, input.usage, input.dformat)
 
     @staticmethod
-    def resizeOp(ser, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype):
+    def resizeOp(ser, input, mode, stride, offset, shift, stride_fp, offset_fp, output_dims, input_dtype, output_dtype):
 
         output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
 
-        if stride[0] <= 0 or stride[1] <= 0:
-            ser.setExpectedFailure(True, 'Negative or zero stride')
+        if input_dtype == DType.FLOAT:
+            if stride_fp[0] <= 0 or stride_fp[1] <= 0:
+                ser.setExpectedFailure(True, 'Negative or zero stride')
+        else:
+            if stride[0] <= 0 or stride[1] <= 0:
+                ser.setExpectedFailure(True, 'Negative or zero stride')
 
         if mode == ResizeMode.BILINEAR:
             if input_dtype == DType.INT8:
@@ -2345,6 +2448,9 @@
             elif input_dtype == DType.INT16:
                 if output_dtype != DType.INT48:
                     ser.setexpectedfailure(true, 'Invalid output data type')
+            elif input_dtype == DType.FLOAT:
+                if output_dtype != DType.FLOAT:
+                    ser.setexpectedfailure(true, 'Invalid output data type')
             else:
                 ser.setexpectedfailure(true, 'Invalid input data type')
 
@@ -2355,6 +2461,9 @@
             elif input_dtype == DType.INT16:
                 if output_dtype != DType.INT16:
                     ser.setexpectedfailure(true, 'Invalid output data type')
+            elif input_dtype == DType.FLOAT:
+                if output_dtype != DType.FLOAT:
+                    ser.setexpectedfailure(true, 'Invalid output data type')
             else:
                 ser.setexpectedfailure(true, 'Invalid input data type')