Port operations to CKW prototype

Resolves: COMPMID-6334
Signed-off-by: Nikolaj Jensen <nikolaj.jensen@arm.com>
Change-Id: I500d30f09daec4087eb3e7aecd1de77dc8fd53b4
Signed-off-by: Nikolaj Jensen <nikolaj.jensen@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9828
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Reviewed-by: Jakub Sujak <jakub.sujak@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/compute_kernel_writer/prototype/examples/add_exp_store.cpp b/compute_kernel_writer/prototype/examples/add_exp_store.cpp
index a9be049..9529268 100644
--- a/compute_kernel_writer/prototype/examples/add_exp_store.cpp
+++ b/compute_kernel_writer/prototype/examples/add_exp_store.cpp
@@ -27,7 +27,6 @@
 #include "ckw/TensorOperand.h"
 #include "ckw/TensorTileSampler.h"
 #include "ckw/TileOperand.h"
-#include "ckw/Types.h"
 
 #include "common/ExampleComponentArgument.h"
 #include "common/ExampleKernelWriter.h"
@@ -110,7 +109,7 @@
     auto &dst_tile = dst->tile();
 
     // Perform the operation.
-    writer->op_binary_expression(dst_tile, lhs_tile, rhs_tile, BinaryOp::Add);
+    writer->op_binary_expression(dst_tile, lhs_tile, BinaryOp::Add, rhs_tile);
 }
 
 void op_exp(ExampleScopedKernelWriter writer, std::vector<ExampleComponentArgument *> operands)
@@ -138,7 +137,7 @@
     auto &dst_tile = dst->tile();
 
     // Perform the operation.
-    writer->op_scalar_function(dst_tile, src_tile, ScalarUnaryFunction::Exp);
+    writer->op_unary_elementwise_function(dst_tile, UnaryFunction::Exp, src_tile);
 }
 
 void op_store(ExampleScopedKernelWriter writer, std::vector<ExampleComponentArgument *> operands)
@@ -164,9 +163,9 @@
     const TensorInfo src1_info(DataType::Fp32, TensorShape({ 3, 10, 20, 1, 1 }), TensorDataLayout::Nhwc, 1);
     const TensorInfo dst_info(DataType::Fp32, TensorShape({ 3, 10, 20, 1, 1 }), TensorDataLayout::Nhwc, 2);
 
-    ExampleComponentArgument src0(writer->create_tensor_argument("src0", src0_info));
-    ExampleComponentArgument src1(writer->create_tensor_argument("src1", src1_info));
-    ExampleComponentArgument dst(writer->create_tensor_argument("dst", dst_info));
+    ExampleComponentArgument src0(writer->declare_tensor_argument("src0", src0_info));
+    ExampleComponentArgument src1(writer->declare_tensor_argument("src1", src1_info));
+    ExampleComponentArgument dst(writer->declare_tensor_argument("dst", dst_info));
 
     ExampleComponentArgument ans;
 
diff --git a/compute_kernel_writer/prototype/include/ckw/Kernel.h b/compute_kernel_writer/prototype/include/ckw/Kernel.h
index 57a8a40..527206f 100644
--- a/compute_kernel_writer/prototype/include/ckw/Kernel.h
+++ b/compute_kernel_writer/prototype/include/ckw/Kernel.h
@@ -26,7 +26,7 @@
 #define CKW_PROTOTYPE_INCLUDE_CKW_KERNEL_H
 
 #include "ckw/OperandBase.h"
-#include "ckw/Types.h"
+#include "ckw/types/GpuTargetLanguage.h"
 
 #include <map>
 #include <memory>
diff --git a/compute_kernel_writer/prototype/include/ckw/KernelWriter.h b/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
index 3b15391..2bf443c 100644
--- a/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
+++ b/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
@@ -30,6 +30,9 @@
 #include "ckw/TensorOperand.h"
 #include "ckw/TileInfo.h"
 #include "ckw/TileOperand.h"
+#include "ckw/types/ConvertPolicy.h"
+#include "ckw/types/Functions.h"
+#include "ckw/types/Operators.h"
 
 #include <memory>
 
@@ -83,23 +86,23 @@
     // Tensor and tile declaration
     // =============================================================================================
 
-    /** Define a tensor argument.
+    /** Declare a tensor argument.
      *
      * @param[in] name The name of the tensor.
      * @param[in] info The tensor info.
      *
      * @return The @ref TensorOperand object.
      */
-    TensorOperand &create_tensor_argument(const char *name, const TensorInfo &info);
+    TensorOperand &declare_tensor_argument(const std::string &name, const TensorInfo &info);
 
-    /** Define a compile-time constant scalar argument.
+    /** Declare a compile-time constant scalar argument.
      *
      * @param[in] name  The name of the tile.
      * @param[in] value The value of the tile.
      *
      * @return The @ref TileOperand object.
      */
-    TileOperand &create_tile_argument(const char *name, int32_t value);
+    TileOperand &declare_tile_argument(const std::string &name, int32_t value);
 
     /** Declare a new tile.
      *
@@ -111,7 +114,7 @@
      * @return The @ref TileOperand object.
      */
     template <typename... TArgs>
-    TileOperand &declare_tile(const char *name, TArgs &&...args)
+    TileOperand &declare_tile(const std::string &name, TArgs &&...args)
     {
         const auto var_name = generate_variable_name(name);
         auto       operand  = new TileOperand(var_name, ::std::forward<TArgs>(args)...);
@@ -144,29 +147,103 @@
     // Data processing
     // =============================================================================================
 
-    /** Write assignment: `<dst> = <src>`.
+    /** Write assignment: `<dst> = <src>;`.
      *
-     * @param[in] dst The destination tile.
-     * @param[in] src The source tile.
+     * @param[out] dst The destination tile.
+     * @param[in]  src The source tile.
      */
-    void op_assign(TileOperand &dst, const TileOperand &src);
+    void op_assign(const TileOperand &dst, const TileOperand &src);
 
-    /** Write binary expression: `<dst> = <lhs> <op> <rhs>`.
+    /** Write the cast: `<dst> = convert_<dst.type><_sat>(<src>);`.
      *
-     * @param[in] dst The destination tile.
-     * @param[in] lhs The LHS operand.
-     * @param[in] rhs The RHS operand.
-     * @param[in] op  The binary operator.
+     * @param[out] dst      The destination tile.
+     * @param[in]  src      The source tile.
+     * @param[in]  policy   The policy governing the behavior of the cast.
      */
-    void op_binary_expression(TileOperand &dst, const TileOperand &lhs, const TileOperand &rhs, BinaryOp op);
+    void op_cast_expression(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy);
 
-    /** Write function applied to scalar value: `<dst> = <func>(<src>)`.
+    /** Write the unary expression: `<dst> = <op> <src>`.
      *
-     * @param[in] dst  The destination tile.
-     * @param[in] src  The source tile.
-     * @param[in] func The function to be applied to the source tile.
+     * @param[out]  dst The destination tile.
+     * @param[in]   op  The unary operator.
+     * @param[in]   src The source tile.
      */
-    void op_scalar_function(TileOperand &dst, const TileOperand &src, ScalarUnaryFunction func);
+    void op_unary_expression(const TileOperand &dst, UnaryOp op, const TileOperand &src);
+
+    /** Write binary expression: `<dst> = <lhs> <op> <rhs>;`.
+     *
+     * @param[out] dst  The destination tile.
+     * @param[in]  lhs  The LHS tile.
+     * @param[in]  op   The binary operator.
+     * @param[in]  rhs  The RHS tile.
+     */
+    void op_binary_expression(const TileOperand &dst, const TileOperand &lhs, BinaryOp op, const TileOperand &rhs);
+
+    /** Write function applied to scalar value: `<dst> = <func>(<src>);`.
+     *
+     * @param[out] dst  The destination tile.
+     * @param[in]  func The function to be applied to the source tile.
+     * @param[in]  src  The source tile.
+     */
+    void op_unary_elementwise_function(const TileOperand &dst, UnaryFunction func, const TileOperand &src);
+
+    /** Write function applied to scalar value: `<dst> = <func>(<first>, <second>);`.
+     *
+     * @param[out] dst      The destination tile.
+     * @param[in]  func     The function to be applied to the source tiles.
+     * @param[in]  first    The first argument tile.
+     * @param[in]  second   The second argument tile.
+     */
+    void op_binary_elementwise_function(const TileOperand &dst, BinaryFunction func, const TileOperand &first, const TileOperand &second);
+
+    /** Write function applied to scalar value: `<dst> = <func>(<first>, <second>, <third>);`.
+     *
+     * @param[out] dst      The destination tile.
+     * @param[in]  func     The function to be applied to the source tiles.
+     * @param[in]  first    The first argument tile.
+     * @param[in]  second   The second argument tile.
+     * @param[in]  third    The third argument tile.
+     */
+    void op_ternary_elementwise_function(const TileOperand &dst, TernaryFunction func, const TileOperand &first, const TileOperand &second, const TileOperand &third);
+
+    /** Write if-statement: `if(<lhs> <op> <rhs>) { <body> }`.
+     *
+     * @param[in] lhs   The LHS tile of the condition.
+     * @param[in] op    The relational binary operator.
+     * @param[in] rhs   The RHS tile of the condition.
+     * @param[in] body  The body of the if-statement.
+     */
+    void op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body);
+
+    /** Write else-if-statement: `else if(<lhs> <op> <rhs>) { <body> }`.
+     *
+     * @param[in] lhs   The LHS tile of the condition.
+     * @param[in] op    The relational binary operator.
+     * @param[in] rhs   The RHS tile of the condition.
+     * @param[in] body  The body of the else-if-statement.
+     */
+    void op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body);
+
+    /** Write an else-statement: `else { <body> }`.
+     *
+     * @param[in] body The body of the else-statement.
+     */
+    void op_else(const std::function<void()> &body);
+
+    /** Write for-loops: `for(; <var> <cond_op> <cond_value>; <update_op> <update_value>) { body }`.
+     *
+     * @param[in]       var_name          The name of the variable used in condition.
+     * @param[in]       cond_op           The relational binary operator used in condition.
+     * @param[in]       cond_value_name   The value which the variable is compared against.
+     * @param[in]       update_op         The assignment operator used for updating the update value.
+     * @param[in, out]  update_value      The value which is updated at every iteration.
+     * @param[in]       body              The body of the for-loop.
+     */
+    void op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body);
+
+    /** Write the return statement: `return;`
+     */
+    void op_return();
 
     // =============================================================================================
     // Misc
@@ -174,8 +251,8 @@
 
     /** Set `dst` the global ID of dimension `dim`.
      *
-     * @param[in] dst The tile to be written to.
-     * @param[in] dim The global ID dimension.
+     * @param[out] dst The tile to be written to.
+     * @param[in]  dim The global ID dimension.
      */
     void op_get_global_id(TileOperand &dst, int32_t dim);
 
@@ -193,7 +270,7 @@
      *
      * @return The full variable name.
      */
-    ::std::string generate_variable_name(const char *name) const;
+    ::std::string generate_variable_name(const std::string &name) const;
 
     /** Register the operand to the kernel.
      *
diff --git a/compute_kernel_writer/prototype/include/ckw/OperandBase.h b/compute_kernel_writer/prototype/include/ckw/OperandBase.h
index a9e313f..06d9f82 100644
--- a/compute_kernel_writer/prototype/include/ckw/OperandBase.h
+++ b/compute_kernel_writer/prototype/include/ckw/OperandBase.h
@@ -25,7 +25,7 @@
 #ifndef CKW_PROTOTYPE_INCLUDE_CKW_OPERANDBASE_H
 #define CKW_PROTOTYPE_INCLUDE_CKW_OPERANDBASE_H
 
-#include "ckw/Types.h"
+#include "ckw/types/DataType.h"
 #include <string>
 
 namespace ckw
diff --git a/compute_kernel_writer/prototype/include/ckw/TensorInfo.h b/compute_kernel_writer/prototype/include/ckw/TensorInfo.h
index 8071588..8eaa6ae 100644
--- a/compute_kernel_writer/prototype/include/ckw/TensorInfo.h
+++ b/compute_kernel_writer/prototype/include/ckw/TensorInfo.h
@@ -25,7 +25,7 @@
 #ifndef CKW_PROTOTYPE_INCLUDE_CKW_TENSORINFO_H
 #define CKW_PROTOTYPE_INCLUDE_CKW_TENSORINFO_H
 
-#include "ckw/Types.h"
+#include "ckw/types/DataType.h"
 
 #include <array>
 #include <cstdint>
diff --git a/compute_kernel_writer/prototype/include/ckw/TensorOperand.h b/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
index 7a663f0..3a2509e 100644
--- a/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
+++ b/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
@@ -29,7 +29,7 @@
 #include "ckw/TensorInfo.h"
 #include "ckw/TensorTileSampler.h"
 #include "ckw/TileOperand.h"
-#include "ckw/Types.h"
+#include "ckw/types/DataType.h"
 
 #include <memory>
 
diff --git a/compute_kernel_writer/prototype/include/ckw/TensorTileSampler.h b/compute_kernel_writer/prototype/include/ckw/TensorTileSampler.h
index 2ea65bc..e1bf0c5 100644
--- a/compute_kernel_writer/prototype/include/ckw/TensorTileSampler.h
+++ b/compute_kernel_writer/prototype/include/ckw/TensorTileSampler.h
@@ -25,7 +25,7 @@
 #ifndef CKW_PROTOTYPE_INCLUDE_CKW_TENSORTILESAMPLER_H
 #define CKW_PROTOTYPE_INCLUDE_CKW_TENSORTILESAMPLER_H
 
-#include "ckw/Types.h"
+#include "ckw/types/TensorSamplerTypes.h"
 #include <functional>
 
 namespace ckw
diff --git a/compute_kernel_writer/prototype/include/ckw/TileInfo.h b/compute_kernel_writer/prototype/include/ckw/TileInfo.h
index c60880d..de9e47a 100644
--- a/compute_kernel_writer/prototype/include/ckw/TileInfo.h
+++ b/compute_kernel_writer/prototype/include/ckw/TileInfo.h
@@ -25,7 +25,7 @@
 #ifndef CKW_PROTOTYPE_INCLUDE_CKW_TILEINFO_H
 #define CKW_PROTOTYPE_INCLUDE_CKW_TILEINFO_H
 
-#include "ckw/Types.h"
+#include "ckw/types/DataType.h"
 
 #include <array>
 #include <cstdint>
diff --git a/compute_kernel_writer/prototype/include/ckw/types/ConvertPolicy.h b/compute_kernel_writer/prototype/include/ckw/types/ConvertPolicy.h
new file mode 100644
index 0000000..2a19850
--- /dev/null
+++ b/compute_kernel_writer/prototype/include/ckw/types/ConvertPolicy.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_INCLUDE_CKW_CONVERTPOLICY_H
+#define CKW_INCLUDE_CKW_CONVERTPOLICY_H
+
+#include <cstdint>
+
+namespace ckw
+{
+
+enum class ConvertPolicy : int32_t
+{
+    None     = 0, // No policy specified.
+    Saturate = 1, // Saturated.
+};
+
+} // namespace ckw
+
+#endif //CKW_INCLUDE_CKW_CONVERTPOLICY_H
diff --git a/compute_kernel_writer/prototype/include/ckw/types/DataType.h b/compute_kernel_writer/prototype/include/ckw/types/DataType.h
new file mode 100644
index 0000000..3447dd6
--- /dev/null
+++ b/compute_kernel_writer/prototype/include/ckw/types/DataType.h
@@ -0,0 +1,50 @@
+/*
+* Copyright (c) 2023 Arm Limited.
+*
+* SPDX-License-Identifier: MIT
+*
+* Permission is hereby granted, free of charge, to any person obtaining a copy
+* of this software and associated documentation files (the "Software"), to
+* deal in the Software without restriction, including without limitation the
+* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+* sell copies of the Software, and to permit persons to whom the Software is
+* furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in all
+* copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+* SOFTWARE.
+*/
+
+#ifndef CKW_INCLUDE_CKW_DATATYPE_H
+#define CKW_INCLUDE_CKW_DATATYPE_H
+
+#include <cstdint>
+
+namespace ckw
+{
+
+/** Compute Kernel Writer data types. This data type is used by the code variables and tensor arguments. */
+enum class DataType : int32_t
+{
+    Unknown = 0x00,
+    Fp32    = 0x11,
+    Fp16    = 0x12,
+    Int32   = 0x21,
+    Int16   = 0x22,
+    Int8    = 0x24,
+    Uint32  = 0x31,
+    Uint16  = 0x32,
+    Uint8   = 0x34,
+    Bool    = 0x41
+};
+
+} // namespace ckw
+
+#endif //CKW_INCLUDE_CKW_DATATYPE_H
diff --git a/compute_kernel_writer/prototype/include/ckw/types/Functions.h b/compute_kernel_writer/prototype/include/ckw/types/Functions.h
new file mode 100644
index 0000000..68146cb
--- /dev/null
+++ b/compute_kernel_writer/prototype/include/ckw/types/Functions.h
@@ -0,0 +1,61 @@
+/*
+* Copyright (c) 2023 Arm Limited.
+*
+* SPDX-License-Identifier: MIT
+*
+* Permission is hereby granted, free of charge, to any person obtaining a copy
+* of this software and associated documentation files (the "Software"), to
+* deal in the Software without restriction, including without limitation the
+* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+* sell copies of the Software, and to permit persons to whom the Software is
+* furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in all
+* copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+* SOFTWARE.
+*/
+
+#ifndef CKW_INCLUDE_CKW_FUNCTIONS_H
+#define CKW_INCLUDE_CKW_FUNCTIONS_H
+
+#include <cstdint>
+
+namespace ckw
+{
+
+enum class UnaryFunction : int32_t
+{
+    Exp            = 0x0000,
+    Tanh           = 0x0001,
+    Sqrt           = 0x0002,
+    Erf            = 0x0003,
+    Fabs           = 0x0004,
+    IsGreaterEqual = 0x0005,
+    Log            = 0x0006,
+    Round          = 0x0007,
+
+    // Misc
+    SizeOf = 0x0008,
+};
+
+enum class BinaryFunction : int32_t
+{
+    Min  = 0x0000,
+    Max  = 0x0001,
+};
+
+enum class TernaryFunction : int32_t
+{
+    Select = 0x0000,
+};
+
+} // namespace ckw
+
+#endif //CKW_INCLUDE_CKW_FUNCTIONS_H
diff --git a/compute_kernel_writer/prototype/include/ckw/types/GpuTargetLanguage.h b/compute_kernel_writer/prototype/include/ckw/types/GpuTargetLanguage.h
new file mode 100644
index 0000000..6c08617
--- /dev/null
+++ b/compute_kernel_writer/prototype/include/ckw/types/GpuTargetLanguage.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_INCLUDE_CKW_GPUTARGETLANGUAGE_H
+#define CKW_INCLUDE_CKW_GPUTARGETLANGUAGE_H
+
+#include <cstdint>
+
+namespace ckw
+{
+
+enum class GpuTargetLanguage : int32_t
+{
+    Unknown,
+    OpenCL
+};
+
+} // namespace ckw
+
+#endif //CKW_INCLUDE_CKW_GPUTARGETLANGUAGE_H
diff --git a/compute_kernel_writer/prototype/include/ckw/types/Operators.h b/compute_kernel_writer/prototype/include/ckw/types/Operators.h
new file mode 100644
index 0000000..78027f1
--- /dev/null
+++ b/compute_kernel_writer/prototype/include/ckw/types/Operators.h
@@ -0,0 +1,74 @@
+/*
+* Copyright (c) 2023 Arm Limited.
+*
+* SPDX-License-Identifier: MIT
+*
+* Permission is hereby granted, free of charge, to any person obtaining a copy
+* of this software and associated documentation files (the "Software"), to
+* deal in the Software without restriction, including without limitation the
+* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+* sell copies of the Software, and to permit persons to whom the Software is
+* furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in all
+* copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+* SOFTWARE.
+*/
+
+#ifndef CKW_INCLUDE_CKW_OPERATORS_H
+#define CKW_INCLUDE_CKW_OPERATORS_H
+
+#include <cstdint>
+
+namespace ckw
+{
+
+enum class UnaryOp : int32_t
+{
+    LogicalNot = 0x0000, // !
+};
+
+/* Binary operations
+*/
+enum class BinaryOp : int32_t
+{
+    // Elementwise
+    Add = 0x0000, // +
+    Sub = 0x0001, // -
+    Mul = 0x0002, // *
+    Div = 0x0003, // /
+    Mod = 0x0004, // %
+    // Relational
+    Equal        = 0x1000, // ==
+    Less         = 0x1001, // <
+    LessEqual    = 0x1002, // <=
+    Greater      = 0x1003, // >
+    GreaterEqual = 0x1004, // >=
+    // Algebra
+    MatMul_Nt_Nt = 0x2000, // X
+    MatMul_Nt_T  = 0x2001, // X
+    MatMul_T_Nt  = 0x2002, // X
+    MatMul_T_T   = 0x2003, // X
+    Dot          = 0x2004, // .
+    // Logical
+    LogicalAnd = 0x3000, // &&
+    LogicalOr  = 0x3001, // ||
+};
+
+enum class AssignmentOp : int32_t
+{
+    // Unary
+    Increment = 0x0000, // +=
+    Decrement = 0x0001, // -=
+};
+
+} // namespace ckw
+
+#endif //CKW_INCLUDE_CKW_OPERATORS_H
diff --git a/compute_kernel_writer/prototype/include/ckw/Types.h b/compute_kernel_writer/prototype/include/ckw/types/TensorSamplerTypes.h
similarity index 72%
rename from compute_kernel_writer/prototype/include/ckw/Types.h
rename to compute_kernel_writer/prototype/include/ckw/types/TensorSamplerTypes.h
index bb5d7ce..836bd13 100644
--- a/compute_kernel_writer/prototype/include/ckw/Types.h
+++ b/compute_kernel_writer/prototype/include/ckw/types/TensorSamplerTypes.h
@@ -22,76 +22,14 @@
  * SOFTWARE.
  */
 
-#ifndef CKW_PROTOTYPE_INCLUDE_CKW_TYPES_H
-#define CKW_PROTOTYPE_INCLUDE_CKW_TYPES_H
+#ifndef CKW_INCLUDE_CKW_TENSORSAMPLERTYPES_H
+#define CKW_INCLUDE_CKW_TENSORSAMPLERTYPES_H
 
-#include <array>
 #include <cstdint>
 
 namespace ckw
 {
 
-/** Compute Kernel Writer data types. This data type is used by the code variables and tensor arguments. */
-enum class DataType
-{
-    Unknown = 0x00,
-    Fp32    = 0x11,
-    Fp16    = 0x12,
-    Int32   = 0x21,
-    Int16   = 0x22,
-    Int8    = 0x24,
-    Uint32  = 0x31,
-    Uint16  = 0x32,
-    Uint8   = 0x34,
-    Bool    = 0x41
-};
-
-enum class GpuTargetLanguage
-{
-    Unknown,
-    OpenCL
-};
-
-/* Binary operations
-*/
-enum class BinaryOp : int32_t
-{
-    // Elementwise
-    Add = 0x0000, // +
-    Sub = 0x0001, // -
-    Mul = 0x0002, // *
-    Div = 0x0003, // /
-    Mod = 0x0004, // %
-    // Relational
-    Equal        = 0x1000, // ==
-    Less         = 0x1001, // <
-    LessEqual    = 0x1002, // <=
-    Greater      = 0x1003, // >
-    GreaterEqual = 0x1004, // >=
-    // Algebra
-    MatMul_Nt_Nt = 0x2000, // X
-    MatMul_Nt_T  = 0x2001, // X
-    MatMul_T_Nt  = 0x2002, // X
-    MatMul_T_T   = 0x2003, // X
-    Dot          = 0x2004, // .
-    // Logical
-    LogicalAnd = 0x3000, // &&
-    LogicalOr  = 0x3001, // ||
-    LogicalNot = 0x3002  // !
-};
-
-enum class AssignmentOp : int32_t
-{
-    // Unary
-    Increment = 0x0000, // +=
-    Decrement = 0x0001, // -=
-};
-
-enum class ScalarUnaryFunction : int32_t
-{
-    Exp,
-};
-
 enum class TensorSamplerFormat : int32_t
 {
     Unknown = 0,
@@ -137,4 +75,4 @@
 
 } // namespace ckw
 
-#endif // CKW_PROTOTYPE_INCLUDE_CKW_TYPES_H
+#endif //CKW_INCLUDE_CKW_TENSORSAMPLERTYPES_H
diff --git a/compute_kernel_writer/prototype/src/Kernel.cpp b/compute_kernel_writer/prototype/src/Kernel.cpp
index bbf5c44..692d504 100644
--- a/compute_kernel_writer/prototype/src/Kernel.cpp
+++ b/compute_kernel_writer/prototype/src/Kernel.cpp
@@ -23,7 +23,7 @@
  */
 
 #include "ckw/Kernel.h"
-#include "ckw/Types.h"
+#include "ckw/types/GpuTargetLanguage.h"
 #include "src/Prototype.h"
 
 namespace ckw
diff --git a/compute_kernel_writer/prototype/src/KernelWriter.cpp b/compute_kernel_writer/prototype/src/KernelWriter.cpp
index 5d79985..73458ef 100644
--- a/compute_kernel_writer/prototype/src/KernelWriter.cpp
+++ b/compute_kernel_writer/prototype/src/KernelWriter.cpp
@@ -85,7 +85,7 @@
 // Tensor and tile declaration
 // =================================================================================================
 
-TensorOperand &KernelWriter::create_tensor_argument(const char *name, const TensorInfo &info)
+TensorOperand &KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info)
 {
     const auto var_name = generate_variable_name(name);
 
@@ -97,7 +97,7 @@
     return *operand;
 }
 
-TileOperand &KernelWriter::create_tile_argument(const char *name, int32_t value)
+TileOperand &KernelWriter::declare_tile_argument(const std::string &name, int32_t value)
 {
     const auto var_name = generate_variable_name(name);
 
@@ -107,7 +107,7 @@
     return *operand;
 }
 
-std::string KernelWriter::generate_variable_name(const char *name) const
+std::string KernelWriter::generate_variable_name(const std::string &name) const
 {
     std::stringstream var_name;
 
@@ -181,7 +181,7 @@
 // Data processing
 // =================================================================================================
 
-void KernelWriter::op_assign(TileOperand &dst, const TileOperand &src)
+void KernelWriter::op_assign(const TileOperand &dst, const TileOperand &src)
 {
     auto impl_dst = dst.create_impl_operand(_impl.get());
     auto impl_src = src.create_impl_operand(_impl.get());
@@ -189,7 +189,15 @@
     _impl->op_assign(impl_dst, impl_src);
 }
 
-void KernelWriter::op_binary_expression(TileOperand &dst, const TileOperand &lhs, const TileOperand &rhs, BinaryOp op)
+void KernelWriter::op_cast_expression(const TileOperand &dst, const TileOperand &src, const ConvertPolicy policy)
+{
+    auto impl_dst = dst.create_impl_operand(_impl.get());
+    auto impl_src = src.create_impl_operand(_impl.get());
+
+    _impl->op_cast_expression(impl_dst, impl_src, policy);
+}
+
+void KernelWriter::op_binary_expression(const TileOperand &dst, const TileOperand &lhs, BinaryOp op, const TileOperand &rhs)
 {
     auto impl_lhs = lhs.create_impl_operand(_impl.get());
     auto impl_rhs = rhs.create_impl_operand(_impl.get());
@@ -198,12 +206,81 @@
     _impl->op_binary_expression(impl_dst, impl_lhs, op, impl_rhs);
 }
 
-void KernelWriter::op_scalar_function(TileOperand &dst, const TileOperand &src, ScalarUnaryFunction opcode)
+void KernelWriter::op_unary_expression(const TileOperand &dst, UnaryOp op, const TileOperand &src)
 {
     auto impl_dst = dst.create_impl_operand(_impl.get());
     auto impl_src = src.create_impl_operand(_impl.get());
 
-    _impl->op_scalar_function(impl_dst, impl_src, opcode);
+    _impl->op_unary_expression(impl_dst, op, impl_src);
+}
+
+void KernelWriter::op_unary_elementwise_function(const TileOperand &dst, UnaryFunction opcode, const TileOperand &src)
+{
+    auto impl_dst = dst.create_impl_operand(_impl.get());
+    auto impl_src = src.create_impl_operand(_impl.get());
+
+    _impl->op_unary_elementwise_function(impl_dst, opcode, impl_src);
+}
+
+void KernelWriter::op_binary_elementwise_function(const TileOperand &dst, BinaryFunction opcode, const TileOperand &first, const TileOperand &second)
+{
+    auto impl_dst    = dst.create_impl_operand(_impl.get());
+    auto impl_first  = first.create_impl_operand(_impl.get());
+    auto impl_second = second.create_impl_operand(_impl.get());
+
+    _impl->op_binary_elementwise_function(impl_dst, opcode, impl_first, impl_second);
+}
+
+void KernelWriter::op_ternary_elementwise_function(const TileOperand &dst, TernaryFunction opcode, const TileOperand &first, const TileOperand &second, const TileOperand &third)
+{
+    auto impl_dst    = dst.create_impl_operand(_impl.get());
+    auto impl_first  = first.create_impl_operand(_impl.get());
+    auto impl_second = second.create_impl_operand(_impl.get());
+    auto impl_third  = third.create_impl_operand(_impl.get());
+
+    _impl->op_ternary_elementwise_function(impl_dst, opcode, impl_first, impl_second, impl_third);
+}
+
+void KernelWriter::op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body)
+{
+    auto impl_lhs = lhs.create_impl_operand(_impl.get());
+    auto impl_rhs = rhs.create_impl_operand(_impl.get());
+
+    _impl->op_if_header(impl_lhs, op, impl_rhs);
+    _impl->compound_statement_begin();
+    body();
+    _impl->compound_statement_end();
+}
+
+void KernelWriter::op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body)
+{
+    auto impl_lhs = lhs.create_impl_operand(_impl.get());
+    auto impl_rhs = rhs.create_impl_operand(_impl.get());
+
+    _impl->op_else_if_header(impl_lhs, op, impl_rhs);
+    _impl->compound_statement_begin();
+    body();
+    _impl->compound_statement_end();
+}
+
+void KernelWriter::op_else(const std::function<void()> &body)
+{
+    _impl->op_else_header();
+    _impl->compound_statement_begin();
+    body();
+    _impl->compound_statement_end();
+}
+
+void KernelWriter::op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body)
+{
+    auto impl_var_name          = var_name.create_impl_operand(_impl.get());
+    auto impl_cond_value_name   = cond_value_name.create_impl_operand(_impl.get());
+    auto impl_update_value_name = update_value_name.create_impl_operand(_impl.get());
+
+    _impl->op_for_loop_header(impl_var_name, cond_op, impl_cond_value_name, update_op, impl_update_value_name);
+    _impl->compound_statement_begin();
+    body();
+    _impl->compound_statement_end();
 }
 
 // =================================================================================================
@@ -215,6 +292,11 @@
     _impl->op_get_global_id(prototype::Operand(dst.name()), dim);
 }
 
+void KernelWriter::op_return()
+{
+    _impl->op_return();
+}
+
 // =================================================================================================
 // Code generation
 // =================================================================================================
diff --git a/compute_kernel_writer/prototype/src/Prototype.h b/compute_kernel_writer/prototype/src/Prototype.h
index fdb4ab1..18f284b 100644
--- a/compute_kernel_writer/prototype/src/Prototype.h
+++ b/compute_kernel_writer/prototype/src/Prototype.h
@@ -31,6 +31,7 @@
 #include <chrono>
 #include <cmath>
 #include <cstdint>  // int32_t
+#include <functional>
 #include <iostream> // cout (to be removed)
 #include <map>
 #include <memory>
@@ -41,7 +42,12 @@
 
 #include "ckw/Error.h"
 #include "ckw/TensorInfo.h"
-#include "ckw/Types.h"
+#include "ckw/types/ConvertPolicy.h"
+#include "ckw/types/DataType.h"
+#include "ckw/types/Functions.h"
+#include "ckw/types/GpuTargetLanguage.h"
+#include "ckw/types/Operators.h"
+#include "ckw/types/TensorSamplerTypes.h"
 
 namespace ckw
 {
@@ -1548,6 +1554,18 @@
     }
 }
 
+inline std::string to_string(UnaryOp op)
+{
+    switch(op)
+    {
+        case UnaryOp::LogicalNot:
+            return "!";
+        default:
+            assert(false);
+            return "";
+    }
+}
+
 inline std::string to_string(BinaryOp op)
 {
     switch(op)
@@ -1576,8 +1594,6 @@
             return "&&";
         case BinaryOp::LogicalOr:
             return "||";
-        case BinaryOp::LogicalNot:
-            return "!";
         default:
             assert(false);
             return "";
@@ -2407,12 +2423,6 @@
     bool return_tensor_component_by_value{ false };
 };
 
-enum class ConvertPolicy
-{
-    Wrap,    /**< Wrap around */
-    Saturate /**< Saturate */
-};
-
 enum class RoundingMode
 {
     None,
@@ -2445,36 +2455,44 @@
     virtual void compound_statement_end() = 0;
 
     // Operations
-    virtual void op_get_global_id(const Operand &dst_var, int32_t dim) = 0;
+    virtual void op_get_global_id(const Operand &dst_var, int32_t dim)                                                                                                                                                                                                   = 0;
 
-    virtual void op_get_global_coord(const Operand &dst, const Operand &step, const TensorOperand &tensor, int32_t dim) = 0;
+    virtual void op_get_global_coord(const Operand &dst, const Operand &step, const TensorOperand &tensor, int32_t dim)                                                                                                                                                  = 0;
 
-    virtual void op_get_global_batch(const Operand &dst, const TensorOperand &tensor) = 0;
+    virtual void op_get_global_batch(const Operand &dst, const TensorOperand &tensor)                                                                                                                                                                                    = 0;
 
-    virtual void op_get_global_size(const Operand &dst_var, int32_t dim) = 0;
+    virtual void op_get_global_size(const Operand &dst_var, int32_t dim)                                                                                                                                                                                                 = 0;
 
-    virtual void op_binary_expression(const Operand &dst, const Operand &lhs, BinaryOp op, const Operand &rhs) = 0;
+    virtual void op_unary_expression(const Operand &dst, UnaryOp op, const Operand &src)                                                                                                                                                                                 = 0;
 
-    virtual void op_assign(const Operand &dst_name, const Operand &src_name) = 0;
+    virtual void op_binary_expression(const Operand &dst, const Operand &lhs, BinaryOp op, const Operand &rhs)                                                                                                                                                           = 0;
 
-    virtual void op_scalar_function(const Operand &dst_name, const Operand &src_name, ScalarUnaryFunction func) = 0;
+    virtual void op_assign(const Operand &dst_name, const Operand &src_name)                                                                                                                                                                                             = 0;
 
-    virtual void op_if(const Operand &lhs, BinaryOp op, const Operand &rhs) = 0;
+    virtual void op_unary_elementwise_function(const Operand &dst_name, UnaryFunction func, const Operand &src_name)                                                                                                                                                     = 0;
 
-    virtual void op_for_loop(const Operand &var_name, BinaryOp cond_op, const Operand &cond_value, AssignmentOp update_op, const Operand &update_value) = 0;
+    virtual void op_binary_elementwise_function(const Operand &dst_name, BinaryFunction func, const Operand &first_name, const Operand &second_name)                                                                                                                     = 0;
 
-    virtual void op_load_indirect(const TensorOperand &tensor, const Operand &dst, const Operand &x, const Operand &y_indirect, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0;
+    virtual void op_ternary_elementwise_function(const Operand &dst_name, TernaryFunction func, const Operand &first_name, const Operand &second_name, const Operand &third_name)                                                                                        = 0;
+
+    virtual void op_if_header(const Operand &lhs, BinaryOp op, const Operand &rhs)                                                                                                                                                                                       = 0;
+
+    virtual void op_else_if_header(const Operand &lhs, BinaryOp op, const Operand &rhs)                                                                                                                                                                                  = 0;
+
+    virtual void op_else_header()                                                                                                                                                                                                                                        = 0;
+
+    virtual void op_for_loop_header(const Operand &var_name, BinaryOp cond_op, const Operand &cond_value, AssignmentOp update_op, const Operand &update_value)                                                                                                           = 0;
+
+    virtual void op_load_indirect(const TensorOperand &tensor, const Operand &dst, const Operand &x, const Operand &y_indirect, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32))                                                             = 0;
 
     virtual void op_load_immediate(const TensorOperand &tensor, const Operand &dst, const Operand &x, const Operand &y, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32), const Operand &dilation_y = Operand("1", OperandType::ScalarInt32)) = 0;
 
-    virtual void op_store_immediate(const TensorOperand &tensor, const Operand &src, const Operand &x, const Operand &y, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0;
+    virtual void op_store_immediate(const TensorOperand &tensor, const Operand &src, const Operand &x, const Operand &y, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32))                                                                    = 0;
 
-    virtual void op_cast_expression(const Operand &dst, const Operand &src, ConvertPolicy policy) = 0;
+    virtual void op_cast_expression(const Operand &dst, const Operand &src, ConvertPolicy policy)                                                                                                                                                                        = 0;
 
-    virtual void op_return() = 0;
+    virtual void op_return()                                                                                                                                                                                                                                             = 0;
 
-    // virtual void op_else() = 0;
-    // virtual void op_elseif() = 0;
     // Utils
     // It is the process of converting
     virtual void util_get_indirect_buffer(const Operand &dst, const TensorOperand &tensor, const Operand &x,
@@ -2929,10 +2947,10 @@
     std::string to_ls_buffer_address(const std::string &x, const std::string &y, const std::string &z,
                                      const std::string &b) const
     {
-        auto tensor_storage = static_cast<GpuTensorStorage>(_mapper.gpu_sampler().storage);
+        auto tensor_storage            = static_cast<GpuTensorStorage>(_mapper.gpu_sampler().storage);
         assert(tensor_storage == GpuTensorStorage::BufferUint8Ptr);
-        const std::string ptr_buf  = _mapper.tensor_argument()->storage(tensor_storage);
-        const std::string dst_type = get_cl_data_type(_dst->format().dt, 1);
+        const std::string ptr_buf      = _mapper.tensor_argument()->storage(tensor_storage);
+        const std::string dst_type     = get_cl_data_type(_dst->format().dt, 1);
 
         std::string address;
         address += "(__global ";
@@ -3135,7 +3153,6 @@
 
         auto              tensor_storage = static_cast<GpuTensorStorage>(_mapper.gpu_sampler().storage);
         const std::string image2d_obj    = _mapper.tensor_argument()->storage(tensor_storage);
-        // const DataType dt              = _dst->format().dt;
         const std::string post_fix = _dst->format().dt == DataType::Fp32 ? "f" : "h";
 
         switch(type)
@@ -3242,7 +3259,7 @@
 };
 
 // This utility method needs to go in utils.h
-inline bool is_tile_scalar(IVectorTile *x)
+inline bool is_tile_scalar(const IVectorTile *x)
 {
     return x->format().w == 1 && x->format().h == 1;
 }
@@ -3415,11 +3432,11 @@
 
     void op_get_global_batch(const Operand &o_dst, const TensorOperand &o_tensor) override
     {
-        OperandUnpacker operands(_data->tiles, _data->arguments);
-        auto            dst = operands.unpack(o_dst);
+        OperandUnpacker    operands(_data->tiles, _data->arguments);
+        const IVectorTile *dst = operands.unpack(o_dst);
 
         TensorOperandUnpacker tensor_operands(_data->arguments);
-        auto                  tensor      = tensor_operands.unpack(o_tensor);
+        IGpuTensorArgument   *tensor      = tensor_operands.unpack(o_tensor);
         auto                  gpu_sampler = o_tensor.sampler();
 
         GpuTensor3dMapper mapper(tensor, gpu_sampler);
@@ -3450,13 +3467,39 @@
         _data->code += ");\n";
     }
 
+    void op_unary_expression(const Operand &dst_name, UnaryOp op, const Operand &src_name) override
+    {
+        OperandUnpacker    operands(_data->tiles, _data->arguments);
+        const IVectorTile *src = operands.unpack(src_name);
+        const IVectorTile *dst = operands.unpack(dst_name);
+
+        const int32_t     dst_w = dst->format().w;
+        const int32_t     dst_h = dst->format().h;
+        const int32_t     src_w = src->format().w;
+        const std::string dt    = dst->underlying_source_variables()[0].type.str;
+
+        const bool broadcast_src_x = dst_w != 1 && src_w == 1;
+
+        const std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
+
+        // Broadcasting on Y is automatic
+        for(int32_t y = 0; y < dst_h; ++y)
+        {
+            _data->code += dst->vector(y).str;
+            _data->code += " = ";
+            _data->code += to_string(op);
+            _data->code += src_prefix + src->vector(y).str;
+            _data->code += ";\n";
+        }
+    }
+
     void op_binary_expression(const Operand &dst_name, const Operand &lhs_name, BinaryOp op,
                               const Operand &rhs_name) override
     {
-        OperandUnpacker operands(_data->tiles, _data->arguments);
-        auto            lhs = operands.unpack(lhs_name);
-        auto            rhs = operands.unpack(rhs_name);
-        auto            dst = operands.unpack(dst_name);
+        OperandUnpacker    operands(_data->tiles, _data->arguments);
+        const IVectorTile *lhs = operands.unpack(lhs_name);
+        const IVectorTile *rhs = operands.unpack(rhs_name);
+        const IVectorTile *dst = operands.unpack(dst_name);
 
         const int32_t dst_w = dst->format().w;
         const int32_t dst_h = dst->format().h;
@@ -3488,12 +3531,12 @@
             return;
         }
 
-        bool broadcast_lhs_x = dst_w != 1 && lhs_w == 1;
-        bool broadcast_rhs_x = dst_w != 1 && rhs_w == 1;
+        const bool broadcast_lhs_x = dst_w != 1 && lhs_w == 1;
+        const bool broadcast_rhs_x = dst_w != 1 && rhs_w == 1;
 
-        std::string lhs_prefix = broadcast_lhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : "";
-        std::string rhs_prefix = broadcast_rhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : "";
-        std::string op_str     = to_string(op);
+        const std::string lhs_prefix = broadcast_lhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : "";
+        const std::string rhs_prefix = broadcast_rhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : "";
+        const std::string op_str     = to_string(op);
 
         // Broadcasting on Y is automatic
         for(int32_t y = 0; y < dst_h; ++y)
@@ -3511,21 +3554,20 @@
 
     void op_cast_expression(const Operand &o_dst, const Operand &o_src, ConvertPolicy policy) override
     {
-        CKW_UNUSED(policy);
-
-        OperandUnpacker operands(_data->tiles, _data->arguments);
-        auto            src = operands.unpack(o_src);
-        auto            dst = operands.unpack(o_dst);
+        OperandUnpacker    operands(_data->tiles, _data->arguments);
+        const IVectorTile *src = operands.unpack(o_src);
+        const IVectorTile *dst = operands.unpack(o_dst);
 
         // const int32_t dst_w  = dst->format().w;
         const int32_t     dst_h = dst->format().h;
-        const std::string dt    = dst->scalar(0, 0).type.str;
+        const std::string dt    = dst->underlying_source_variables()[0].type.str;
+        const std::string sat   = (policy == ConvertPolicy::Saturate ? "_sat" : "");
 
         // Broadcasting on Y is automatic
         for(int32_t y = 0; y < dst_h; ++y)
         {
             _data->code += dst->vector(y).str;
-            _data->code += " = convert_" + dt + "(";
+            _data->code += " = convert_" + dt + sat + "(";
             _data->code += src->vector(y).str;
             _data->code += ");\n";
         }
@@ -3533,19 +3575,18 @@
 
     void op_assign(const Operand &dst_name, const Operand &src_name) override
     {
-        OperandUnpacker operands(_data->tiles, _data->arguments);
-        auto            src = operands.unpack(src_name);
-        auto            dst = operands.unpack(dst_name);
+        OperandUnpacker    operands(_data->tiles, _data->arguments);
+        const IVectorTile *src = operands.unpack(src_name);
+        const IVectorTile *dst = operands.unpack(dst_name);
 
-        const int32_t dst_w = dst->format().w;
-        const int32_t dst_h = dst->format().h;
-        const int32_t src_w = src->format().w;
-        // const int32_t src_h  = src->format().h;
-        const std::string dt = dst->scalar(0, 0).type.str;
+        const int32_t     dst_w = dst->format().w;
+        const int32_t     dst_h = dst->format().h;
+        const int32_t     src_w = src->format().w;
+        const std::string dt    = dst->underlying_source_variables()[0].type.str;
 
-        bool broadcast_src_x = dst_w != 1 && src_w == 1;
+        const bool broadcast_src_x = dst_w != 1 && src_w == 1;
 
-        std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
+        const std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
 
         // Broadcasting on Y is automatic
         for(int32_t y = 0; y < dst_h; ++y)
@@ -3558,21 +3599,20 @@
     }
 
     void
-    op_scalar_function(const Operand &dst_name, const Operand &src_name, ScalarUnaryFunction func) override
+    op_unary_elementwise_function(const Operand &dst_name, UnaryFunction func, const Operand &src_name) override
     {
-        OperandUnpacker operands(_data->tiles, _data->arguments);
-        auto            src = operands.unpack(src_name);
-        auto            dst = operands.unpack(dst_name);
+        OperandUnpacker    operands(_data->tiles, _data->arguments);
+        const IVectorTile *src = operands.unpack(src_name);
+        const IVectorTile *dst = operands.unpack(dst_name);
 
-        const int32_t dst_w = dst->format().w;
-        const int32_t dst_h = dst->format().h;
-        const int32_t src_w = src->format().w;
-        // const int32_t src_h  = src->format().h;
-        const std::string dt = dst->scalar(0, 0).type.str;
+        const int32_t     dst_w = dst->format().w;
+        const int32_t     dst_h = dst->format().h;
+        const int32_t     src_w = src->format().w;
+        const std::string dt    = dst->underlying_source_variables()[0].type.str;
 
-        bool broadcast_src_x = dst_w != 1 && src_w == 1;
+        const bool broadcast_src_x = dst_w != 1 && src_w == 1;
 
-        std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
+        const std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
 
         // Broadcasting on Y is automatic
         for(int32_t y = 0; y < dst_h; ++y)
@@ -3582,12 +3622,35 @@
 
             switch(func)
             {
-                case ScalarUnaryFunction::Exp:
+                case UnaryFunction::Exp:
                     _data->code += "exp(";
                     break;
-
+                case UnaryFunction::Tanh:
+                    _data->code += "tanh(";
+                    break;
+                case UnaryFunction::Sqrt:
+                    _data->code += "sqrt(";
+                    break;
+                case UnaryFunction::Erf:
+                    _data->code += "erf(";
+                    break;
+                case UnaryFunction::Fabs:
+                    _data->code += "fabs(";
+                    break;
+                case UnaryFunction::IsGreaterEqual:
+                    _data->code += "isgreaterequal(";
+                    break;
+                case UnaryFunction::Log:
+                    _data->code += "log(";
+                    break;
+                case UnaryFunction::SizeOf:
+                    _data->code += "sizeof(";
+                    break;
+                case UnaryFunction::Round:
+                    _data->code += "round(";
+                    break;
                 default:
-                    CKW_ASSERT(false);
+                    CKW_ASSERT_MSG(false, "Unexpected UnaryFunction used.");
             }
 
             _data->code += src_prefix + src->vector(y).str;
@@ -3595,11 +3658,105 @@
         }
     }
 
-    void op_if(const Operand &o_lhs, BinaryOp op, const Operand &o_rhs) override
+    void op_binary_elementwise_function(const Operand &dst_name, BinaryFunction func, const Operand &first_name, const Operand &second_name) override
     {
-        OperandUnpacker operands(_data->tiles, _data->arguments);
-        auto            lhs = operands.unpack(o_lhs);
-        auto            rhs = operands.unpack(o_rhs);
+        OperandUnpacker    operands(_data->tiles, _data->arguments);
+        const IVectorTile *first  = operands.unpack(first_name);
+        const IVectorTile *second = operands.unpack(second_name);
+        const IVectorTile *dst    = operands.unpack(dst_name);
+
+        const int32_t     dst_w        = dst->format().w;
+        const int32_t     dst_h        = dst->format().h;
+        const int32_t     first_w      = first->format().w;
+        const int32_t     second_w     = second->format().w;
+        const auto        datatype     = dst->underlying_source_variables()[0].type;
+        const std::string datatype_str = datatype.str;
+
+        const bool broadcast_first_x  = dst_w != 1 && first_w == 1;
+        const bool broadcast_second_x = dst_w != 1 && second_w == 1;
+
+        const std::string first_prefix  = broadcast_first_x ? "(" + datatype_str + ")" : "";
+        const std::string second_prefix = broadcast_second_x ? "(" + datatype_str + ")" : "";
+
+        const bool is_float = (datatype.dt == DataType::Fp32 || datatype.dt == DataType::Fp16);
+
+        // Broadcasting on Y is automatic
+        for(int32_t y = 0; y < dst_h; ++y)
+        {
+            _data->code += dst->vector(y).str;
+            _data->code += " = ";
+
+            switch(func)
+            {
+                case BinaryFunction::Min:
+                    _data->code += is_float ? "fmin(" : "min(";
+                    break;
+                case BinaryFunction::Max:
+                    _data->code += is_float ? "fmax(" : "max(";
+                    break;
+                default:
+                    CKW_ASSERT_MSG(false, "Unexpected BinaryFunction used.");
+            }
+
+            _data->code += first_prefix + first->vector(y).str;
+            _data->code += ", ";
+            _data->code += second_prefix + second->vector(y).str;
+            _data->code += ");\n";
+        }
+    }
+
+    void op_ternary_elementwise_function(const Operand &dst_name, TernaryFunction func, const Operand &first_name, const Operand &second_name, const Operand &third_name) override
+    {
+        OperandUnpacker    operands(_data->tiles, _data->arguments);
+        const IVectorTile *first  = operands.unpack(first_name);
+        const IVectorTile *second = operands.unpack(second_name);
+        const IVectorTile *third  = operands.unpack(third_name);
+        const IVectorTile *dst    = operands.unpack(dst_name);
+
+        const int32_t     dst_w    = dst->format().w;
+        const int32_t     dst_h    = dst->format().h;
+        const int32_t     first_w  = first->format().w;
+        const int32_t     second_w = second->format().w;
+        const int32_t     third_w  = third->format().w;
+        const std::string dt       = dst->underlying_source_variables()[0].type.str;
+
+        const bool broadcast_first_x  = dst_w != 1 && first_w == 1;
+        const bool broadcast_second_x = dst_w != 1 && second_w == 1;
+        const bool broadcast_third_x  = dst_w != 1 && third_w == 1;
+
+        const std::string first_prefix  = broadcast_first_x ? "(" + dt + ")" : "";
+        const std::string second_prefix = broadcast_second_x ? "(" + dt + ")" : "";
+        const std::string third_prefix  = broadcast_third_x ? "(" + dt + ")" : "";
+
+        // Broadcasting on Y is automatic
+        for(int32_t y = 0; y < dst_h; ++y)
+        {
+            _data->code += dst->vector(y).str;
+            _data->code += " = ";
+
+            switch(func)
+            {
+                case TernaryFunction::Select:
+                    _data->code += "select(";
+                    break;
+                default:
+                    CKW_ASSERT_MSG(false, "Unexpected TernaryFunction used.");
+            }
+
+            _data->code += first_prefix + first->vector(y).str;
+            _data->code += ", ";
+            _data->code += second_prefix + second->vector(y).str;
+            _data->code += ", ";
+            _data->code += third_prefix + third->vector(y).str;
+            _data->code += ");\n";
+        }
+    }
+
+    void op_if_header(const Operand &o_lhs, BinaryOp op, const Operand &o_rhs) override
+    {
+        OperandUnpacker    operands(_data->tiles, _data->arguments);
+        const IVectorTile *lhs = operands.unpack(o_lhs);
+        const IVectorTile *rhs = operands.unpack(o_rhs);
 
         assert(is_tile_scalar(lhs));
         assert(is_tile_scalar(rhs));
@@ -3613,13 +3770,23 @@
         _data->code += ")\n";
     }
 
-    void op_for_loop(const Operand &var_name, BinaryOp cond_op, const Operand &cond_value_name,
-                     AssignmentOp update_op, const Operand &update_value_name) override
+    void op_else_if_header(const Operand &o_lhs, BinaryOp op, const Operand &o_rhs) override
     {
-        OperandUnpacker operands(_data->tiles, _data->arguments);
-        auto            var          = operands.unpack(var_name);
-        auto            cond_value   = operands.unpack(cond_value_name);
-        auto            update_value = operands.unpack(update_value_name);
+        _data->code += "else ";
+        op_if_header(o_lhs, op, o_rhs);
+    }
+
+    void op_else_header() override
+    {
+        _data->code += "else\n";
+    }
+
+    void op_for_loop_header(const Operand& var_name, BinaryOp cond_op, const Operand& cond_value_name, AssignmentOp update_op, const Operand& update_value_name) override
+    {
+        OperandUnpacker    operands(_data->tiles, _data->arguments);
+        const IVectorTile *var          = operands.unpack(var_name);
+        const IVectorTile *cond_value   = operands.unpack(cond_value_name);
+        const IVectorTile *update_value = operands.unpack(update_value_name);
 
         const int32_t dst_w = var->format().w;
         const int32_t dst_h = var->format().h;
@@ -3646,15 +3813,17 @@
                            const Operand &dilation_y) override
     {
         OperandUnpacker operands(_data->tiles, _data->arguments);
-        auto            dst   = operands.unpack(o_dst);
-        auto            x     = operands.unpack(o_x);
-        auto            y     = operands.unpack(o_y);
-        auto            z     = operands.unpack(o_z);
-        auto            dil_y = operands.unpack(dilation_y);
-        auto            b     = operands.unpack(o_batch_idx);
+
+        // Not const as it requires changes to 'load_writer'.
+        IVectorTile *dst   = operands.unpack(o_dst);
+        IVectorTile *x     = operands.unpack(o_x);
+        IVectorTile *y     = operands.unpack(o_y);
+        IVectorTile *z     = operands.unpack(o_z);
+        IVectorTile *dil_y = operands.unpack(dilation_y);
+        IVectorTile *b     = operands.unpack(o_batch_idx);
 
         TensorOperandUnpacker tensor_operands(_data->arguments);
-        auto                  tensor      = tensor_operands.unpack(o_tensor);
+        IGpuTensorArgument   *tensor      = tensor_operands.unpack(o_tensor);
         auto                  gpu_sampler = o_tensor.sampler();
 
         GpuTensor3dMapper mapper(tensor, gpu_sampler);
@@ -3682,14 +3851,16 @@
                           const Operand &o_batch_idx) override
     {
         OperandUnpacker operands(_data->tiles, _data->arguments);
-        auto            dst   = operands.unpack(o_dst);
-        auto            x     = operands.unpack(o_x);
-        auto            y_ind = operands.unpack(o_indirect_h);
-        auto            z     = operands.unpack(o_z);
-        auto            b     = operands.unpack(o_batch_idx);
+
+        // Not const as it requires changes to 'load_writer'.
+        IVectorTile *dst   = operands.unpack(o_dst);
+        IVectorTile *x     = operands.unpack(o_x);
+        IVectorTile *y_ind = operands.unpack(o_indirect_h);
+        IVectorTile *z     = operands.unpack(o_z);
+        IVectorTile *b     = operands.unpack(o_batch_idx);
 
         TensorOperandUnpacker tensor_operands(_data->arguments);
-        auto                  tensor      = tensor_operands.unpack(o_tensor);
+        IGpuTensorArgument   *tensor      = tensor_operands.unpack(o_tensor);
         auto                  gpu_sampler = o_tensor.sampler();
 
         GpuTensor3dMapper mapper(tensor, gpu_sampler);
@@ -3712,14 +3883,16 @@
                             const Operand &batch_index_name) override
     {
         OperandUnpacker operands(_data->tiles, _data->arguments);
-        auto            src = operands.unpack(src_name);
-        auto            x   = operands.unpack(x_name);
-        auto            y   = operands.unpack(y_name);
-        auto            z   = operands.unpack(z_name);
-        auto            b   = operands.unpack(batch_index_name);
+
+        // Not const as it requires changes to 'load_writer'.
+        IVectorTile *src = operands.unpack(src_name);
+        IVectorTile *x   = operands.unpack(x_name);
+        IVectorTile *y   = operands.unpack(y_name);
+        IVectorTile *z   = operands.unpack(z_name);
+        IVectorTile *b   = operands.unpack(batch_index_name);
 
         TensorOperandUnpacker tensor_operands(_data->arguments);
-        auto                  tensor      = tensor_operands.unpack(tensor_name);
+        IGpuTensorArgument   *tensor      = tensor_operands.unpack(tensor_name);
         auto                  gpu_sampler = tensor_name.sampler();
 
         GpuTensor3dMapper mapper(tensor, gpu_sampler);
@@ -3747,15 +3920,15 @@
     void util_get_indirect_buffer(const Operand &o_dst, const TensorOperand &o_tensor, const Operand &o_x,
                                   const Operand &o_y, const Operand &o_x_off, const Operand &o_y_off) override
     {
-        OperandUnpacker operands(_data->tiles, _data->arguments);
-        auto            dst   = operands.unpack(o_dst);
-        auto            x     = operands.unpack(o_x);
-        auto            y     = operands.unpack(o_y);
-        auto            x_off = operands.unpack(o_x_off);
-        auto            y_off = operands.unpack(o_y_off);
+        OperandUnpacker    operands(_data->tiles, _data->arguments);
+        const IVectorTile *dst   = operands.unpack(o_dst);
+        const IVectorTile *x     = operands.unpack(o_x);
+        const IVectorTile *y     = operands.unpack(o_y);
+        const IVectorTile *x_off = operands.unpack(o_x_off);
+        const IVectorTile *y_off = operands.unpack(o_y_off);
 
         TensorOperandUnpacker tensor_operands(_data->arguments);
-        auto                  tensor = tensor_operands.unpack(o_tensor);
+        IGpuTensorArgument   *tensor = tensor_operands.unpack(o_tensor);
 
         assert(dst->format().w == 1);
         assert(x->format().w == 1);
diff --git a/compute_kernel_writer/prototype/src/TensorTileSampler.cpp b/compute_kernel_writer/prototype/src/TensorTileSampler.cpp
index 143d550..28e54df 100644
--- a/compute_kernel_writer/prototype/src/TensorTileSampler.cpp
+++ b/compute_kernel_writer/prototype/src/TensorTileSampler.cpp
@@ -24,7 +24,7 @@
 
 #include "ckw/TensorTileSampler.h"
 #include "ckw/TileOperand.h"
-#include "ckw/Types.h"
+#include "ckw/types/TensorSamplerTypes.h"
 
 namespace ckw
 {