Port operations to CKW prototype

Resolves: COMPMID-6334
Signed-off-by: Nikolaj Jensen <nikolaj.jensen@arm.com>
Change-Id: I500d30f09daec4087eb3e7aecd1de77dc8fd53b4
Signed-off-by: Nikolaj Jensen <nikolaj.jensen@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9828
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Reviewed-by: Jakub Sujak <jakub.sujak@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/compute_kernel_writer/prototype/src/KernelWriter.cpp b/compute_kernel_writer/prototype/src/KernelWriter.cpp
index 5d79985..73458ef 100644
--- a/compute_kernel_writer/prototype/src/KernelWriter.cpp
+++ b/compute_kernel_writer/prototype/src/KernelWriter.cpp
@@ -85,7 +85,7 @@
 // Tensor and tile declaration
 // =================================================================================================
 
-TensorOperand &KernelWriter::create_tensor_argument(const char *name, const TensorInfo &info)
+TensorOperand &KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info)
 {
     const auto var_name = generate_variable_name(name);
 
@@ -97,7 +97,7 @@
     return *operand;
 }
 
-TileOperand &KernelWriter::create_tile_argument(const char *name, int32_t value)
+TileOperand &KernelWriter::declare_tile_argument(const std::string &name, int32_t value)
 {
     const auto var_name = generate_variable_name(name);
 
@@ -107,7 +107,7 @@
     return *operand;
 }
 
-std::string KernelWriter::generate_variable_name(const char *name) const
+std::string KernelWriter::generate_variable_name(const std::string &name) const
 {
     std::stringstream var_name;
 
@@ -181,7 +181,7 @@
 // Data processing
 // =================================================================================================
 
-void KernelWriter::op_assign(TileOperand &dst, const TileOperand &src)
+void KernelWriter::op_assign(const TileOperand &dst, const TileOperand &src)
 {
     auto impl_dst = dst.create_impl_operand(_impl.get());
     auto impl_src = src.create_impl_operand(_impl.get());
@@ -189,7 +189,15 @@
     _impl->op_assign(impl_dst, impl_src);
 }
 
-void KernelWriter::op_binary_expression(TileOperand &dst, const TileOperand &lhs, const TileOperand &rhs, BinaryOp op)
+void KernelWriter::op_cast_expression(const TileOperand &dst, const TileOperand &src, const ConvertPolicy policy)
+{
+    auto impl_dst = dst.create_impl_operand(_impl.get());
+    auto impl_src = src.create_impl_operand(_impl.get());
+
+    _impl->op_cast_expression(impl_dst, impl_src, policy);
+}
+
+void KernelWriter::op_binary_expression(const TileOperand &dst, const TileOperand &lhs, BinaryOp op, const TileOperand &rhs)
 {
     auto impl_lhs = lhs.create_impl_operand(_impl.get());
     auto impl_rhs = rhs.create_impl_operand(_impl.get());
@@ -198,12 +206,81 @@
     _impl->op_binary_expression(impl_dst, impl_lhs, op, impl_rhs);
 }
 
-void KernelWriter::op_scalar_function(TileOperand &dst, const TileOperand &src, ScalarUnaryFunction opcode)
+void KernelWriter::op_unary_expression(const TileOperand &dst, UnaryOp op, const TileOperand &src)
 {
     auto impl_dst = dst.create_impl_operand(_impl.get());
     auto impl_src = src.create_impl_operand(_impl.get());
 
-    _impl->op_scalar_function(impl_dst, impl_src, opcode);
+    _impl->op_unary_expression(impl_dst, op, impl_src);
+}
+
+void KernelWriter::op_unary_elementwise_function(const TileOperand &dst, UnaryFunction opcode, const TileOperand &src)
+{
+    auto impl_dst = dst.create_impl_operand(_impl.get());
+    auto impl_src = src.create_impl_operand(_impl.get());
+
+    _impl->op_unary_elementwise_function(impl_dst, opcode, impl_src);
+}
+
+void KernelWriter::op_binary_elementwise_function(const TileOperand &dst, BinaryFunction opcode, const TileOperand &first, const TileOperand &second)
+{
+    auto impl_dst    = dst.create_impl_operand(_impl.get());
+    auto impl_first  = first.create_impl_operand(_impl.get());
+    auto impl_second = second.create_impl_operand(_impl.get());
+
+    _impl->op_binary_elementwise_function(impl_dst, opcode, impl_first, impl_second);
+}
+
+void KernelWriter::op_ternary_elementwise_function(const TileOperand &dst, TernaryFunction opcode, const TileOperand &first, const TileOperand &second, const TileOperand &third)
+{
+    auto impl_dst    = dst.create_impl_operand(_impl.get());
+    auto impl_first  = first.create_impl_operand(_impl.get());
+    auto impl_second = second.create_impl_operand(_impl.get());
+    auto impl_third  = third.create_impl_operand(_impl.get());
+
+    _impl->op_ternary_elementwise_function(impl_dst, opcode, impl_first, impl_second, impl_third);
+}
+
+void KernelWriter::op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body)
+{
+    auto impl_lhs = lhs.create_impl_operand(_impl.get());
+    auto impl_rhs = rhs.create_impl_operand(_impl.get());
+
+    _impl->op_if_header(impl_lhs, op, impl_rhs);
+    _impl->compound_statement_begin();
+    body();
+    _impl->compound_statement_end();
+}
+
+void KernelWriter::op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body)
+{
+    auto impl_lhs = lhs.create_impl_operand(_impl.get());
+    auto impl_rhs = rhs.create_impl_operand(_impl.get());
+
+    _impl->op_else_if_header(impl_lhs, op, impl_rhs);
+    _impl->compound_statement_begin();
+    body();
+    _impl->compound_statement_end();
+}
+
+void KernelWriter::op_else(const std::function<void()> &body)
+{
+    _impl->op_else_header();
+    _impl->compound_statement_begin();
+    body();
+    _impl->compound_statement_end();
+}
+
+void KernelWriter::op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body)
+{
+    auto impl_var_name          = var_name.create_impl_operand(_impl.get());
+    auto impl_cond_value_name   = cond_value_name.create_impl_operand(_impl.get());
+    auto impl_update_value_name = update_value_name.create_impl_operand(_impl.get());
+
+    _impl->op_for_loop_header(impl_var_name, cond_op, impl_cond_value_name, update_op, impl_update_value_name);
+    _impl->compound_statement_begin();
+    body();
+    _impl->compound_statement_end();
 }
 
 // =================================================================================================
@@ -215,6 +292,11 @@
     _impl->op_get_global_id(prototype::Operand(dst.name()), dim);
 }
 
+void KernelWriter::op_return()
+{
+    _impl->op_return();
+}
+
 // =================================================================================================
 // Code generation
 // =================================================================================================