MLBEDSW-2836 Change sets to tuples

Replace conditional checks against sets with tuples.
If not requiring uniqueness, or complex set operations, it is quicker to
use tuples instead.

Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com>
Change-Id: Ie8732c8d46067244963936c53f0ec81adda50372
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 4a85750..d0d3d7c 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -43,9 +43,9 @@
 from .tensor import QuantizationParameters
 from .tensor import Tensor
 
-passthrough_nodes = set((Op.Identity,))
+passthrough_nodes = (Op.Identity,)
 
-memory_only_ops = set((Op.Reshape,))
+memory_only_ops = (Op.Reshape,)
 
 
 def remove_passthrough_tensor(tens, arch, nng):
@@ -663,7 +663,7 @@
 
 # Reorder activation op if it's after the memory only operations
 def fixup_act_reorder(op, arch, nng):
-    if op.type.is_relu_op() or op.type in set((Op.Sigmoid, Op.Tanh)):
+    if op.type.is_relu_op() or op.type in (Op.Sigmoid, Op.Tanh):
         prep_op = get_prepend_op(op)
         if prep_op is not None:
             act_op = op.clone("_reordered")
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index 30e22ff..d057d17 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -55,8 +55,8 @@
                 new_start_coord[idx] += split_offset[idx]
                 new_end_coord[idx] += split_offset[idx]
 
-        if split_offset is None and npu_block_type in set(
-            (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum)
+        if (split_offset is None) and (
+            npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum)
         ):
             # these types of operations do a "dot product" or sum over the entire IFM
             new_start_coord[-1] = 0
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 12edf5e..b287785 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -178,7 +178,7 @@
                 visit_tensor(inp)
                 inp.consumer_list.append(op)
 
-            if op.type in set((Op.Placeholder, Op.SubgraphInput)):
+            if op.type in (Op.Placeholder, Op.SubgraphInput):
                 assert len(op.outputs) == 1
                 self.input_tensors.append(op.outputs[0])
 
@@ -321,7 +321,7 @@
         all_ops = self.get_all_ops()
         unique_ops = []
         for op in all_ops:
-            if op.type in set((Op.Const, Op.Identity, Op.Placeholder)):
+            if op.type in (Op.Const, Op.Identity, Op.Placeholder):
                 continue
 
             attrs = op.attrs.copy()
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index d28df97..2d7a1b0 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -45,7 +45,7 @@
     ofm_block = Block(block_config_ps2[-3], block_config_ps2[-4], block_config_ps2[-1])
     kernel = ps2.primary_op.kernel
 
-    if ps2.npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct)):
+    if ps2.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
         op = ps2.primary_op
         ifm_block_depth = arch.calc_ifm_block_depth(op.ifm.shape[-1], op.ifm.dtype.size_in_bits())
     else:
@@ -499,7 +499,7 @@
     ifm_read_multiple = 1
     weight_read_multiple = 0
 
-    if ps.placement in set((PassPlacement.MemoryOnly, PassPlacement.StartupInit)):
+    if ps.placement in (PassPlacement.MemoryOnly, PassPlacement.StartupInit):
         return bws, macs, cycles, blocks, ifm_read_multiple, weight_read_multiple  # nothing real happening in this pass
 
     min_block_size = arch.min_block_sizes[ps.npu_block_type]
@@ -537,13 +537,11 @@
             ifm_block_depth, ofm_block, primary_op.kernel, ifm_resampling_mode=ifm_tensor.resampling_mode
         )
 
-        if npu_block_type in set(
-            (
-                NpuBlockType.ConvolutionMxN,
-                NpuBlockType.ConvolutionDepthWise,
-                NpuBlockType.Pooling,
-                NpuBlockType.ReduceSum,
-            )
+        if npu_block_type in (
+            NpuBlockType.ConvolutionMxN,
+            NpuBlockType.ConvolutionDepthWise,
+            NpuBlockType.Pooling,
+            NpuBlockType.ReduceSum,
         ):
             # extent the ifm to full dimension
             ifm_tensor_brick_size = tuple(numeric_util.full_shape(4, list(ifm_tensor.brick_size), 1))
@@ -640,8 +638,8 @@
             n_kernel_xy = kernel_dims[0] * kernel_dims[1]
             n_input_channels_at_a_time = block_config[2]
 
-            if npu_block_type == NpuBlockType.Pooling or block_traversal in set(
-                (TensorBlockTraversal.PartKernelFirst, TensorBlockTraversal.DepthWise)
+            if (npu_block_type == NpuBlockType.Pooling) or (
+                block_traversal in (TensorBlockTraversal.PartKernelFirst, TensorBlockTraversal.DepthWise)
             ):
                 n_input_channels_at_a_time = numeric_util.round_up_divide(n_input_channels_at_a_time, 4)
                 n_kernel_xy = max(
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index ea2eaa4..9bc04f2 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -72,7 +72,7 @@
 
 binary_elem_wise_main_ops = Op.op_set(Op.is_binary_elementwise_op)
 
-unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)  # Unary element-wise operations
+unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
 
 elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
 
diff --git a/ethosu/vela/register_command_stream_util.py b/ethosu/vela/register_command_stream_util.py
index ca7e6bc..ce49fc2 100644
--- a/ethosu/vela/register_command_stream_util.py
+++ b/ethosu/vela/register_command_stream_util.py
@@ -46,7 +46,7 @@
 BASE_PTR_INDEX_MEM2MEM = int((1 << 8) | (3 << 0))
 
 
-UNARY_ELEMWISE_OPS = set((NpuElementWiseOp.ABS, NpuElementWiseOp.LRELU, NpuElementWiseOp.CLZ,))
+UNARY_ELEMWISE_OPS = (NpuElementWiseOp.ABS, NpuElementWiseOp.LRELU, NpuElementWiseOp.CLZ)
 
 
 def to_npu_kernel(kernel: Kernel) -> NpuKernel:
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index b07b4dc..f6e628c 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -557,8 +557,10 @@
         return self.consumer_list
 
     def get_address_ranges_for_coordinates(self, start_coord, end_coord):
-        if self.sub_purpose in set(
-            (TensorSubPurpose.RollingBufferX, TensorSubPurpose.RollingBufferY, TensorSubPurpose.RollingBufferXY)
+        if self.sub_purpose in (
+            TensorSubPurpose.RollingBufferX,
+            TensorSubPurpose.RollingBufferY,
+            TensorSubPurpose.RollingBufferXY,
         ):
             # build dummy coordinates that cover the entire buffer
             start_coord = [0] * len(start_coord)
@@ -637,7 +639,7 @@
                 augmented_shape[1] = 1
 
         else:
-            assert self.format in set((TensorFormat.Unknown, TensorFormat.WeightsCompressed))
+            assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
             return None, None
 
         strides = [0] * len(augmented_shape)
@@ -774,9 +776,7 @@
         return address_offset
 
     def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area):
-        if self.mem_area == scratch_tensor_mem_area and (self.mem_type in set((MemType.Scratch, MemType.Scratch_fast))):
-            return True
-        return False
+        return (self.mem_area == scratch_tensor_mem_area) and (self.mem_type in (MemType.Scratch, MemType.Scratch_fast))
 
     def equivalent(self, tens):
         return self.equivalence_id == tens.equivalence_id
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 93b97f6..df52478 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -121,7 +121,7 @@
         if dtype == DataType.uint8:
             tens.quantization.quant_min = 0
             tens.quantization.quant_max = (1 << dtype.bits) - 1
-        elif dtype in set((DataType.int8, DataType.int16, DataType.int32, DataType.int64)):
+        elif dtype in (DataType.int8, DataType.int16, DataType.int32, DataType.int64):
             tens.quantization.quant_min = -(1 << (dtype.bits - 1))
             tens.quantization.quant_max = (1 << (dtype.bits - 1)) - 1
 
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index e82fb5e..f747d47 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -76,7 +76,7 @@
         self.scratch_fast_buf_id = 1  # Always assign scratch_fast to buffer 1
         self.buffers_to_write = []  # have an empty array there
 
-        self.ops_to_ignore = set((Op.Const, Op.Placeholder, Op.SubgraphInput))
+        self.ops_to_ignore = (Op.Const, Op.Placeholder, Op.SubgraphInput)
 
         self.tensors_to_reshape = {}
 
@@ -405,7 +405,7 @@
         # Ensure that the order of the offsets match the order of the tensors
         for tens, idx in self.tensor_map.items():
             # Set offsets for tensor allocated in Tensor Arena or in the scratch_fast area
-            if tens.mem_type in set((MemType.Scratch, MemType.Scratch_fast)):
+            if tens.mem_type in (MemType.Scratch, MemType.Scratch_fast):
                 offsets[idx] = np.int32(tens.address) if tens.address is not None else np.int32(0)
 
         self.nng.metadata.append(("OfflineMemoryAllocation", np.array([version, subgraph_idx, nbr_tensors] + offsets)))
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 40ebcd0..fce17d1 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -535,8 +535,7 @@
 
             if ps.scale_tensor is not None:
                 rescale_for_faf = False
-                activation_ops = set((Op.Sigmoid, Op.Tanh))
-                if (ps.ops[-1].type in activation_ops) and (ps.npu_block_type != NpuBlockType.ElementWise):
+                if (ps.ops[-1].type in (Op.Sigmoid, Op.Tanh)) and (ps.npu_block_type != NpuBlockType.ElementWise):
                     rescale_for_faf = True
                 calc_scales_and_pack_biases(ps.scale_tensor, arch, ofm_depth_step, rescale_for_faf)
                 if ps.scale_tensor.ops[0].type == Op.DMA: