Add temporary tile support for dynamic fusion

* Multiple intermediate tensors can share the same tile.
  - A simple operator can reuse the input tensor for the result
    if the input tensor has the same shape, data type and it is
    only consumed by that operator.
  - The special case is a simple operator and an output operator
    consume the same tensor. However as the output operator
    doesn't change the content of the input tensor, it doesn't
    count as "consuming" the input tensor.
* These temporary tiles are declared automatically by the template
  writer. Individual operator doesn't need to generate output tile
  declaration.
* Cast is now simple operator.

Resolves: COMPMID-5778
Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Change-Id: I232647ac976645e2d266a62e055b9eb48c356a8e
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8877
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.cpp b/src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.cpp
index 0d25749..81c3f0c 100644
--- a/src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.cpp
+++ b/src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.cpp
@@ -133,8 +133,9 @@
 
     _finalized = true;
 
-    std::set<const ITensorInfo *> input_tensors;
     std::set<const ITensorInfo *> output_tensors;
+    std::map<const ITensorInfo *, std::vector<const ITensorInfo *>> possible_tile_map;
+    std::map<const ITensorInfo *, int32_t> tile_usages;
 
     for(auto component : _components)
     {
@@ -156,26 +157,139 @@
             }
             else if(_interm_tensors.find(tensor) == _interm_tensors.end())
             {
-                input_tensors.insert(tensor);
+                _input_tensors.insert(tensor);
+
+                tile_usages[tensor] = 0;
+                possible_tile_map.emplace(tensor, std::vector<const ITensorInfo *>());
             }
         }
 
         for(auto tensor : dst_tensors)
         {
-            ARM_COMPUTE_ERROR_ON(input_tensors.find(tensor) != input_tensors.end());
+            ARM_COMPUTE_ERROR_ON(_input_tensors.find(tensor) != _input_tensors.end());
             ARM_COMPUTE_ERROR_ON(output_tensors.find(tensor) != output_tensors.end());
             ARM_COMPUTE_ERROR_ON(_interm_tensors.find(tensor) != _interm_tensors.end());
             output_tensors.insert(tensor);
+
+            tile_usages[tensor] = 0;
+            possible_tile_map.emplace(tensor, std::vector<const ITensorInfo *>());
+        }
+
+        // Check if the output can overwrite the input tile.
+        const auto component_type = component->type();
+        if(component_type == GpuComponentType::Simple || component_type == GpuComponentType::Output)
+        {
+            ARM_COMPUTE_ERROR_ON(dst_tensors.size() != 1);
+
+            const auto dst_tensor = dst_tensors[0];
+            const auto &dst_shape = dst_tensor->tensor_shape();
+            const auto &dst_type = dst_tensor->data_type();
+
+            tile_usages[dst_tensor] = 0;
+
+            for(auto src_tensor : src_tensors)
+            {
+                const auto &src_shape = src_tensor->tensor_shape();
+                const auto &src_type = src_tensor->data_type();
+
+                if(src_shape == dst_shape && src_type == dst_type)
+                {
+                    const auto tile_usages_it = tile_usages.find(src_tensor);
+                    ARM_COMPUTE_ERROR_ON(tile_usages_it == tile_usages.end());
+
+                    if(component_type == GpuComponentType::Simple || tile_usages_it->second > 0)
+                    {
+                        // Increase the number of tile usages unless this component is an output
+                        // and the tile has not been shared with any component.
+                        // (Reason: output component doesn't change the content of the tile)
+                        ++tile_usages_it->second;
+                    }
+
+                    possible_tile_map[dst_tensor].push_back(src_tensor);
+                }
+            }
+        }
+        else
+        {
+            // Outputs of complex and unfusable components need dedicated tile.
+            for(auto tensor : dst_tensors)
+            {
+                tile_usages[tensor] = 0;
+            }
+        }
+    }
+
+    // Find the smallest list of tiles that the intermediate tensors need to write to.
+    for(auto tensor : _input_tensors)
+    {
+        _tile_map[tensor] = tensor;
+    }
+
+    for(auto component : _components)
+    {
+        const auto dst_tensors = component->tensors().get_const_dst_tensors();
+
+        for(auto tensor : dst_tensors)
+        {
+            const auto target_tiles = possible_tile_map.at(tensor);
+            _tile_map[tensor] = tensor;
+
+            for(auto target : target_tiles)
+            {
+                const auto num_usage = tile_usages[target];
+
+                if(num_usage <= 1)
+                {
+                    // The target tile is consumed by only this operator, so we can reuse it
+                    // for the destination tensor data.
+                    _tile_map[tensor] = _tile_map.at(target);
+                    break;
+                }
+            }
+        }
+    }
+
+    for(auto tensor : output_tensors)
+    {
+        _tile_map[tensor] = tensor;
+    }
+
+    // All intermediate tensors that cannot be shared with any previous tensor
+    // will need to be declared as tile variable.
+    for(auto tensor_tile : _tile_map)
+    {
+        if(tensor_tile.first == tensor_tile.second &&
+           _interm_tensors.find(tensor_tile.first) != _interm_tensors.end())
+        {
+            _tiles.push_back(tensor_tile.first);
         }
     }
 
     std::set_union(
-        input_tensors.begin(), input_tensors.end(),
+        _input_tensors.begin(), _input_tensors.end(),
         output_tensors.begin(), output_tensors.end(),
         std::back_inserter(_argument_tensors));
     _any_output_tensor = *output_tensors.begin();
 }
 
+std::vector<const ITensorInfo *> GpuKernelComponentGroup::get_tiles() const
+{
+    ARM_COMPUTE_ERROR_ON_MSG(!_finalized, "The component group must have been finalized.");
+    return _tiles;
+}
+
+const ITensorInfo *GpuKernelComponentGroup::get_tile_for_tensor(const ITensorInfo *tensor) const
+{
+    ARM_COMPUTE_ERROR_ON_MSG(!_finalized, "The component group must have been finalized.");
+
+    if(_tile_map.find(tensor) != _tile_map.end())
+    {
+        return _tile_map.at(tensor);
+    }
+
+    return tensor;
+}
+
 const ITensorInfo *GpuKernelComponentGroup::get_any_dst_tensor() const
 {
     ARM_COMPUTE_ERROR_ON_MSG(!_finalized, "The component group must have been finalized.");
@@ -203,6 +317,12 @@
     return _interm_tensors.find(tensor) != _interm_tensors.end();
 }
 
+bool GpuKernelComponentGroup::is_input_tensor(const ITensorInfo *tensor) const
+{
+    ARM_COMPUTE_ERROR_ON_MSG(!_finalized, "The component group must have been finalized.");
+    return _input_tensors.find(tensor) != _input_tensors.end();
+}
+
 size_t GpuKernelComponentGroup::size() const
 {
     return _components.size();
diff --git a/src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h b/src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h
index 386aefd..c939aec 100644
--- a/src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h
+++ b/src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h
@@ -30,6 +30,7 @@
 #include <cstdlib>
 #include <vector>
 #include <set>
+#include <map>
 
 namespace arm_compute
 {
@@ -109,6 +110,22 @@
      * @return false  Otherwise
      */
     bool is_intermediate_tensor(const ITensorInfo *tensor) const;
+    /** Check if an @ref ITensorInfo is an input tensor of the group.
+     *
+     * @param[in] tensor @ref ITensorInfo to be looked up.
+     *
+     * @return true if @p tensor is an input tensor of the group, otherwise false.
+     */
+    bool is_input_tensor(const ITensorInfo *tensor) const;
+    /** Get the list of temporary tiles that need to be declared */
+    std::vector<const ITensorInfo *> get_tiles() const;
+    /** Get the shared tile that can be used to store temporary data of the specified tensor.
+     *
+     * @param[in] tensor @ref ITensorInfo to be looked up.
+     *
+     * @return @ref ITensorInfo that is used to store temporary data of @p tensor.
+     **/
+    const ITensorInfo *get_tile_for_tensor(const ITensorInfo *tensor) const;
     /** Get the number of components within the group */
     size_t size() const;
     /** Check if the component group is empty */
@@ -126,9 +143,13 @@
     std::vector<ComponentPtr> _components{};
 
     bool _finalized{ false };
+
     std::vector<const ITensorInfo *> _argument_tensors{};
+    std::set<const ITensorInfo *> _input_tensors{};
     std::set<const ITensorInfo *> _interm_tensors{};
     const ITensorInfo *_any_output_tensor{ nullptr };
+    std::vector<const ITensorInfo *> _tiles{};
+    std::map<const ITensorInfo *, const ITensorInfo *> _tile_map{};
 };
 } // namespace dynamic_fusion
 } // namespace experimental
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentCast.h b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentCast.h
index d0f75b1..84d6f07 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentCast.h
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentCast.h
@@ -120,7 +120,7 @@
     /** Get component type */
     GpuComponentType type() const override
     {
-        return GpuComponentType::Complex;
+        return GpuComponentType::Simple;
     }
 
 private:
diff --git a/src/dynamic_fusion/sketch/gpu/operators/GpuCast.cpp b/src/dynamic_fusion/sketch/gpu/operators/GpuCast.cpp
index 9e5e735..3a5b64a 100644
--- a/src/dynamic_fusion/sketch/gpu/operators/GpuCast.cpp
+++ b/src/dynamic_fusion/sketch/gpu/operators/GpuCast.cpp
@@ -38,7 +38,7 @@
 {
 namespace
 {
-constexpr GpuOperatorType operator_type = GpuOperatorType::Complex;
+constexpr GpuOperatorType operator_type = GpuOperatorType::Simple;
 }
 Status GpuCast::is_supported_op(const GpuWorkloadContext &context,
                                 const ITensorInfo        *src,
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp
index 13c0b14..2eafe62 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.cpp
@@ -24,6 +24,7 @@
 #include "GpuKernelVariableTable.h"
 #include "arm_compute/core/CL/CLHelpers.h"
 #include "arm_compute/core/ITensorInfo.h"
+#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
 
 namespace arm_compute
 {
@@ -31,44 +32,48 @@
 {
 namespace dynamic_fusion
 {
-void GpuKernelVariableTable::declare_variable(const ITensorInfo *tensor, GpuKernelArgumentInfo argument_info, bool is_interm, const std::string &alias)
+void GpuKernelVariableTable::declare_variable(const GpuKernelComponentGroup &comp_group, const ITensorInfo *tensor, GpuKernelArgumentInfo argument_info, const std::string &alias)
 {
     ARM_COMPUTE_ERROR_ON_MSG(!tensor->has_valid_id(), "Tensor info with valid id expected");
+
     // Do not re-declare if the variable associated with the tensor has already been declared
-    if(get_variable(tensor).has_valid_id())
+    auto it = _vars.find(tensor->id());
+
+    if(it != _vars.end())
     {
-        ARM_COMPUTE_ERROR_ON(!(get_variable(tensor).kernel_argument_info == argument_info));
+        ARM_COMPUTE_ERROR_ON(!(it->second.kernel_argument_info == argument_info));
         return;
     }
-    // Declare variable associated with the tensor
-    std::stringstream ss;
-    ss << alias << "_t" << tensor->id();
-    const auto     uniq_name = ss.str();
-    TensorVariable var{ tensor->id(), uniq_name, argument_info };
 
-    if(is_interm)
+    const auto target = comp_group.get_tile_for_tensor(tensor);
+
+    if(target != tensor)
     {
-        _interm_var = var;
-        _interm_tensors.insert(tensor->id());
+        // If the tensor uses a shared tile, don't declare another variable.
+        it = _vars.find(target->id());
+
+        ARM_COMPUTE_ERROR_ON_MSG(
+            it == _vars.end(),
+            "The variable used for this tensor must have been declared.");
+
+        _vars[tensor->id()] = it->second;
     }
     else
     {
+        // Declare variable associated with the tensor
+        std::stringstream ss;
+        ss << alias << "_t" << tensor->id();
+        const auto     uniq_name = ss.str();
+        TensorVariable var{ tensor->id(), uniq_name, argument_info };
+
         _vars.emplace(tensor->id(), var);
     }
 }
 
 GpuKernelVariableTable::TensorVariable GpuKernelVariableTable::get_variable(const ITensorInfo *tensor) const
 {
-    const TensorVariable empty_var{};
-    if(_vars.find(tensor->id()) != _vars.end())
-    {
-        return _vars.at(tensor->id());
-    }
-    if(_interm_tensors.find(tensor->id()) != _interm_tensors.end())
-    {
-        return _interm_var;
-    }
-    return empty_var;
+    const auto var = _vars.at(tensor->id());
+    return var;
 }
 
 GpuKernelVariableTable::VariableList GpuKernelVariableTable::get_variable_list(const std::vector<const ITensorInfo *> &tensors) const
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h b/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h
index 4eee396..82b7513 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/GpuKernelVariableTable.h
@@ -39,9 +39,10 @@
 {
 namespace dynamic_fusion
 {
-/** A table of all the variables used in the kernel
- * Since fusion is restricted to a linear sequence of components in a kernel, only a single "intermediate variable" (the accumulator) is allowed.
- * Each kernel has exactly one variable table
+class GpuKernelComponentGroup;
+
+/** A table of all the variables used in the kernel.
+ * Each kernel has exactly one variable table.
  */
 class GpuKernelVariableTable
 {
@@ -69,15 +70,12 @@
 public:
     /** Declare a @ref TensorVariable for a corresponding tensor info.
      *
-     * @note: Later re-declaration of the intermediate variable will overwrite the previous association to the @ref ITensorInfo
-     *        Therefore, the order of declaration is important. It's assumed that the components declaring the variable is already in correct order
-     *
+     * @param[in] comp_group    Component group the tensor belongs to
      * @param[in] tensor        Tensor info with which the new variable is associated
      * @param[in] argument_info Kernel argument information
-     * @param[in] is_interm     If the new variable is an intermediate variable
      * @param[in] alias         Alias for the variable. Will be used as part of the variable name
      */
-    void declare_variable(const ITensorInfo *tensor, GpuKernelArgumentInfo argument_info, bool is_interm = false, const std::string &alias = "unnamed");
+    void declare_variable(const GpuKernelComponentGroup &comp_group, const ITensorInfo *tensor, GpuKernelArgumentInfo argument_info, const std::string &alias = "unnamed");
     /** Get the @ref TensorVariable associated with @p tensor
      *
      * @param[in] tensor Tensor info to be queried
@@ -95,9 +93,7 @@
     VariableList get_variable_list(const std::vector<const ITensorInfo *> &tensors) const;
 
 private:
-    std::map<ITensorInfo::Id, TensorVariable> _vars{}; /**< Non-intermediate (function parameter) variables*/
-    TensorVariable            _interm_var{};           /**< Intermediate variable */
-    std::set<ITensorInfo::Id> _interm_tensors{};       /**< Tensors associated with the single intermediate variable */
+    std::map<ITensorInfo::Id, TensorVariable> _vars{};
 };
 
 /** A tag value will substitute a tag in a string template during its instantiation */
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.cpp
index 8adf056..53e74b4 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.cpp
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateActivation.cpp
@@ -67,14 +67,14 @@
 // IN(src)              {{src}}
 // OUT(dst, accum)      {{dst}}
 
-TILE({{DATA_TYPE}}, M0, N0, {{dst}});
+TILE({{DATA_TYPE}}, M0, N0, {{src}});
 TILE(uint, M0, 1, g_dst_indirect_y);
 {
     {{src}}_offset_first_element_in_bytes += g_ind_2 * {{src}}_stride_z;
 
-    T_LOAD({{DATA_TYPE}}, M0, N0, {{TENSOR_TYPE}}, {{src}}, g_ind_0, g_ind_1, 1, {{src}}_stride_y, {{dst}});
+    T_LOAD({{DATA_TYPE}}, M0, N0, {{TENSOR_TYPE}}, {{src}}, g_ind_0, g_ind_1, 1, {{src}}_stride_y, {{src}});
 
-    T_ACTIVATION({{DATA_TYPE}}, M0, N0, {{ACT}}, {{A_VAL}}, {{B_VAL}}, {{dst}}, {{dst}});
+    T_ACTIVATION({{DATA_TYPE}}, M0, N0, {{ACT}}, {{A_VAL}}, {{B_VAL}}, {{src}}, {{dst}});
 }
 
 LOOP_UNROLLING(int, i, 0, 1, M0,
@@ -91,7 +91,7 @@
 // IN/OUT(src, accum)   {{src}}
 
 {
-    T_ACTIVATION({{DATA_TYPE}}, M0, N0, {{ACT}}, {{A_VAL}}, {{B_VAL}}, {{src}}, {{src}});
+    T_ACTIVATION({{DATA_TYPE}}, M0, N0, {{ACT}}, {{A_VAL}}, {{B_VAL}}, {{src}}, {{dst}});
 }
 )_";
     }
@@ -104,15 +104,15 @@
 void ClTemplateActivation::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
 {
     vtable.declare_variable(
+        comp_group,
         _src,
         GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
-        comp_group.is_intermediate_tensor(_src),
         "src");
 
     vtable.declare_variable(
+        comp_group,
         _dst,
         GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
-        comp_group.is_intermediate_tensor(_dst),
         "dst");
 }
 
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp
index 6ab3a68..dcb43f9 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateCast.cpp
@@ -54,20 +54,26 @@
     ARM_COMPUTE_UNUSED(comp_group);
 
     const std::string kernel_name = get_name();
+    const auto is_root = (comp_group.get_root_component()->id() == this->id());
 
     std::string code = R"_(
-//------------------ START KERNEL {{meta_kernel_id}} ---------------------
+//------------------ START KERNEL {{meta_kernel_id}} CAST ---------------------
+)_";
+
+    if(is_root)
+    {
+        code += R"_(
 // IN_0(src)            {{src}}
 // OUT(dst, accum)      {{dst}}
 
-TILE({{DATA_TYPE_OUT}}, M0, N0, {{dst}});
 TILE(uint, M0, 1, g_dst_indirect_y);
 {
     {{src}}_offset_first_element_in_bytes += get_global_id(2) * {{src}}_stride_z;
 
-    TILE({{DATA_TYPE_IN}}, M0, N0, in_data);
-    T_LOAD({{DATA_TYPE_IN}}, M0, N0, BUFFER, {{src}}, g_ind_0, g_ind_1, 1, {{src}}_stride_y, in_data);
+    TILE({{DATA_TYPE_IN}}, M0, N0, {{tmp}});
+    T_LOAD({{DATA_TYPE_IN}}, M0, N0, BUFFER, {{src}}, g_ind_0, g_ind_1, 1, {{src}}_stride_y, {{tmp}});
 )_";
+    }
 
     code += R"_(
     LOOP_UNROLLING(int, m0, 0, 1, M0,
@@ -77,20 +83,20 @@
     if(kernel_name == "cast_down" && is_data_type_quantized(_src->data_type()))
     {
         code += R"_(
-    in_data[m0].v ^= (VEC_DATA_TYPE({{DATA_TYPE_IN}}, N0))0x80;
+    {{tmp}}[m0].v ^= (VEC_DATA_TYPE({{DATA_TYPE_IN}}, N0))0x80;
 )_";
     }
 
     if(kernel_name == "cast_down" && (is_data_type_float(_src->data_type()) || _attributes.convert_policy() == ConvertPolicy::SATURATE))
     {
         code += R"_(
-    {{dst}}[m0].v = CONVERT_SAT(in_data[m0].v, VEC_DATA_TYPE({{DATA_TYPE_OUT}}, N0));
+    {{dst}}[m0].v = CONVERT_SAT({{tmp}}[m0].v, VEC_DATA_TYPE({{DATA_TYPE_OUT}}, N0));
 )_";
     }
     else
     {
         code += R"_(
-    {{dst}}[m0].v = CONVERT(in_data[m0].v, VEC_DATA_TYPE({{DATA_TYPE_OUT}}, N0));
+    {{dst}}[m0].v = CONVERT({{tmp}}[m0].v, VEC_DATA_TYPE({{DATA_TYPE_OUT}}, N0));
 )_";
     }
 
@@ -98,7 +104,9 @@
     })
 )_";
 
-    code += R"_(
+    if(is_root)
+    {
+        code += R"_(
     LOOP_UNROLLING(int, i, 0, 1, M0,
     {
         g_dst_indirect_y[i].v = (uint)min((int)(g_ind_1 + i), (int)({{arg_dst}}_w) - 1);
@@ -106,7 +114,11 @@
         g_dst_indirect_y[i].v += (int)(g_ind_2 / {{arg_dst}}_h) * (int)({{arg_dst}}_w * {{arg_dst}}_h);
     })
 }
-//------------------ END KERNEL {{meta_kernel_id}} ---------------------
+)_";
+    }
+
+    code += R"_(
+//------------------ END KERNEL {{meta_kernel_id}} CAST ---------------------
 )_";
 
     return code;
@@ -115,27 +127,28 @@
 void ClTemplateCast::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
 {
     vtable.declare_variable(
+        comp_group,
         _src,
         GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
-        comp_group.is_intermediate_tensor(_src),
         "src");
 
     vtable.declare_variable(
+        comp_group,
         _dst,
         GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
-        comp_group.is_intermediate_tensor(_dst),
         "dst");
 }
 
 TagLUT ClTemplateCast::get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
 {
-    ARM_COMPUTE_UNUSED(comp_group);
+    const auto is_root = (comp_group.get_root_component()->id() == this->id());
 
     TagLUT lut{};
 
     // Arguments and global shared variables
     lut["src"] = vtable.get_variable(_src);
     lut["dst"] = vtable.get_variable(_dst);
+    lut["tmp"] = (is_root) ? lut["src"].value + "_in_data" : lut["src"];
 
     const auto dst_argument = vtable.get_variable(comp_group.get_any_dst_tensor());
     lut["arg_dst"]          = dst_argument.uniq_name;
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDepthwiseConv2d.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDepthwiseConv2d.cpp
index 6fa77aa..ab7cc9f 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDepthwiseConv2d.cpp
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDepthwiseConv2d.cpp
@@ -81,7 +81,6 @@
     code += R"_(
 // OUT(dst, accum)      {{dst}}
 
-TILE({{ACC_DATA_TYPE}}, M0, N0, {{dst}});
 TILE(uint, M0, 1, g_dst_indirect_y);
 
 {
@@ -206,9 +205,9 @@
                                                        GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer;
 
     vtable.declare_variable(
+        comp_group,
         _src,
         GpuKernelArgumentInfo(input_type),
-        comp_group.is_intermediate_tensor(_src),
         "src");
 
     const GpuKernelArgumentInfo::Type weight_type = _settings.export_weights_to_cl_image() ?
@@ -216,23 +215,23 @@
                                                         GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer;
 
     vtable.declare_variable(
+        comp_group,
         _weight,
         GpuKernelArgumentInfo(weight_type),
-        comp_group.is_intermediate_tensor(_weight),
         "weight");
 
     if(_bias != nullptr && _bias->has_valid_id()) // optional bias
     {
         vtable.declare_variable(
+            comp_group,
             _bias,
             GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Vector),
-            comp_group.is_intermediate_tensor(_bias),
             "bias");
     }
     vtable.declare_variable(
+        comp_group,
         _dst,
         GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
-        comp_group.is_intermediate_tensor(_dst),
         "dst");
 }
 
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDirectConv2d.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDirectConv2d.cpp
index 26399c5..c6e14f9 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDirectConv2d.cpp
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateDirectConv2d.cpp
@@ -86,7 +86,6 @@
     code += R"_(
 // OUT(dst, accum)      {{dst}}
 
-TILE({{ACC_DATA_TYPE}}, M0, N0, {{dst}});
 TILE(uint, M0, 1, g_dst_indirect_y);
 
 {
@@ -227,30 +226,30 @@
 void ClTemplateDirectConv2d::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
 {
     vtable.declare_variable(
+        comp_group,
         _src,
         GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
-        comp_group.is_intermediate_tensor(_src),
         "src");
 
     const GpuKernelArgumentInfo::Type weight_type = _settings.export_to_cl_image() ? GpuKernelArgumentInfo::Type::Tensor_4D_t_Image : GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer;
     vtable.declare_variable(
+        comp_group,
         _weight,
         GpuKernelArgumentInfo(weight_type),
-        comp_group.is_intermediate_tensor(_weight),
         "weight");
 
     if(_bias && _bias->has_valid_id()) // optional bias
     {
         vtable.declare_variable(
+            comp_group,
             _bias,
             GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Vector),
-            comp_group.is_intermediate_tensor(_bias),
             "bias");
     }
     vtable.declare_variable(
+        comp_group,
         _dst,
         GpuKernelArgumentInfo(common_tensor_type),
-        comp_group.is_intermediate_tensor(_dst),
         "dst");
 }
 
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp
index 39cec6e..df8deee 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.cpp
@@ -65,94 +65,94 @@
     std::string code;
     const bool  is_broadcast = _lhs->tensor_shape() != _rhs->tensor_shape();
     const bool  is_root      = (comp_group.get_root_component()->id() == this->id());
+    const bool  is_lhs_input = comp_group.is_input_tensor(_lhs);
+    const bool  is_rhs_input = comp_group.is_input_tensor(_rhs);
+
+    code =
+R"_(
+    //------------------ START KERNEL {{meta_kernel_id}} ELTWISE_OP ---------------------
+)_";
 
     if(is_root)
     {
-        code =
+        code +=
 R"_(
-    //------------------ START KERNEL {{meta_kernel_id}} ELTWISE_OP ---------------------
-)_"
-    // IN_0(LHS)            {{lhs}}
-    // IN_1(RHS)            {{rhs}}
-    // OUT(dst, accum)      {{dst}}
-    // dst = lhs + rhs (mix-precision, broadcast, boundary aware)
-R"_(
-    TILE({{DATA_TYPE}}, M0, N0, {{dst}});
     TILE(uint, M0, 1, g_dst_indirect_y);
+)_";
+    }
+
+    if(is_lhs_input)
     {
-        TILE({{DATA_TYPE}}, M0, N0, lhs_tile);
-        TILE({{DATA_TYPE}}, M0, N0, rhs_tile);
-)_"
-        // Assuming un-collapsed window
+        code +=
+R"_(
+    TILE({{DATA_TYPE}}, M0, N0, {{lhs}});
+)_";
+    }
+
+    if(is_rhs_input)
+    {
+        code +=
+R"_(
+    TILE({{DATA_TYPE}}, M0, N0, {{rhs}});
+)_";
+    }
+
+    code +=
+R"_(
+    {
+)_";
+
+    if(is_lhs_input)
+    {
+        code +=
 R"_(
         {{lhs}}_offset_first_element_in_bytes += g_ind_2 * {{lhs}}_stride_z;
+        T_LOAD({{DATA_TYPE}}, {{lhs_m0}}, {{lhs_n0}}, BUFFER, {{lhs}}, {{lhs_start_ind_0}}, {{lhs_start_ind_1}}, 1, {{lhs}}_stride_y, {{lhs}});
+)_";
+    }
+
+    if(is_rhs_input)
+    {
+        code +=
+R"_(
         {{rhs}}_offset_first_element_in_bytes += g_ind_2 * {{rhs}}_stride_z;
-
-        T_LOAD({{DATA_TYPE}}, M0, N0, BUFFER, {{lhs}}, g_ind_0, g_ind_1, 1, {{lhs}}_stride_y, lhs_tile);
-        T_LOAD({{DATA_TYPE}}, {{rhs_m0}}, {{rhs_n0}}, BUFFER, {{rhs}}, {{rhs_start_ind_0}}, {{rhs_start_ind_1}}, 1, {{rhs}}_stride_y, rhs_tile);
-)_";
-        if(is_broadcast)
-        {
-            code +=
-R"_(
-        T_ELTWISE_BROADCAST_{{ELTWISE_OP}}_X({{DATA_TYPE}}, M0, N0, lhs_tile, rhs_tile, {{dst}});
-)_";
-        }
-        else
-        {
-            code +=
-R"_(
-        T_ELTWISE_{{ELTWISE_OP}}({{DATA_TYPE}}, M0, N0, lhs_tile, rhs_tile, {{dst}});
-)_";
-        }
-    code +=
-    // Calculate the destination indirect Y
-R"_(
-    LOOP_UNROLLING(int, i, 0, 1, M0,
-    {
-        g_dst_indirect_y[i].v = (uint)min(g_ind_1 + i, (int)({{out}}_w * {{out}}_h) - 1);
-        g_dst_indirect_y[i].v += g_ind_2 * (int)({{out}}_w * {{out}}_h);
-    })
-    }
-    //------------------ END KERNEL {{meta_kernel_id}} ELTWISE_OP ---------------------
+        T_LOAD({{DATA_TYPE}}, {{rhs_m0}}, {{rhs_n0}}, BUFFER, {{rhs}}, {{rhs_start_ind_0}}, {{rhs_start_ind_1}}, 1, {{rhs}}_stride_y, {{rhs}});
 )_";
     }
 
-    else // non-root
+    if(is_broadcast)
     {
-        code =
-R"_(
-    //------------------ START KERNEL {{meta_kernel_id}} ELTWISE_OP ---------------------
-)_"
-    // IN_0/Out(Accumulator)   {{acc}}
-    // IN_1(Operand)        {{operand}}
-    // acc = operand + acc (mix-precision, broadcast, boundary aware)
-R"_(
-    {
-        TILE(DATA_TYPE, M0, N0, operand_tile);
-        T_LOAD({{DATA_TYPE}}, {{rhs_m0}}, {{rhs_n0}}, BUFFER, {{operand}}, {{rhs_start_ind_0}}, {{rhs_start_ind_1}}, 1, {{operand}}_stride_y, operand_tile);
+        code +=
+            R"_(
+        T_ELTWISE_BROADCAST_{{ELTWISE_OP}}_X({{DATA_TYPE}}, M0, N0, {{lhs}}, {{rhs}}, {{dst}});
 )_";
+    }
+    else
+    {
+        code +=
+            R"_(
+        T_ELTWISE_{{ELTWISE_OP}}({{DATA_TYPE}}, M0, N0, {{lhs}}, {{rhs}}, {{dst}});
+)_";
+    }
 
-        if(is_broadcast)
-        {
-            code +=
+    if(is_root)
+    {
+        // Calculate the destination indirect Y
+        code +=
 R"_(
-        T_ELTWISE_BROADCAST_{{ELTWISE_OP}}_X({{DATA_TYPE}}, M0, N0, {{acc}}, operand_tile, {{acc}});
-)_";
-        }
-        else
+        LOOP_UNROLLING(int, i, 0, 1, M0,
         {
-            code +=
-R"_(
-        T_ELTWISE_{{ELTWISE_OP}}({{DATA_TYPE}}, M0, N0, {{acc}}, operand_tile, {{acc}});
+            g_dst_indirect_y[i].v = (uint)min(g_ind_1 + i, (int)({{arg_dst}}_w * {{arg_dst}}_h) - 1);
+            g_dst_indirect_y[i].v += g_ind_2 * (int)({{arg_dst}}_w * {{arg_dst}}_h);
+        })
 )_";
-        }
+    }
+
     code +=
 R"_(
     }
     //------------------ END KERNEL {{meta_kernel_id}} ELTWISE_OP ---------------------
 )_";
-    }
 
     return code;
 }
@@ -160,62 +160,38 @@
 void ClTemplateElementwiseBinary::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
 {
     vtable.declare_variable(
+        comp_group,
         _lhs,
         GpuKernelArgumentInfo(common_tensor_type),
-        comp_group.is_intermediate_tensor(_lhs),
         "lhs");
 
     vtable.declare_variable(
+        comp_group,
         _rhs,
         GpuKernelArgumentInfo(common_tensor_type),
-        comp_group.is_intermediate_tensor(_rhs),
         "rhs");
 
     vtable.declare_variable(
+        comp_group,
         _dst,
         GpuKernelArgumentInfo(common_tensor_type),
-        comp_group.is_intermediate_tensor(_dst),
         "dst");
 }
 
 TagLUT ClTemplateElementwiseBinary::get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
 {
     TagLUT             lut{};
-    const ITensorInfo *accumulator = _lhs;
-    const ITensorInfo *operand     = _rhs;
 
     // Local build options
     lut["meta_kernel_id"] = id();
     lut["DATA_TYPE"]      = get_cl_type_from_data_type(_lhs->data_type());
     // Arguments and global shared variables
-    const bool is_root = (comp_group.get_root_component()->id() == this->id());
-    if(is_root)
-    {
-        lut["lhs"] = vtable.get_variable(_lhs);
-        lut["rhs"] = vtable.get_variable(_rhs);
-        lut["dst"] = vtable.get_variable(_dst);
-        lut["out"] = vtable.get_variable(comp_group.get_any_dst_tensor());
-    }
-    else
-    {
-        // Determine which tensor is the accumulator
-        if(comp_group.is_intermediate_tensor(_lhs))
-        {
-            accumulator = _lhs;
-            operand     = _rhs;
-        }
-        else if(comp_group.is_intermediate_tensor(_rhs))
-        {
-            accumulator = _rhs;
-            operand     = _lhs;
-        }
-        else
-        {
-            ARM_COMPUTE_ERROR("Invalid elementwise component linking");
-        }
-        lut["acc"]     = vtable.get_variable(accumulator);
-        lut["operand"] = vtable.get_variable(operand);
-    }
+
+    lut["lhs"] = vtable.get_variable(_lhs);
+    lut["rhs"] = vtable.get_variable(_rhs);
+    lut["dst"] = vtable.get_variable(_dst);
+    lut["arg_dst"] = vtable.get_variable(comp_group.get_any_dst_tensor());
+
     switch(_attributes.operation())
     {
         case Attributes::ElementwiseOp::ADD:
@@ -224,22 +200,65 @@
         default:
             ARM_COMPUTE_ERROR("Arithmetic Operation not supported");
     }
-    ARM_COMPUTE_ERROR_ON_MSG(detail::have_different_dimensions(accumulator->tensor_shape(), _dst->tensor_shape(), 0), "Only the operand can be broadcast to match the accumulator's shape");
-    const bool is_broadcast = (operand->tensor_shape() != _dst->tensor_shape());
+
+    ARM_COMPUTE_ERROR_ON(
+        comp_group.is_intermediate_tensor(_lhs) &&
+        detail::have_different_dimensions(_lhs->tensor_shape(), _dst->tensor_shape(), 0));
+    ARM_COMPUTE_ERROR_ON(
+        comp_group.is_intermediate_tensor(_rhs) &&
+        detail::have_different_dimensions(_rhs->tensor_shape(), _dst->tensor_shape(), 0));
 
     // Set broadcast parameters
     // PRE: All tensors are broadcast-compatible
-    if(is_broadcast)
+    if(_lhs->tensor_shape() != _dst->tensor_shape())
     {
+        const auto is_broadcast_x = _lhs->dimension(0) == 1U && _dst->dimension(0) != 1U;
+        const auto is_broadcast_y = _lhs->dimension(1) == 1U && _dst->dimension(1) != 1U;
+        const auto is_broadcast_z = _lhs->dimension(2) == 1U && _dst->dimension(2) != 1U;
+
         // Note that n0 maps to input tensor dimension 0, m0 maps to input dimensions 1 and 2 because of our collapse strategy
-        if(operand->dimension(0) == 1U && operand->dimension(1) == 1U && operand->dimension(2) == 1U) // Broadcast in X, Y, Z: collapsed rhs win [M0xN0] = [1x1]
+        if(is_broadcast_x && is_broadcast_y && is_broadcast_z) // Broadcast in X, Y, Z: collapsed lhs win [M0xN0] = [1x1]
+        {
+            lut["lhs_m0"]          = "1";
+            lut["lhs_n0"]          = "1";
+            lut["lhs_start_ind_1"] = "0";
+            lut["lhs_start_ind_0"] = "0";
+        }
+        else if(is_broadcast_y && is_broadcast_z) // Broadcast in Y and Z: collapsed lhs win [M0xN0] = [1xN]
+        {
+            lut["lhs_m0"]          = "1";
+            lut["lhs_n0"]          = "N0";
+            lut["lhs_start_ind_1"] = "0";
+            lut["lhs_start_ind_0"] = "g_ind_0";
+        }
+        else
+        {
+            ARM_COMPUTE_ERROR("Only support lhs broadcasting in all X, Y, Z dimensions, or just in Y and Z dimensions");
+        }
+    }
+    else
+    {
+        lut["lhs_m0"]          = "M0";
+        lut["lhs_n0"]          = "N0";
+        lut["lhs_start_ind_1"] = "g_ind_1";
+        lut["lhs_start_ind_0"] = "g_ind_0";
+    }
+
+    if(_rhs->tensor_shape() != _dst->tensor_shape())
+    {
+        const auto is_broadcast_x = _rhs->dimension(0) == 1U && _dst->dimension(0) != 1U;
+        const auto is_broadcast_y = _rhs->dimension(1) == 1U && _dst->dimension(1) != 1U;
+        const auto is_broadcast_z = _rhs->dimension(2) == 1U && _dst->dimension(2) != 1U;
+
+        // Note that n0 maps to input tensor dimension 0, m0 maps to input dimensions 1 and 2 because of our collapse strategy
+        if(is_broadcast_x && is_broadcast_y && is_broadcast_z) // Broadcast in X, Y, Z: collapsed rhs win [M0xN0] = [1x1]
         {
             lut["rhs_m0"]          = "1";
             lut["rhs_n0"]          = "1";
             lut["rhs_start_ind_1"] = "0";
             lut["rhs_start_ind_0"] = "0";
         }
-        else if(operand->dimension(1) == 1U && operand->dimension(2) == 1U) // Broadcast in Y and Z: collapsed rhs win [M0xN0] = [1xN]
+        else if(is_broadcast_y && is_broadcast_z) // Broadcast in Y and Z: collapsed rhs win [M0xN0] = [1xN]
         {
             lut["rhs_m0"]          = "1";
             lut["rhs_n0"]          = "N0";
@@ -258,6 +277,7 @@
         lut["rhs_start_ind_1"] = "g_ind_1";
         lut["rhs_start_ind_0"] = "g_ind_0";
     }
+
     return lut;
 }
 
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.cpp
index 05bdd27..8f1ed95 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.cpp
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.cpp
@@ -190,21 +190,21 @@
 void ClTemplateLogits1DMaxShiftExpSum::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
 {
     vtable.declare_variable(
+        comp_group,
         _src,
         GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
-        comp_group.is_intermediate_tensor(_src),
         "src");
 
     vtable.declare_variable(
+        comp_group,
         _sum,
         GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
-        comp_group.is_intermediate_tensor(_sum),
         "sum");
 
     vtable.declare_variable(
+        comp_group,
         _dst,
         GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
-        comp_group.is_intermediate_tensor(_dst),
         "dst");
 }
 
@@ -274,4 +274,4 @@
 
 } // namespace dynamic_fusion
 } // namespace experimental
-} // namespace arm_compute
\ No newline at end of file
+} // namespace arm_compute
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateResize.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateResize.cpp
index a2c04d9..bcb6492 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateResize.cpp
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateResize.cpp
@@ -54,7 +54,6 @@
 
     std::string code = R"_(
 //------------------ START KERNEL {{meta_kernel_id}} ---------------------
-TILE({{DST_DATA_TYPE}}, 1, N0, {{dst}});
 TILE(uint, 1, 1, g_dst_indirect_y);
 {
     const int yo = g_ind_2 % {{arg_dst}}_h;
@@ -180,15 +179,15 @@
 void ClTemplateResize::declare_variables(GpuKernelVariableTable &vtable, const IGpuTemplateComponentWriter::ComponentGroup &comp_group) const
 {
     vtable.declare_variable(
+        comp_group,
         _src,
         GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
-        comp_group.is_intermediate_tensor(_src),
         "src");
 
     vtable.declare_variable(
+        comp_group,
         _dst,
         GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
-        comp_group.is_intermediate_tensor(_dst),
         "dst");
 }
 
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateStore.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateStore.cpp
index ef4f2f2..217214c 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateStore.cpp
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateStore.cpp
@@ -62,14 +62,14 @@
 void ClTemplateStore::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
 {
     vtable.declare_variable(
+        comp_group,
         _src,
         GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
-        comp_group.is_intermediate_tensor(_src),
         "src");
     vtable.declare_variable(
+        comp_group,
         _dst,
         GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_4D_t_Buffer),
-        comp_group.is_intermediate_tensor(_dst),
         "dst");
 }
 
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp
index eed481f..2ab6316 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateWriter.cpp
@@ -191,6 +191,26 @@
     code += write_global_section();
     code += "    //------------------ END KERNEL_BUILDER_COORDINATE ---------------------\n";
 
+    {
+        const auto tiles = _components.get_tiles();
+        std::stringstream tiles_ss;
+
+        tiles_ss << "    //------------------ START TILE DECLARATION ---------------------\n";
+
+        for(auto tile : tiles)
+        {
+            const auto var = _vtable.get_variable(tile);
+            const auto data_type = get_cl_type_from_data_type(tile->data_type());
+            const auto var_name = var.uniq_name;
+
+            tiles_ss << "    TILE(" << data_type << ", M0, N0, " << var_name << ");\n";
+        }
+
+        tiles_ss << "    //------------------ END TILE DECLARATION ---------------------\n";
+
+        code += tiles_ss.str();
+    }
+
     for(const auto &component_code : component_codes)
     {
         code += component_code;