Add compute kernel writer arguments export

* The information is extracted from the prototype argument
  registry.

Partially resolves: COMPMID-6283
Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Change-Id: Ia6d69b7c2a2e411597e76a7e03b7c92199a16990
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9848
Reviewed-by: SiCong Li <sicong.li@arm.com>
Reviewed-by: Jakub Sujak <jakub.sujak@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/compute_kernel_writer/prototype/CMakeLists.txt b/compute_kernel_writer/prototype/CMakeLists.txt
index 0def9ea..3d6a192 100644
--- a/compute_kernel_writer/prototype/CMakeLists.txt
+++ b/compute_kernel_writer/prototype/CMakeLists.txt
@@ -34,6 +34,7 @@
     src/TileOperand.cpp
     src/TensorOperand.cpp
     src/TensorTileSampler.cpp
+    src/KernelArgument.cpp
 )
 
 target_compile_options(ckw_prototype
diff --git a/compute_kernel_writer/prototype/examples/add_exp_store.cpp b/compute_kernel_writer/prototype/examples/add_exp_store.cpp
index 9529268..6a98845 100644
--- a/compute_kernel_writer/prototype/examples/add_exp_store.cpp
+++ b/compute_kernel_writer/prototype/examples/add_exp_store.cpp
@@ -23,6 +23,7 @@
  */
 
 #include "ckw/Error.h"
+#include "ckw/KernelArgument.h"
 #include "ckw/KernelWriter.h"
 #include "ckw/TensorOperand.h"
 #include "ckw/TensorTileSampler.h"
@@ -163,9 +164,9 @@
     const TensorInfo src1_info(DataType::Fp32, TensorShape({ 3, 10, 20, 1, 1 }), TensorDataLayout::Nhwc, 1);
     const TensorInfo dst_info(DataType::Fp32, TensorShape({ 3, 10, 20, 1, 1 }), TensorDataLayout::Nhwc, 2);
 
-    ExampleComponentArgument src0(writer->declare_tensor_argument("src0", src0_info));
-    ExampleComponentArgument src1(writer->declare_tensor_argument("src1", src1_info));
-    ExampleComponentArgument dst(writer->declare_tensor_argument("dst", dst_info));
+    ExampleComponentArgument src0(writer->declare_tensor_argument("src0", src0_info, TensorStorageType::BufferUint8Ptr));
+    ExampleComponentArgument src1(writer->declare_tensor_argument("src1", src1_info, TensorStorageType::BufferUint8Ptr));
+    ExampleComponentArgument dst(writer->declare_tensor_argument("dst", dst_info, TensorStorageType::BufferUint8Ptr));
 
     ExampleComponentArgument ans;
 
@@ -173,6 +174,28 @@
     op_exp(writer, { &ans, &ans });
     op_store(writer, { &ans, &dst });
 
+    const auto arguments = kernel.arguments();
+
+    std::cout << "\n====================\nArguments:\n====================\n";
+
+    for(auto &arg : arguments)
+    {
+        switch(arg.type())
+        {
+            case ckw::KernelArgument::Type::TensorStorage:
+                std::cout << "* Tensor storage:   ID = " << arg.id() << ", type = " << std::hex << "0x" << static_cast<uint32_t>(arg.tensor_storage_type()) << std::dec << "\n";
+                break;
+
+            case ckw::KernelArgument::Type::TensorComponent:
+                std::cout << "* Tensor component: ID = " << arg.id() << ", type = " << std::hex << "0x" << static_cast<uint32_t>(arg.tensor_component_type()) << std::dec << "\n";
+                break;
+
+            default:
+                CKW_ASSERT(false);
+        }
+    }
+
+    std::cout << "\n====================\nCode:\n====================\n";
     const auto code = root_writer.generate_code();
     std::cout << code;
 
diff --git a/compute_kernel_writer/prototype/include/ckw/Kernel.h b/compute_kernel_writer/prototype/include/ckw/Kernel.h
index 527206f..3deb2ac 100644
--- a/compute_kernel_writer/prototype/include/ckw/Kernel.h
+++ b/compute_kernel_writer/prototype/include/ckw/Kernel.h
@@ -25,16 +25,20 @@
 #ifndef CKW_PROTOTYPE_INCLUDE_CKW_KERNEL_H
 #define CKW_PROTOTYPE_INCLUDE_CKW_KERNEL_H
 
+#include "ckw/KernelArgument.h"
 #include "ckw/OperandBase.h"
 #include "ckw/types/GpuTargetLanguage.h"
 
 #include <map>
 #include <memory>
 #include <string>
+#include <vector>
 
 namespace ckw
 {
 
+class TileOperand;
+
 namespace prototype
 {
 class GpuKernelWriterDataHolder;
@@ -57,11 +61,20 @@
     /** Get the name of the kernel function. */
     const std::string &name() const;
 
-    /** (Internal use only) Get the map from operand name to the operand declared in this kernel. */
-    const ::std::map<::std::string, ::std::unique_ptr<OperandBase>> &operands() const;
+    /** Get the list of kernel arguments. */
+    ::std::vector<KernelArgument> arguments() const;
 
-    /** (Internal use only) Get the map from operand name to the operand declared in this kernel. */
-    ::std::map<::std::string, ::std::unique_ptr<OperandBase>> &operands();
+    /** (Internal use only) Register the tile operand.
+     *
+     * @param operand The tile operand to be registered.
+     */
+    TileOperand &register_operand(::std::unique_ptr<TileOperand> operand);
+
+    /** (Internal use only) Register the tensor operand.
+     *
+     * @param operand The tensor operand to be registered.
+     */
+    TensorOperand &register_operand(::std::unique_ptr<TensorOperand> operand);
 
     /** (Internal use only) Get the implementation data. */
     prototype::GpuKernelWriterDataHolder *impl();
@@ -70,6 +83,7 @@
     ::std::string                                             _name;
     ::std::unique_ptr<prototype::GpuKernelWriterDataHolder>   _kernel;
     ::std::map<::std::string, ::std::unique_ptr<OperandBase>> _operands;
+    ::std::map<int32_t, TensorOperand *>                      _tensor_id_operands;
 };
 
 } // namespace ckw
diff --git a/compute_kernel_writer/prototype/include/ckw/KernelArgument.h b/compute_kernel_writer/prototype/include/ckw/KernelArgument.h
new file mode 100644
index 0000000..af8bcde
--- /dev/null
+++ b/compute_kernel_writer/prototype/include/ckw/KernelArgument.h
@@ -0,0 +1,106 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_PROTOTYPE_INCLUDE_CKW_KERNELARGUMENT_H
+#define CKW_PROTOTYPE_INCLUDE_CKW_KERNELARGUMENT_H
+
+#include "ckw/TensorInfo.h"
+#include <cstdint>
+
+namespace ckw
+{
+
+class TensorOperand;
+class TensorComponentOperand;
+
+/** A kernel argument which can be either a tensor storage or a tensor component. */
+class KernelArgument
+{
+public:
+    /** The type of kernel argument. */
+    enum class Type : int32_t
+    {
+        /** The argument that provides the read and/or write access to the tensor data.
+         *
+         * See @ref ckw::TensorStorage to see the list of supported storage type.
+         */
+        TensorStorage,
+
+        /** The argument that provides extra information about the tensor.
+         *
+         * See @ref ckw::TensorComponent to see the list of supported component.
+         */
+        TensorComponent,
+    };
+
+    /** Initialize a new instance of kernel argument class for a tensor storage argument.
+     *
+     * @param[in] tensor The tensor whose storage is exposed to kernel arguments.
+     */
+    KernelArgument(TensorOperand &tensor);
+
+    /** Initialize a new instance of kernel argument class for a tensor component argument.
+     *
+     * @param[in] tensor_component The tensor component to be exposed to kernel arguments.
+     */
+    KernelArgument(TensorComponentOperand &tensor_component);
+
+    /** Get the type of kernel argument. */
+    Type type() const;
+
+    /** Get the argument ID.
+     *
+     * This method can be used to get the tensor info ID of both tensor storage and tensor component arguments.
+     */
+    int32_t id() const;
+
+    /** Get the type of tensor storage.
+     *
+     * This method can only be used for tensor storage argument.
+     */
+    TensorStorageType tensor_storage_type() const;
+
+    /** Get the tensor component type.
+     *
+     * This method can only be used for tensor component argument.
+     */
+    TensorComponentType tensor_component_type() const;
+
+private:
+    Type    _type;
+    int32_t _id;
+
+    union SubId
+    {
+        int32_t             unknown;
+        TensorStorageType   tensor_storage_type;
+        TensorComponentType tensor_component_type;
+    };
+
+    SubId _sub_id{ 0 };
+};
+
+} // namespace ckw
+
+#endif // CKW_PROTOTYPE_INCLUDE_CKW_KERNELARGUMENT_H
diff --git a/compute_kernel_writer/prototype/include/ckw/KernelWriter.h b/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
index 2bf443c..146fdac 100644
--- a/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
+++ b/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
@@ -88,12 +88,13 @@
 
     /** Declare a tensor argument.
      *
-     * @param[in] name The name of the tensor.
-     * @param[in] info The tensor info.
+     * @param[in] name         The name of the tensor.
+     * @param[in] info         The tensor info.
+     * @param[in] storage_type The tensor storage type.
      *
      * @return The @ref TensorOperand object.
      */
-    TensorOperand &declare_tensor_argument(const std::string &name, const TensorInfo &info);
+    TensorOperand &declare_tensor_argument(const std::string &name, const TensorInfo &info, TensorStorageType storage_type = TensorStorageType::BufferUint8Ptr);
 
     /** Declare a compile-time constant scalar argument.
      *
@@ -117,10 +118,9 @@
     TileOperand &declare_tile(const std::string &name, TArgs &&...args)
     {
         const auto var_name = generate_variable_name(name);
-        auto       operand  = new TileOperand(var_name, ::std::forward<TArgs>(args)...);
-        register_operand(operand, true);
+        auto       operand  = std::make_unique<TileOperand>(var_name, ::std::forward<TArgs>(args)...);
 
-        return *operand;
+        return declare_tile_operand(std::move(operand));
     }
 
     // =============================================================================================
@@ -272,14 +272,11 @@
      */
     ::std::string generate_variable_name(const std::string &name) const;
 
-    /** Register the operand to the kernel.
+    /** Declare the tile operand.
      *
-     * The operand is uniquely owned by the kernel afterward.
-     *
-     * @param[in] operand   The operand to be registered.
-     * @param[in] declaring Whether the tile declaration is generated.
+     * @param[in] operand   The tile operand to be declared.
      */
-    void register_operand(OperandBase *operand, bool declaring);
+    TileOperand &declare_tile_operand(std::unique_ptr<TileOperand> operand);
 
 private:
     Kernel                                                *_kernel;
diff --git a/compute_kernel_writer/prototype/include/ckw/TensorInfo.h b/compute_kernel_writer/prototype/include/ckw/TensorInfo.h
index 8eaa6ae..55f8101 100644
--- a/compute_kernel_writer/prototype/include/ckw/TensorInfo.h
+++ b/compute_kernel_writer/prototype/include/ckw/TensorInfo.h
@@ -67,7 +67,7 @@
  *  The data type is represented as an integer. The value of the integer value
  *  is assigned to retrieve the information through the @ref TensorComponentBitmask.
  */
-enum class TensorComponent : uint32_t
+enum class TensorComponentType : uint32_t
 {
     Unknown            = 0x00000000,
     OffsetFirstElement = 0x01000000,
@@ -88,7 +88,7 @@
 
 /** Compute Kernel Writer tensor storage. The tensor storage represents the type of tensor memory object.
  */
-enum class TensorStorage : uint32_t
+enum class TensorStorageType : uint32_t
 {
     Unknown            = 0x00000000,
     BufferUint8Ptr     = 0x01000000,
diff --git a/compute_kernel_writer/prototype/include/ckw/TensorOperand.h b/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
index 3a2509e..6d88932 100644
--- a/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
+++ b/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
@@ -48,10 +48,11 @@
 public:
     /** Initialize a new instance of @ref TensorOperand class.
      *
-     * @param[in] name       The name of the tensor.
-     * @param[in] info       The tensor info.
+     * @param[in] name         The name of the tensor.
+     * @param[in] info         The tensor info.
+     * @param[in] storage_type The tensor storage type.
      */
-    TensorOperand(const ::std::string &name, const TensorInfo &info);
+    TensorOperand(const ::std::string &name, const TensorInfo &info, TensorStorageType storage_type);
 
     /** No copy constructor. */
     TensorOperand(const TensorOperand &other) = delete;
@@ -71,6 +72,9 @@
     /** Get the tensor info. */
     TensorInfo &info();
 
+    /** Get the tensor storage type. */
+    TensorStorageType storage_type() const;
+
     /** Get the data type. */
     virtual DataType data_type() const override;
 
@@ -96,43 +100,44 @@
     TensorOperand &tile_sampler(const TensorTileSampler &value);
 
     /** Get the operand that contains the stride in y dimension of the tensor. */
-    TileOperand &stride1();
+    TensorComponentOperand &stride1();
 
     /** Get the operand that contains the stride in z dimension of the tensor. */
-    TileOperand &stride2();
+    TensorComponentOperand &stride2();
 
     /** Get the operand that contains the stride in w dimension of the tensor. */
-    TileOperand &stride3();
+    TensorComponentOperand &stride3();
 
     /** Get the operand that contains the stride in w dimension of the tensor. */
-    TileOperand &stride4();
+    TensorComponentOperand &stride4();
 
     /** Get the operand that contains the size of dimension 0 of the tensor. */
-    TileOperand &dim0();
+    TensorComponentOperand &dim0();
 
     /** Get the operand that contains the size of dimension 1 of the tensor. */
-    TileOperand &dim1();
+    TensorComponentOperand &dim1();
 
     /** Get the operand that contains the size of dimension 2 of the tensor. */
-    TileOperand &dim2();
+    TensorComponentOperand &dim2();
 
     /** Get the operand that contains the size of dimension 3 of the tensor. */
-    TileOperand &dim3();
+    TensorComponentOperand &dim3();
 
     /** Get the operand that contains the size of dimension 4 of the tensor. */
-    TileOperand &dim4();
+    TensorComponentOperand &dim4();
 
     /** Get the operand that contains the size of dimensions 1 and 2 collapsed. */
-    TileOperand &dim1_dim2();
+    TensorComponentOperand &dim1_dim2();
 
     /** Get the operand that contains the size of dimensions 1, 2 and 3 collapsed. */
-    TileOperand &dim1_dim2_dim3();
+    TensorComponentOperand &dim1_dim2_dim3();
 
     /** Get the operand that contains the offset in bytes to the first element. */
-    TileOperand &offset_first_element_in_bytes();
+    TensorComponentOperand &offset_first_element_in_bytes();
 
 private:
-    TensorInfo _info;
+    TensorInfo        _info;
+    TensorStorageType _storage_type;
 
     TileOperand      *_tile{ nullptr };
     TensorTileSampler _tile_sampler{};
@@ -161,10 +166,19 @@
 public:
     /** Initialize a new instance of @ref TensorComponentOperand class.
      *
-     * @param[in] name      The name of the operand.
+     * @param[in] tensor    The tensor operand.
      * @param[in] component The tensor info component.
      */
-    TensorComponentOperand(const ::std::string &name, TensorComponent component);
+    TensorComponentOperand(TensorOperand &tensor, TensorComponentType component);
+
+    /** Get the tensor operand. */
+    TensorOperand &tensor();
+
+    /** Get the tensor operand. */
+    const TensorOperand &tensor() const;
+
+    /** Get the tensor component. */
+    TensorComponentType component_type() const;
 
     /** (Internal use only) Create the implementation operand.
      *
@@ -173,7 +187,8 @@
     virtual prototype::Operand create_impl_operand(prototype::IGpuKernelWriter *writer) const override;
 
 private:
-    TensorComponent _component;
+    TensorOperand      &_tensor;
+    TensorComponentType _component;
 };
 
 } // namespace ckw
diff --git a/compute_kernel_writer/prototype/src/Kernel.cpp b/compute_kernel_writer/prototype/src/Kernel.cpp
index 692d504..884b69a 100644
--- a/compute_kernel_writer/prototype/src/Kernel.cpp
+++ b/compute_kernel_writer/prototype/src/Kernel.cpp
@@ -23,6 +23,7 @@
  */
 
 #include "ckw/Kernel.h"
+#include "ckw/TensorOperand.h"
 #include "ckw/types/GpuTargetLanguage.h"
 #include "src/Prototype.h"
 
@@ -30,7 +31,7 @@
 {
 
 Kernel::Kernel(const char *name, GpuTargetLanguage language)
-    : _name(name), _kernel(std::make_unique<prototype::GpuKernelWriterDataHolder>(language)), _operands{}
+    : _name(name), _kernel(std::make_unique<prototype::GpuKernelWriterDataHolder>(language)), _operands{}, _tensor_id_operands{}
 {
 }
 
@@ -43,14 +44,102 @@
     return _name;
 }
 
-const std::map<std::string, std::unique_ptr<OperandBase>> &Kernel::operands() const
+std::vector<KernelArgument> Kernel::arguments() const
 {
-    return _operands;
+    std::vector<KernelArgument> arguments;
+
+    const auto impl_args = _kernel->arguments.tensor_argument_declarations();
+
+    for(auto tensor_arg : impl_args)
+    {
+        auto tensor = _tensor_id_operands.at(tensor_arg->format().id);
+        arguments.push_back(*tensor);
+
+        for(auto component_arg : tensor_arg->component_declarations())
+        {
+            switch(component_arg)
+            {
+                case TensorComponentType::OffsetFirstElement:
+                    arguments.push_back(tensor->offset_first_element_in_bytes());
+                    break;
+
+                case TensorComponentType::Stride1:
+                    arguments.push_back(tensor->stride1());
+                    break;
+
+                case TensorComponentType::Stride2:
+                    arguments.push_back(tensor->stride2());
+                    break;
+
+                case TensorComponentType::Stride3:
+                    arguments.push_back(tensor->stride3());
+                    break;
+
+                case TensorComponentType::Stride4:
+                    arguments.push_back(tensor->stride4());
+                    break;
+
+                case TensorComponentType::Dim0:
+                    arguments.push_back(tensor->dim0());
+                    break;
+
+                case TensorComponentType::Dim1:
+                    arguments.push_back(tensor->dim1());
+                    break;
+
+                case TensorComponentType::Dim2:
+                    arguments.push_back(tensor->dim2());
+                    break;
+
+                case TensorComponentType::Dim3:
+                    arguments.push_back(tensor->dim3());
+                    break;
+
+                case TensorComponentType::Dim4:
+                    arguments.push_back(tensor->dim4());
+                    break;
+
+                case TensorComponentType::Dim1xDim2:
+                    arguments.push_back(tensor->dim1_dim2());
+                    break;
+
+                case TensorComponentType::Dim1xDim2xDim3:
+                    arguments.push_back(tensor->dim1_dim2_dim3());
+                    break;
+
+                default:
+                    CKW_ASSERT(false);
+            }
+        }
+    }
+
+    return arguments;
 }
 
-std::map<std::string, std::unique_ptr<OperandBase>> &Kernel::operands()
+TileOperand &Kernel::register_operand(std::unique_ptr<TileOperand> operand)
 {
-    return _operands;
+    const auto &name = operand->name();
+    auto        ptr  = operand.get();
+
+    CKW_ASSERT(_operands.find(name) == _operands.end());
+    _operands[name] = std::move(operand);
+
+    return *ptr;
+}
+
+TensorOperand &Kernel::register_operand(std::unique_ptr<TensorOperand> operand)
+{
+    const auto  id   = operand->info().id();
+    const auto &name = operand->name();
+    auto        ptr  = operand.get();
+
+    CKW_ASSERT(_tensor_id_operands.find(id) == _tensor_id_operands.end());
+    CKW_ASSERT(_operands.find(name) == _operands.end());
+
+    _tensor_id_operands[id] = operand.get();
+    _operands[name]         = std::move(operand);
+
+    return *ptr;
 }
 
 prototype::GpuKernelWriterDataHolder *Kernel::impl()
diff --git a/compute_kernel_writer/prototype/src/KernelArgument.cpp b/compute_kernel_writer/prototype/src/KernelArgument.cpp
new file mode 100644
index 0000000..2b4d7c8
--- /dev/null
+++ b/compute_kernel_writer/prototype/src/KernelArgument.cpp
@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ckw/KernelArgument.h"
+#include "ckw/Error.h"
+#include "ckw/TensorOperand.h"
+
+namespace ckw
+{
+
+KernelArgument::KernelArgument(TensorOperand &tensor)
+    : _type(Type::TensorStorage), _id(tensor.info().id())
+{
+    _sub_id.tensor_storage_type = tensor.storage_type();
+}
+
+KernelArgument::KernelArgument(TensorComponentOperand &tensor_component)
+    : _type(Type::TensorComponent), _id(tensor_component.tensor().info().id())
+{
+    _sub_id.tensor_component_type = tensor_component.component_type();
+}
+
+KernelArgument::Type KernelArgument::type() const
+{
+    return _type;
+}
+
+int32_t KernelArgument::id() const
+{
+    return _id;
+}
+
+TensorStorageType KernelArgument::tensor_storage_type() const
+{
+    CKW_ASSERT(_type == Type::TensorStorage);
+    return _sub_id.tensor_storage_type;
+}
+
+TensorComponentType KernelArgument::tensor_component_type() const
+{
+    CKW_ASSERT(_type == Type::TensorComponent);
+    return _sub_id.tensor_component_type;
+}
+
+} // namespace ckw
diff --git a/compute_kernel_writer/prototype/src/KernelWriter.cpp b/compute_kernel_writer/prototype/src/KernelWriter.cpp
index 73458ef..1ac9ede 100644
--- a/compute_kernel_writer/prototype/src/KernelWriter.cpp
+++ b/compute_kernel_writer/prototype/src/KernelWriter.cpp
@@ -24,6 +24,7 @@
 
 #include "ckw/KernelWriter.h"
 #include "ckw/Error.h"
+#include "ckw/TensorInfo.h"
 #include "ckw/TensorOperand.h"
 #include "src/Prototype.h"
 
@@ -85,26 +86,24 @@
 // Tensor and tile declaration
 // =================================================================================================
 
-TensorOperand &KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info)
+TensorOperand &KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info, TensorStorageType storage_type)
 {
     const auto var_name = generate_variable_name(name);
 
     _impl->declare_argument(var_name, create_impl_tensor_info(info));
 
-    auto operand = new TensorOperand(var_name, info);
-    register_operand(operand, false);
+    auto &operand = _kernel->register_operand(std::make_unique<TensorOperand>(var_name, info, storage_type));
 
-    return *operand;
+    return operand;
 }
 
 TileOperand &KernelWriter::declare_tile_argument(const std::string &name, int32_t value)
 {
     const auto var_name = generate_variable_name(name);
 
-    auto operand = new TileOperand(var_name, value);
-    register_operand(operand, false);
+    auto &operand = _kernel->register_operand(std::make_unique<TileOperand>(var_name, value));
 
-    return *operand;
+    return operand;
 }
 
 std::string KernelWriter::generate_variable_name(const std::string &name) const
@@ -116,21 +115,21 @@
     return var_name.str();
 }
 
-void KernelWriter::register_operand(OperandBase *operand, bool declaring)
+TileOperand &KernelWriter::declare_tile_operand(std::unique_ptr<TileOperand> operand_ptr)
 {
-    const auto &name     = operand->name();
-    auto       &operands = _kernel->operands();
+    auto       &operand = _kernel->register_operand(std::move(operand_ptr));
+    const auto &name    = operand.name();
 
-    CKW_ASSERT(operands.find(name) == operands.end());
-    operands[name] = std::unique_ptr<OperandBase>(operand);
-
-    if(declaring && !operand->is_constant())
+    if(!operand.is_constant())
     {
-        const auto tile = reinterpret_cast<TileOperand *>(operand);
+        const auto &info = operand.tile_info();
 
-        const auto &info = tile->tile_info();
-        _impl->declare_tile(tile->name(), prototype::TileInfo(info.data_type(), info.width(), info.height()));
+        _impl->declare_tile(
+            name,
+            prototype::TileInfo(info.data_type(), info.width(), info.height()));
     }
+
+    return operand;
 }
 
 // =================================================================================================
@@ -143,7 +142,7 @@
         tensor.name(),
         prototype::GpuSampler{
             sampler.format(),
-            prototype::GpuSamplerTensorStorage::BufferUint8Ptr,
+            prototype::to_gpu_tensor_storage(tensor.storage_type()),
             sampler.address_mode_x(),
             sampler.address_mode_y(),
             sampler.address_mode_z() });
@@ -164,7 +163,7 @@
         tensor.name(),
         prototype::GpuSampler{
             sampler.format(),
-            prototype::GpuSamplerTensorStorage::BufferUint8Ptr,
+            prototype::to_gpu_tensor_storage(tensor.storage_type()),
             sampler.address_mode_x(),
             sampler.address_mode_y(),
             sampler.address_mode_z() });
diff --git a/compute_kernel_writer/prototype/src/Prototype.h b/compute_kernel_writer/prototype/src/Prototype.h
index 18f284b..b9f1efa 100644
--- a/compute_kernel_writer/prototype/src/Prototype.h
+++ b/compute_kernel_writer/prototype/src/Prototype.h
@@ -561,7 +561,7 @@
     IndexMask = 0x0000000f,
 };
 
-enum class TensorComponentType : int32_t
+enum class TensorComponentGroup : int32_t
 {
     OffsetFirstElement = 0x00000100,
     Stride             = 0x00001000,
@@ -570,62 +570,39 @@
     Constant           = 0x01000000
 };
 
-enum class TensorComponent : int32_t
-{
-    Unknown            = 0x00000000,
-    OffsetFirstElement = 0x00000100,
-    Stride1            = 0x00001001,
-    Stride2            = 0x00001002,
-    Stride3            = 0x00001003,
-    Stride4            = 0x00001004,
-    Dim0               = 0x00010000,
-    Dim1               = 0x00010001,
-    Dim2               = 0x00010002,
-    Dim3               = 0x00010003,
-    Dim4               = 0x00010004,
-    C                  = 0x00010000, // Dim0
-    W                  = 0x00010001, // Dim1
-    H                  = 0x00010002, // Dim2
-    D                  = 0x00010003,
-    N                  = 0x00010004,
-    Dim1xDim2          = 0x00100021,
-    Dim1xDim2xDim3     = 0x00100321,
-    WxH                = 0x00100021,
-    WxHxD              = 0x00100321
-};
-
-inline std::string to_string(TensorComponent x)
+inline std::string to_string(TensorComponentType x)
 {
     switch(x)
     {
-        case TensorComponent::Unknown:
+        case TensorComponentType::Unknown:
             return "Unknown";
-        case TensorComponent::OffsetFirstElement:
+        case TensorComponentType::OffsetFirstElement:
             return "OffsetFirstElement";
-        case TensorComponent::Stride1:
+        case TensorComponentType::Stride1:
             return "Stride1";
-        case TensorComponent::Stride2:
+        case TensorComponentType::Stride2:
             return "Stride2";
-        case TensorComponent::Stride3:
+        case TensorComponentType::Stride3:
             return "Stride3";
-        case TensorComponent::Stride4:
+        case TensorComponentType::Stride4:
             return "Stride4";
-        case TensorComponent::Dim0:
+        case TensorComponentType::Dim0:
             return "Dim0";
-        case TensorComponent::Dim1:
+        case TensorComponentType::Dim1:
             return "Dim1";
-        case TensorComponent::Dim2:
+        case TensorComponentType::Dim2:
             return "Dim2";
-        case TensorComponent::Dim3:
+        case TensorComponentType::Dim3:
             return "Dim3";
-        case TensorComponent::Dim4:
+        case TensorComponentType::Dim4:
             return "Dim4";
-        case TensorComponent::Dim1xDim2:
+        case TensorComponentType::Dim1xDim2:
             return "Dim1xDim2";
-        case TensorComponent::Dim1xDim2xDim3:
+        case TensorComponentType::Dim1xDim2xDim3:
             return "Dim1xDim2xDim3";
         default:
             assert(false);
+            return "";
     }
 }
 
@@ -640,7 +617,7 @@
      *
      * @return  the tensor component as a string
      */
-    virtual std::string component(TensorComponent x) = 0;
+    virtual std::string component(TensorComponentType x) = 0;
 
     /** Method to get the tensor component type declaration as a string
      *
@@ -658,7 +635,7 @@
      *
      * @return a vector containing the tensor component declarations
      */
-    virtual std::vector<TensorComponent> component_declarations() const = 0;
+    virtual std::vector<TensorComponentType> component_declarations() const = 0;
 
     /** Method to get the name of the tensor argument.
      *
@@ -693,6 +670,50 @@
     Image3dWriteOnly = 0x0031
 };
 
+inline GpuTensorStorage to_gpu_tensor_storage(TensorStorageType s)
+{
+    switch(s)
+    {
+        case TensorStorageType::Unknown:
+            return GpuTensorStorage::Unknown;
+
+        case TensorStorageType::BufferUint8Ptr:
+            return GpuTensorStorage::BufferUint8Ptr;
+
+        case TensorStorageType::Texture2dReadOnly:
+            return GpuTensorStorage::Image2dReadOnly;
+
+        case TensorStorageType::Texture2dWriteOnly:
+            return GpuTensorStorage::Image2dWriteOnly;
+
+        default:
+            assert(false);
+            return GpuTensorStorage::Unknown;
+    }
+}
+
+inline TensorStorageType to_tensor_storage(GpuTensorStorage s)
+{
+    switch(s)
+    {
+        case GpuTensorStorage::Unknown:
+            return TensorStorageType::Unknown;
+
+        case GpuTensorStorage::BufferUint8Ptr:
+            return TensorStorageType::BufferUint8Ptr;
+
+        case GpuTensorStorage::Image2dReadOnly:
+            return TensorStorageType::Texture2dReadOnly;
+
+        case GpuTensorStorage::Image2dWriteOnly:
+            return TensorStorageType::Texture2dWriteOnly;
+
+        default:
+            assert(false);
+            return TensorStorageType::Unknown;
+    }
+}
+
 class IGpuTensorArgument : public ITensorArgument
 {
 public:
@@ -732,9 +753,9 @@
     }
 
     // Methods to override
-    std::string component(TensorComponent x) override
+    std::string component(TensorComponentType x) override
     {
-        if((static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentType::Constant)))
+        if((static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentGroup::Constant)))
         {
             int32_t idx = static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentIndex::IndexMask);
             return std::to_string(idx - 1);
@@ -742,19 +763,19 @@
 
         if(_return_by_value_when_possible)
         {
-            if((static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentType::Dimension)))
+            if((static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentGroup::Dimension)))
             {
                 int32_t idx = static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentIndex::IndexMask);
                 return std::to_string(_format.shape[idx]);
             }
 
-            if((static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentType::FoldedDimension)))
+            if((static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentGroup::FoldedDimension)))
             {
                 switch(x)
                 {
-                    case TensorComponent::Dim1xDim2:
+                    case TensorComponentType::Dim1xDim2:
                         return std::to_string(_format.shape[1] * _format.shape[2]);
-                    case TensorComponent::Dim1xDim2xDim3:
+                    case TensorComponentType::Dim1xDim2xDim3:
                         return std::to_string(_format.shape[1] * _format.shape[2] * _format.shape[2]);
                     default:
                         std::cout << "Unsupported folded dimension" << std::endl;
@@ -817,7 +838,7 @@
         return _storage_required;
     }
 
-    std::vector<TensorComponent> component_declarations() const override
+    std::vector<TensorComponentType> component_declarations() const override
     {
         return _components_required;
     }
@@ -845,31 +866,31 @@
         return var_name;
     }
 
-    std::string build_component_name(TensorComponent x) const
+    std::string build_component_name(TensorComponentType x) const
     {
         std::string var_name = _basename;
 
         switch(x)
         {
-            case TensorComponent::OffsetFirstElement:
+            case TensorComponentType::OffsetFirstElement:
                 return var_name + "_offset_first_element";
-            case TensorComponent::Stride1:
+            case TensorComponentType::Stride1:
                 return var_name + "_stride1";
-            case TensorComponent::Stride2:
+            case TensorComponentType::Stride2:
                 return var_name + "_stride2";
-            case TensorComponent::Stride3:
+            case TensorComponentType::Stride3:
                 return var_name + "_stride3";
-            case TensorComponent::Dim0:
+            case TensorComponentType::Dim0:
                 return var_name + "_dim0";
-            case TensorComponent::Dim1:
+            case TensorComponentType::Dim1:
                 return var_name + "_dim1";
-            case TensorComponent::Dim2:
+            case TensorComponentType::Dim2:
                 return var_name + "_dim2";
-            case TensorComponent::Dim3:
+            case TensorComponentType::Dim3:
                 return var_name + "_dim3";
-            case TensorComponent::Dim1xDim2:
+            case TensorComponentType::Dim1xDim2:
                 return var_name + "_dim1xdim2";
-            case TensorComponent::Dim1xDim2xDim3:
+            case TensorComponentType::Dim1xDim2xDim3:
                 return var_name + "_dim1xdim2xdim3";
             default:
                 std::cout << "Unsupported component" << std::endl;
@@ -881,7 +902,7 @@
 
     bool                          _return_by_value_when_possible{ false };
     std::vector<GpuTensorStorage> _storage_required{};
-    std::vector<TensorComponent>  _components_required{};
+    std::vector<TensorComponentType>  _components_required{};
 };
 
 /**
@@ -1745,15 +1766,7 @@
     ScalarTileCoord _coord{};
 };
 
-enum class GpuSamplerTensorStorage : int32_t
-{
-    Unknown          = static_cast<int32_t>(GpuTensorStorage::Unknown),
-    BufferUint8Ptr   = static_cast<int32_t>(GpuTensorStorage::BufferUint8Ptr),
-    Image2dReadOnly  = static_cast<int32_t>(GpuTensorStorage::Image2dReadOnly),
-    Image2dWriteOnly = static_cast<int32_t>(GpuTensorStorage::Image2dWriteOnly),
-    Image3dReadOnly  = static_cast<int32_t>(GpuTensorStorage::Image3dReadOnly),
-    Image3dWriteOnly = static_cast<int32_t>(GpuTensorStorage::Image2dWriteOnly),
-};
+using GpuSamplerTensorStorage = GpuTensorStorage;
 
 struct GpuSampler
 {
@@ -2098,37 +2111,37 @@
         return static_cast<DataType>(static_cast<int32_t>(x) & 0x00ff);
     }
 
-    TensorComponent to_tensor_component(OperandType x)
+    TensorComponentType to_tensor_component(OperandType x)
     {
         switch(x)
         {
             case OperandType::TensorDim0:
-                return TensorComponent::Dim0;
+                return TensorComponentType::Dim0;
             case OperandType::TensorDim1:
-                return TensorComponent::Dim1;
+                return TensorComponentType::Dim1;
             case OperandType::TensorDim2:
-                return TensorComponent::Dim2;
+                return TensorComponentType::Dim2;
             case OperandType::TensorDim3:
-                return TensorComponent::Dim3;
+                return TensorComponentType::Dim3;
             case OperandType::TensorDim4:
-                return TensorComponent::Dim4;
+                return TensorComponentType::Dim4;
             case OperandType::TensorStride1:
-                return TensorComponent::Stride1;
+                return TensorComponentType::Stride1;
             case OperandType::TensorStride2:
-                return TensorComponent::Stride2;
+                return TensorComponentType::Stride2;
             case OperandType::TensorStride3:
-                return TensorComponent::Stride3;
+                return TensorComponentType::Stride3;
             case OperandType::TensorStride4:
-                return TensorComponent::Stride4;
+                return TensorComponentType::Stride4;
             case OperandType::TensorDim1xDim2:
-                return TensorComponent::Dim1xDim2;
+                return TensorComponentType::Dim1xDim2;
             case OperandType::TensorDim1xDim2xDim3:
-                return TensorComponent::Dim1xDim2xDim3;
+                return TensorComponentType::Dim1xDim2xDim3;
             case OperandType::TensorDataOffset:
-                return TensorComponent::OffsetFirstElement;
+                return TensorComponentType::OffsetFirstElement;
             default:
                 assert(false);
-                return TensorComponent::Unknown;
+                return TensorComponentType::Unknown;
         }
     }
 
@@ -2174,7 +2187,7 @@
     // Dispatch stage
     GpuOutputSampler                                  output_sampler{};       // GpuOutputSampler, required for the dispatch stage
     std::vector<std::pair<int32_t, GpuTensorStorage>> list_tensor_storages;   // List of tensor storages, required for the dispatch stage
-    std::vector<std::pair<int32_t, TensorComponent>>  list_tensor_components; // List of tensor components (width, stride,..), required for the dispatch stage)
+    std::vector<std::pair<int32_t, TensorComponentType>>  list_tensor_components; // List of tensor components (width, stride,..), required for the dispatch stage)
 };
 
 // This function should produce an object with the source
@@ -2251,7 +2264,7 @@
         {
             case TensorSamplerFormat::C_WH_1:
             case TensorSamplerFormat::C_W_H:
-                return _tensor->component(TensorComponent::C);
+                return _tensor->component(TensorComponentType::Dim0);
             default:
                 std::cout << "Unsupported tensor format" << std::endl;
                 assert(false);
@@ -2265,9 +2278,9 @@
         switch(format)
         {
             case TensorSamplerFormat::C_WH_1:
-                return _tensor->component(TensorComponent::WxH);
+                return _tensor->component(TensorComponentType::Dim1xDim2);
             case TensorSamplerFormat::C_W_H:
-                return _tensor->component(TensorComponent::W);
+                return _tensor->component(TensorComponentType::Dim1);
             default:
                 std::cout << "Unsupported tensor format" << std::endl;
                 assert(false);
@@ -2283,7 +2296,7 @@
             case TensorSamplerFormat::C_WH_1:
                 return "1";
             case TensorSamplerFormat::C_W_H:
-                return _tensor->component(TensorComponent::H);
+                return _tensor->component(TensorComponentType::Dim2);
             default:
                 std::cout << "Unsupported tensor format" << std::endl;
                 assert(false);
@@ -2298,7 +2311,7 @@
         {
             case TensorSamplerFormat::C_WH_1:
             case TensorSamplerFormat::C_W_H:
-                return _tensor->component(TensorComponent::Stride1);
+                return _tensor->component(TensorComponentType::Stride1);
             default:
                 std::cout << "Unsupported tensor format" << std::endl;
                 assert(false);
@@ -2314,7 +2327,7 @@
             case TensorSamplerFormat::C_WH_1:
                 return "0";
             case TensorSamplerFormat::C_W_H:
-                return _tensor->component(TensorComponent::Stride2);
+                return _tensor->component(TensorComponentType::Stride2);
             default:
                 std::cout << "Unsupported tensor format" << std::endl;
                 assert(false);
@@ -2329,7 +2342,7 @@
         {
             case TensorSamplerFormat::C_WH_1:
             case TensorSamplerFormat::C_W_H:
-                return _tensor->component(TensorComponent::Stride3);
+                return _tensor->component(TensorComponentType::Stride3);
             default:
                 std::cout << "Unsupported tensor format" << std::endl;
                 assert(false);
@@ -3941,9 +3954,9 @@
         assert(x_off->format().dt == DataType::Int32);
         assert(y_off->format().dt == DataType::Int32);
 
-        const std::string width  = tensor->component(TensorComponent::W);
-        const std::string height = tensor->component(TensorComponent::H);
-        const std::string wxh    = tensor->component(TensorComponent::WxH);
+        const std::string width  = tensor->component(TensorComponentType::Dim1);
+        const std::string height = tensor->component(TensorComponentType::Dim2);
+        const std::string wxh    = tensor->component(TensorComponentType::Dim1xDim2);
         /*
         int x_s;
         int y_s;
diff --git a/compute_kernel_writer/prototype/src/TensorOperand.cpp b/compute_kernel_writer/prototype/src/TensorOperand.cpp
index 00ecc38..c6725d3 100644
--- a/compute_kernel_writer/prototype/src/TensorOperand.cpp
+++ b/compute_kernel_writer/prototype/src/TensorOperand.cpp
@@ -25,6 +25,7 @@
 #include "ckw/TensorOperand.h"
 #include "ckw/Error.h"
 #include "ckw/Kernel.h"
+#include "ckw/TensorInfo.h"
 #include "ckw/TileOperand.h"
 #include "src/Prototype.h"
 
@@ -34,11 +35,11 @@
 namespace
 {
 
-inline TensorComponentOperand &get_or_create_component(std::unique_ptr<TensorComponentOperand> &ptr, const ::std::string &name, TensorComponent component)
+TensorComponentOperand &get_or_create_component(TensorOperand &tensor, std::unique_ptr<TensorComponentOperand> &ptr, TensorComponentType component)
 {
     if(ptr == nullptr)
     {
-        ptr = std::make_unique<TensorComponentOperand>(name, component);
+        ptr = std::make_unique<TensorComponentOperand>(tensor, component);
     }
 
     return *ptr;
@@ -50,8 +51,8 @@
 // TensorOperand
 // =================================================================================================
 
-TensorOperand::TensorOperand(const std::string &name, const TensorInfo &info)
-    : OperandBase(name), _info(info)
+TensorOperand::TensorOperand(const std::string &name, const TensorInfo &info, TensorStorageType storage_type)
+    : OperandBase(name), _info(info), _storage_type(storage_type)
 {
 }
 
@@ -71,6 +72,11 @@
     return _info;
 }
 
+TensorStorageType TensorOperand::storage_type() const
+{
+    return _storage_type;
+}
+
 DataType TensorOperand::data_type() const
 {
     return _info.data_type();
@@ -113,75 +119,90 @@
     return *this;
 }
 
-TileOperand &TensorOperand::stride1()
+TensorComponentOperand &TensorOperand::stride1()
 {
-    return get_or_create_component(_stride1, name(), TensorComponent::Stride1);
+    return get_or_create_component(*this, _stride1, TensorComponentType::Stride1);
 }
 
-TileOperand &TensorOperand::stride2()
+TensorComponentOperand &TensorOperand::stride2()
 {
-    return get_or_create_component(_stride2, name(), TensorComponent::Stride2);
+    return get_or_create_component(*this, _stride2, TensorComponentType::Stride2);
 }
 
-TileOperand &TensorOperand::stride3()
+TensorComponentOperand &TensorOperand::stride3()
 {
-    return get_or_create_component(_stride3, name(), TensorComponent::Stride3);
+    return get_or_create_component(*this, _stride3, TensorComponentType::Stride3);
 }
 
-TileOperand &TensorOperand::stride4()
+TensorComponentOperand &TensorOperand::stride4()
 {
-    return get_or_create_component(_stride4, name(), TensorComponent::Stride4);
+    return get_or_create_component(*this, _stride4, TensorComponentType::Stride4);
 }
 
-TileOperand &TensorOperand::dim0()
+TensorComponentOperand &TensorOperand::dim0()
 {
-    return get_or_create_component(_dim0, name(), TensorComponent::Dim0);
+    return get_or_create_component(*this, _dim0, TensorComponentType::Dim0);
 }
 
-TileOperand &TensorOperand::dim1()
+TensorComponentOperand &TensorOperand::dim1()
 {
-    return get_or_create_component(_dim1, name(), TensorComponent::Dim1);
+    return get_or_create_component(*this, _dim1, TensorComponentType::Dim1);
 }
 
-TileOperand &TensorOperand::dim2()
+TensorComponentOperand &TensorOperand::dim2()
 {
-    return get_or_create_component(_dim2, name(), TensorComponent::Dim2);
+    return get_or_create_component(*this, _dim2, TensorComponentType::Dim2);
 }
 
-TileOperand &TensorOperand::dim3()
+TensorComponentOperand &TensorOperand::dim3()
 {
-    return get_or_create_component(_dim3, name(), TensorComponent::Dim3);
+    return get_or_create_component(*this, _dim3, TensorComponentType::Dim3);
 }
 
-TileOperand &TensorOperand::dim4()
+TensorComponentOperand &TensorOperand::dim4()
 {
-    return get_or_create_component(_dim4, name(), TensorComponent::Dim4);
+    return get_or_create_component(*this, _dim4, TensorComponentType::Dim4);
 }
 
-TileOperand &TensorOperand::dim1_dim2()
+TensorComponentOperand &TensorOperand::dim1_dim2()
 {
-    return get_or_create_component(_dim1_dim2, name(), TensorComponent::Dim1xDim2);
+    return get_or_create_component(*this, _dim1_dim2, TensorComponentType::Dim1xDim2);
 }
 
-TileOperand &TensorOperand::dim1_dim2_dim3()
+TensorComponentOperand &TensorOperand::dim1_dim2_dim3()
 {
-    return get_or_create_component(_dim1_dim2_dim3, name(), TensorComponent::Dim1xDim2xDim3);
+    return get_or_create_component(*this, _dim1_dim2_dim3, TensorComponentType::Dim1xDim2xDim3);
 }
 
-TileOperand &TensorOperand::offset_first_element_in_bytes()
+TensorComponentOperand &TensorOperand::offset_first_element_in_bytes()
 {
-    return get_or_create_component(_offset_first_element_in_bytes, name(), TensorComponent::OffsetFirstElement);
+    return get_or_create_component(*this, _offset_first_element_in_bytes, TensorComponentType::OffsetFirstElement);
 }
 
 // =================================================================================================
 // TensorComponentOperand
 // =================================================================================================
 
-TensorComponentOperand::TensorComponentOperand(const ::std::string &name, TensorComponent component)
-    : TileOperand(name, DataType::Int32), _component(component)
+TensorComponentOperand::TensorComponentOperand(TensorOperand &tensor, TensorComponentType component)
+    : TileOperand(tensor.name(), DataType::Int32), _tensor(tensor), _component(component)
 {
 }
 
+TensorOperand &TensorComponentOperand::tensor()
+{
+    return _tensor;
+}
+
+const TensorOperand &TensorComponentOperand::tensor() const
+{
+    return _tensor;
+}
+
+TensorComponentType TensorComponentOperand::component_type() const
+{
+    return _component;
+}
+
 prototype::Operand TensorComponentOperand::create_impl_operand(prototype::IGpuKernelWriter *writer) const
 {
     CKW_UNUSED(writer);
@@ -189,51 +210,51 @@
 
     switch(_component)
     {
-        case TensorComponent::OffsetFirstElement:
+        case TensorComponentType::OffsetFirstElement:
             type = prototype::OperandType::TensorDataOffset;
             break;
 
-        case TensorComponent::Stride1:
+        case TensorComponentType::Stride1:
             type = prototype::OperandType::TensorStride1;
             break;
 
-        case TensorComponent::Stride2:
+        case TensorComponentType::Stride2:
             type = prototype::OperandType::TensorStride2;
             break;
 
-        case TensorComponent::Stride3:
+        case TensorComponentType::Stride3:
             type = prototype::OperandType::TensorStride3;
             break;
 
-        case TensorComponent::Stride4:
+        case TensorComponentType::Stride4:
             type = prototype::OperandType::TensorStride4;
             break;
 
-        case TensorComponent::Dim0:
+        case TensorComponentType::Dim0:
             type = prototype::OperandType::TensorDim0;
             break;
 
-        case TensorComponent::Dim1:
+        case TensorComponentType::Dim1:
             type = prototype::OperandType::TensorDim1;
             break;
 
-        case TensorComponent::Dim2:
+        case TensorComponentType::Dim2:
             type = prototype::OperandType::TensorDim2;
             break;
 
-        case TensorComponent::Dim3:
+        case TensorComponentType::Dim3:
             type = prototype::OperandType::TensorDim3;
             break;
 
-        case TensorComponent::Dim4:
+        case TensorComponentType::Dim4:
             type = prototype::OperandType::TensorDim4;
             break;
 
-        case TensorComponent::Dim1xDim2:
+        case TensorComponentType::Dim1xDim2:
             type = prototype::OperandType::TensorDim1xDim2;
             break;
 
-        case TensorComponent::Dim1xDim2xDim3:
+        case TensorComponentType::Dim1xDim2xDim3:
             type = prototype::OperandType::TensorDim1xDim2xDim3;
             break;