COMPMID-979: Add NHWC data layout to the tensor's metadata (Part 2)
Change-Id: I24aa35a85834abf0c9954aba714aeae654615b44
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122646
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/arm_compute/core/Helpers.h b/arm_compute/core/Helpers.h
index c91299f..24ba521 100644
--- a/arm_compute/core/Helpers.h
+++ b/arm_compute/core/Helpers.h
@@ -655,7 +655,7 @@
*
* @return The int conversion of the requested data layout index.
*/
-inline int get_data_layout_dimension_index(const ITensorInfo &info, const DataLayoutDimension data_layout_dimension);
+inline size_t get_data_layout_dimension_index(const DataLayout data_layout, const DataLayoutDimension data_layout_dimension);
} // namespace arm_compute
#include "arm_compute/core/Helpers.inl"
diff --git a/arm_compute/core/Helpers.inl b/arm_compute/core/Helpers.inl
index ff85773..3db8369 100644
--- a/arm_compute/core/Helpers.inl
+++ b/arm_compute/core/Helpers.inl
@@ -370,9 +370,9 @@
return index;
}
-inline int get_data_layout_dimension_index(const ITensorInfo &info, const DataLayoutDimension data_layout_dimension)
+inline size_t get_data_layout_dimension_index(const DataLayout data_layout, const DataLayoutDimension data_layout_dimension)
{
- ARM_COMPUTE_ERROR_ON_MSG(info.data_layout() == DataLayout::UNKNOWN, "Cannot retrieve the dimension index for an unknown layout!");
+ ARM_COMPUTE_ERROR_ON_MSG(data_layout == DataLayout::UNKNOWN, "Cannot retrieve the dimension index for an unknown layout!");
/* Return the index based on the data layout
* [N C H W]
@@ -382,13 +382,13 @@
switch(data_layout_dimension)
{
case DataLayoutDimension::CHANNEL:
- return (info.data_layout() == DataLayout::NCHW) ? 2 : 0;
+ return (data_layout == DataLayout::NCHW) ? 2 : 0;
break;
case DataLayoutDimension::HEIGHT:
- return (info.data_layout() == DataLayout::NCHW) ? 1 : 2;
+ return (data_layout == DataLayout::NCHW) ? 1 : 2;
break;
case DataLayoutDimension::WIDTH:
- return (info.data_layout() == DataLayout::NCHW) ? 0 : 1;
+ return (data_layout == DataLayout::NCHW) ? 0 : 1;
break;
case DataLayoutDimension::BATCHES:
return 3;
diff --git a/arm_compute/core/ITensor.h b/arm_compute/core/ITensor.h
index 202b50a..1ef9c6d 100644
--- a/arm_compute/core/ITensor.h
+++ b/arm_compute/core/ITensor.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,7 +24,7 @@
#ifndef __ARM_COMPUTE_ITENSOR_H__
#define __ARM_COMPUTE_ITENSOR_H__
-#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/ITensorInfo.h"
#include <cstdint>
diff --git a/arm_compute/core/ITensorInfo.h b/arm_compute/core/ITensorInfo.h
index 50a1eb2..167fb41 100644
--- a/arm_compute/core/ITensorInfo.h
+++ b/arm_compute/core/ITensorInfo.h
@@ -132,6 +132,13 @@
* @return Dimension of the requested dimension
*/
virtual size_t dimension(size_t index) const = 0;
+ /** Return the size of the requested data layout dimension
+ *
+ * @param[in] dimension DataLayoutDimension of the dimension
+ *
+ * @return Dimension of the requested dimension
+ */
+ virtual size_t dimension(DataLayoutDimension dimension) const = 0;
/** The strides in bytes for accessing each dimension of the tensor
*
* @return Strides in bytes for each tensor dimension
diff --git a/arm_compute/core/SubTensorInfo.h b/arm_compute/core/SubTensorInfo.h
index f9ed99b..882e4ec 100644
--- a/arm_compute/core/SubTensorInfo.h
+++ b/arm_compute/core/SubTensorInfo.h
@@ -127,6 +127,11 @@
{
return _tensor_shape[index];
}
+ size_t dimension(DataLayoutDimension dimension) const override
+ {
+ ARM_COMPUTE_ERROR_ON(_parent == nullptr);
+ return get_data_layout_dimension_index(_parent->data_layout(), dimension);
+ }
const Strides &strides_in_bytes() const override
{
ARM_COMPUTE_ERROR_ON(_parent == nullptr);
diff --git a/arm_compute/core/TensorInfo.h b/arm_compute/core/TensorInfo.h
index 27cf5ba..97f9d03 100644
--- a/arm_compute/core/TensorInfo.h
+++ b/arm_compute/core/TensorInfo.h
@@ -28,6 +28,7 @@
#include "ITensorInfo.h"
#include "arm_compute/core/Coordinates.h"
+#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Strides.h"
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
@@ -228,6 +229,10 @@
{
return _tensor_shape[index];
}
+ size_t dimension(DataLayoutDimension dimension) const override
+ {
+ return get_data_layout_dimension_index(_data_layout, dimension);
+ }
const Strides &strides_in_bytes() const override
{
return _strides_in_bytes;