Add meta-data to express dynamic shapes in ITensorInfo
Add `get_tensor_shape_state` and `set_tensor_shape_state` to inject
shape dynamism.
The state is represented by an array of integers which index maps to the
respective shape dimension index.
If -1 is passed as a dimension state then the corresponding dimension
is dynamic.
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I3a8a5ad109b90d4df8545b460a9f8dfcc13dfa0f
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4784
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/core/TensorInfo.h b/arm_compute/core/TensorInfo.h
index 31f2732..42a969e 100644
--- a/arm_compute/core/TensorInfo.h
+++ b/arm_compute/core/TensorInfo.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2019 Arm Limited.
+ * Copyright (c) 2016-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -224,6 +224,7 @@
ITensorInfo &set_num_channels(int num_channels) override;
ITensorInfo &set_format(Format format) override;
ITensorInfo &set_tensor_shape(const TensorShape &shape) override;
+ ITensorInfo &set_tensor_dims_state(const TensorDimsState &state) override;
ITensorInfo &set_quantization_info(const QuantizationInfo &quantization_info) override;
ITensorInfo &set_data_layout(const DataLayout &data_layout) override;
ITensorInfo &reset_padding() override;
@@ -262,6 +263,10 @@
{
return _tensor_shape;
}
+ const TensorDimsState &tensor_dims_state() const override
+ {
+ return _dims_state;
+ }
DataType data_type() const override
{
return _data_type;
@@ -288,18 +293,13 @@
}
bool is_dynamic() const override
{
- return _is_dynamic;
+ return std::find(std::cbegin(_dims_state), std::cend(_dims_state), -1) != std::cend(_dims_state);
}
ITensorInfo &set_is_resizable(bool is_resizable) override
{
_is_resizable = is_resizable;
return *this;
}
- ITensorInfo &set_is_dynamic(bool is_dynamic) override
- {
- _is_dynamic = is_dynamic;
- return *this;
- }
ValidRegion valid_region() const override
{
return _valid_region;
@@ -329,10 +329,10 @@
Strides _strides_in_bytes;
size_t _num_channels;
TensorShape _tensor_shape;
+ TensorDimsState _dims_state;
DataType _data_type;
Format _format;
bool _is_resizable;
- bool _is_dynamic;
ValidRegion _valid_region;
PaddingSize _padding;
QuantizationInfo _quantization_info;