MLBEDSW-3148: Refactor Operation
- op.type is now an enum instead of a string
- Removed unused operator codes
- Refactored some attributes like npu_block_type, fused_activation_function
- Refactored operator index calculation
- Refactored a number of operator sets
Change-Id: I641f65ee375794b7aec42abc0664251ae37d78e8
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 5c2ddab..41e1529 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -37,6 +37,7 @@
from .npu_performance import PassCycles
from .numeric_util import full_shape
from .operation import NpuBlockType
+from .operation import Op
from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer
from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config
from .tensor import MemArea
@@ -254,11 +255,7 @@
self.pareto_max_candidates = 16
self.ifm_stream_npu_blocks = set(
- (
- NpuBlockType.ConvolutionMxN,
- NpuBlockType.ConvolutionDepthWise,
- NpuBlockType.Pooling,
- )
+ (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,)
)
num_pareto_metrics = 4
@@ -652,7 +649,7 @@
def avoid_for_cascading(self, pred_candidate):
for op in pred_candidate.ops:
if (
- op.type == "ConcatSliceWrite"
+ op.type == Op.ConcatSliceWrite
and self.arch.feature_map_storage_mem_area != self.arch.fast_storage_mem_area
):
# For SRAM spilling, concat op is avoided as predecessor
@@ -981,9 +978,9 @@
use_NHCWB16 = False
use_fast_storage = False
continue
- if op.type == "ReduceSum" and output.dtype == DataType.int32:
+ if op.type == Op.ReduceSum and output.dtype == DataType.int32:
use_NHCWB16 = False
- elif op.type == "Reshape":
+ elif op.type == Op.Reshape:
# Detect no-op reshapes by comparing their full input and output tensor shapes.
inshape = full_shape(4, op.inputs[0].shape, 1)
outshape = full_shape(4, op.outputs[0].shape, 1)
@@ -995,7 +992,7 @@
incompatible_consumers = [
(
not consumer.run_on_npu
- or consumer.type == "Reshape"
+ or consumer.type == Op.Reshape
or (consumer is last_op_in_subgraph)
)
for consumer in op.outputs[0].consumer_list