blob: 9e1b4d897bb7b438c34c7f4cd54da458df5d753f [file] [log] [blame]
/*
* Copyright (c) 2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#if defined(ENABLE_EXPERIMENTAL_DYNAMIC_FUSION)
#include "src/gpu/cl/kernels/experimental/dynamic_fusion/ClCompositeKernel.h"
#include "src/core/utils/helpers/float_ops.h"
#include "src/gpu/cl/kernels/ClElementwiseKernel.h"
#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.h"
#include "tests/CL/CLAccessor.h"
#include "tests/framework/Macros.h"
#include "tests/framework/datasets/Datasets.h"
#include "tests/validation/Validation.h"
#include "tests/validation/reference/ConvolutionLayer.h"
#include "tests/validation/reference/ElementwiseOperations.h"
#include "tests/validation/reference/GEMM.h"
#include "tests/validation/reference/Permute.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "src/core/AccessWindowStatic.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
#include <chrono>
using namespace arm_compute::experimental::dynamic_fusion;
namespace arm_compute
{
namespace test
{
namespace validation
{
namespace
{
/** Macros which measures the wall clock time, and records it into a map measurement_map with name clock_name */
#define TICK(clock_name) \
auto clock_name##_tick = std::chrono::high_resolution_clock::now();
#define TOCK(clock_name, measurement_map) \
auto clock_name##_tock = std::chrono::high_resolution_clock::now(); \
measurement_map["\"" #clock_name "\""] = duration_cast<microseconds>(clock_name##_tock - clock_name##_tick);
#define TOCK_AVG(clock_name, measurement_map, num_iterations) \
auto clock_name##_tock = std::chrono::high_resolution_clock::now(); \
measurement_map["\"" #clock_name "\""] = duration_cast<microseconds>((clock_name##_tock - clock_name##_tick) / (num_iterations));
template <typename T, typename U>
void fill(U &&tensor, int seed)
{
static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type;
DistributionType distribution{ T(-1.0f), T(1.0f) };
library->fill(tensor, distribution, seed);
// Fill border with infinity in order to check the presence of NaN values (i.e. inf * 0)
DistributionType distribution_inf{ T(std::numeric_limits<float>::infinity()), T(std::numeric_limits<float>::infinity()) };
library->fill_borders_with_garbage(tensor, distribution_inf, seed);
}
} // namespace
TEST_SUITE(CL)
TEST_SUITE(UNIT)
TEST_SUITE(DYNAMIC_FUSION)
TEST_SUITE(ClCompositeKernel)
TEST_SUITE(Validate)
TEST_CASE(MoveNet_SubGraph_1_Gemm, framework::DatasetMode::ALL)
{
/* Computation:
* out = add(addend, gemm_native(lhs, rhs, bias)) (non-broadcast)
*/
const auto data_type = DataType::F32;
const auto m = 5U;
const auto n = 4U;
const auto k = 3U;
const auto t_lhs_shape = TensorShape(k, m);
const auto t_rhs_shape = TensorShape(n, k);
const auto t_dst_shape = TensorShape(n, m);
auto t_lhs_info = TensorInfo(t_lhs_shape, 1, data_type);
auto t_rhs_info = TensorInfo(t_rhs_shape, 1, data_type);
auto t_bias_info = TensorInfo(TensorShape(), 1, DataType::F32);
auto t_dst_info = TensorInfo(t_dst_shape, 1, data_type);
const ClTensorDescriptor t_lhs_desc{ &t_lhs_info };
const ClTensorDescriptor t_rhs_desc{ &t_rhs_info };
const ClTensorDescriptor t_bias_desc{ &t_bias_info };
const ClTensorDescriptor t_addend_desc{ &t_dst_info };
const ClTensorDescriptor t_dst_desc{ &t_dst_info };
ClKernelBlueprint bp;
ArgumentID tid_lhs;
ArgumentID tid_rhs;
ArgumentID tid_l0_bias = g_arg_placeholder;
ArgumentID tid_l1_addend;
ArgumentID tid_dst;
auto st = add_tensor_argument(bp, t_lhs_desc, tid_lhs);
st = add_tensor_argument(bp, t_rhs_desc, tid_rhs);
st = add_tensor_argument(bp, t_addend_desc, tid_l1_addend);
st = add_tensor_argument(bp, t_dst_desc, tid_dst);
const auto common_kernel_desc = ClKernelComponentDescriptor{};
const GemmNativeDescriptor gemm_native_desc{ 1.0, 1.0, m, n, k };
const GEMMKernelInfo gemm_info{ m, n, k, 0, false, false, false, false, ActivationLayerInfo{}, 1, 1, gemm_native_desc.lhs_info, gemm_native_desc.rhs_info, 0, 0 };
const EltwiseAddDescriptor eltwise_add_desc{ ConvertPolicy::WRAP };
const TileDescriptor store_tile_info{ Size2D(gemm_info.rhs_info.n0, gemm_info.lhs_info.m0), Size2D(gemm_info.n, gemm_info.m), ClippingStrategy::TOP_LEFT };
ArgumentID tid_acc;
st = add_tensor_intermed(bp, tid_acc);
st = add_kcomp_gemm_native(bp, common_kernel_desc, gemm_native_desc, tid_lhs, tid_rhs, tid_l0_bias, tid_acc);
st = add_kcomp_eltwise_add(bp, common_kernel_desc, EltwiseAddDescriptor{}, tid_l1_addend, tid_acc, tid_acc);
st = add_kcomp_store(bp, common_kernel_desc, tid_acc, tid_dst, StoreType::StoreBlockBoundaryAware);
ClKernelCode cl_code;
st = set_tile_info(bp, store_tile_info);
st = build(cl_code, ClCodeBuilderContext{ GpuInfo{ GPUTarget::G71 } }, bp);
ClExecutionDescriptor exec_desc{};
st = tune_static(exec_desc, cl_code);
CLScheduler::get().default_reinit();
ClCompositeKernel kernel;
kernel.configure(CLKernelLibrary::get().get_compile_context(), cl_code);
// Construct tensors
CLTensor t_lhs{};
CLTensor t_rhs{};
CLTensor t_l1_addend{};
CLTensor t_dst{};
// Init tensors
{
t_lhs.allocator()->init(t_lhs_info);
t_rhs.allocator()->init(t_rhs_info);
t_l1_addend.allocator()->init(t_dst_info);
t_dst.allocator()->init(t_dst_info);
}
// "Pack" tensors
TensorBinding tensors({ { tid_lhs, &t_lhs },
{ tid_rhs, &t_rhs },
{ tid_l1_addend, &t_l1_addend },
{ tid_dst, &t_dst }
});
// Allocate and fill tensors
{
t_lhs.allocator()->allocate();
t_rhs.allocator()->allocate();
t_l1_addend.allocator()->allocate();
t_dst.allocator()->allocate();
fill<float>(CLAccessor(t_lhs), 0);
fill<float>(CLAccessor(t_rhs), 1);
fill<float>(CLAccessor(t_l1_addend), 2);
}
CLScheduler::get().enqueue_op(kernel, tensors, exec_desc, true);
// Create reference
SimpleTensor<float> ref_t_lhs{ t_lhs_shape, data_type, 1 };
SimpleTensor<float> ref_t_rhs{ t_rhs_shape, data_type, 1 };
SimpleTensor<float> ref_t_bias_placeholder{ t_dst_shape, data_type, 1 };
SimpleTensor<float> ref_t_l1_addend{ t_dst_shape, data_type, 1 };
// Fill reference
fill<float>(ref_t_lhs, 0);
fill<float>(ref_t_rhs, 1);
fill<float>(ref_t_l1_addend, 2);
const auto ref_t_dst = reference::arithmetic_operation(
ArithmeticOperation::ADD,
ref_t_l1_addend,
reference::gemm(ref_t_lhs, ref_t_rhs, ref_t_bias_placeholder, gemm_native_desc.alpha, 0.f /* To disable bias */),
data_type,
eltwise_add_desc.convert_policy);
RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
validate(CLAccessor(t_dst), ref_t_dst, tolerance_f32);
}
TEST_CASE(MoveNet_SubGraph_1_DirectConv2d, framework::DatasetMode::ALL)
{
/* Computation:
* out = add(addend, direct_conv2d(lhs, rhs, bias)) (non-broadcast)
*/
ClCompositeKernel kernel{};
ClKernelBlueprint bp{};
ClKernelCode cl_code{};
ClExecutionDescriptor exec_desc{};
Status st{};
const auto data_type = DataType::F32;
const auto conv_info = PadStrideInfo(1U, 1U, 1U, 1U);
const auto width = 7U;
const auto height = 6U;
const auto IFM = 5U;
const auto OFM = 4U;
const auto kernel_sz = 3U;
const auto src_shape = TensorShape(IFM, width, height);
const auto wei_shape = TensorShape(IFM, kernel_sz, kernel_sz, OFM);
const auto bia_shape = TensorShape(OFM);
const auto dst_shape = TensorShape(OFM, width, height);
auto src_info = TensorInfo(src_shape, 1, data_type, DataLayout::NHWC);
auto wei_info = TensorInfo(wei_shape, 1, data_type, DataLayout::NHWC);
auto bia_info = TensorInfo(bia_shape, 1, data_type, DataLayout::NHWC);
auto dst_info = TensorInfo(dst_shape, 1, data_type, DataLayout::NHWC);
const auto src_desc = ClTensorDescriptor(&src_info);
const auto wei_desc = ClTensorDescriptor(&wei_info);
const auto bia_desc = ClTensorDescriptor(&bia_info);
const auto addend_desc = ClTensorDescriptor(&dst_info);
const auto dst_desc = ClTensorDescriptor(&dst_info);
const auto n0 = std::min(OFM, 4u);
const auto m0 = (OFM > 16) ? ((data_type == DataType::F32) ? 2U : 4U) : 1U;
const ClKernelComponentDescriptor common_kernel_desc{};
const DirectConvolutionDescriptor direct_conv2d_desc{ conv_info };
const EltwiseAddDescriptor eltwise_add_desc{ ConvertPolicy::WRAP };
const TileDescriptor store_tile_info{ Size2D(n0, m0), Size2D(width, height), ClippingStrategy::TOP_LEFT };
ArgumentID src_id{ g_arg_placeholder };
ArgumentID wei_id{ g_arg_placeholder };
ArgumentID bia_id{ g_arg_placeholder };
ArgumentID acc_id{ g_arg_placeholder };
ArgumentID addend_id{ g_arg_placeholder };
ArgumentID dst_id{ g_arg_placeholder };
st = add_tensor_argument(bp, src_desc, src_id);
st = add_tensor_argument(bp, wei_desc, wei_id);
st = add_tensor_argument(bp, bia_desc, bia_id);
st = add_tensor_intermed(bp, acc_id);
st = add_tensor_argument(bp, addend_desc, addend_id);
st = add_tensor_argument(bp, dst_desc, dst_id);
st = add_kcomp_direct_conv(bp, common_kernel_desc, direct_conv2d_desc, src_id, wei_id, bia_id, acc_id);
st = add_kcomp_eltwise_add(bp, common_kernel_desc, eltwise_add_desc, addend_id, acc_id, acc_id);
st = add_kcomp_store(bp, common_kernel_desc, acc_id, dst_id, StoreType::TStoreIndirectWidthSelect);
exec_desc.skip_sliding_window = true;
st = set_tile_info(bp, store_tile_info);
st = build(cl_code, ClCodeBuilderContext{ GpuInfo{ GPUTarget::G71 } }, bp);
st = tune_static(exec_desc, cl_code);
CLScheduler::get().default_reinit();
kernel.configure(CLKernelLibrary::get().get_compile_context(), cl_code);
// Construct tensors
CLTensor src{};
CLTensor wei{};
CLTensor bia{};
CLTensor addend{};
CLTensor dst{};
// Init tensors
src.allocator()->init(src_info);
wei.allocator()->init(wei_info);
bia.allocator()->init(bia_info);
addend.allocator()->init(dst_info);
dst.allocator()->init(dst_info);
// "Pack" tensors
TensorBinding tensors({ { src_id, &src },
{ wei_id, &wei },
{ bia_id, &bia },
{ addend_id, &addend },
{ dst_id, &dst }
});
// Allocate and fill tensors
src.allocator()->allocate();
wei.allocator()->allocate();
bia.allocator()->allocate();
addend.allocator()->allocate();
dst.allocator()->allocate();
fill<float>(CLAccessor(src), 0);
fill<float>(CLAccessor(wei), 1);
fill<float>(CLAccessor(bia), 2);
fill<float>(CLAccessor(addend), 3);
CLScheduler::get().enqueue_op(kernel, tensors, exec_desc, true);
// Create reference
SimpleTensor<float> ref_src_nhwc{ src_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC };
SimpleTensor<float> ref_wei_nhwc{ wei_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC };
SimpleTensor<float> ref_bia_nhwc{ bia_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC };
SimpleTensor<float> ref_addend_nhwc{ dst_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC };
// Fill reference
fill<float>(ref_src_nhwc, 0);
fill<float>(ref_wei_nhwc, 1);
fill<float>(ref_bia_nhwc, 2);
fill<float>(ref_addend_nhwc, 3);
auto ref_src = reference::permute(ref_src_nhwc, PermutationVector(1U, 2U, 0U));
auto ref_wei = reference::permute(ref_wei_nhwc, PermutationVector(1U, 2U, 0U));
auto ref_bia = reference::permute(ref_bia_nhwc, PermutationVector(1U, 2U, 0U));
auto ref_addend = reference::permute(ref_addend_nhwc, PermutationVector(1U, 2U, 0U));
TensorShape dst_shape_nchw{ dst_shape };
permute(dst_shape_nchw, PermutationVector(1U, 2U, 0U));
const auto ref_dst = reference::arithmetic_operation(
ArithmeticOperation::ADD,
ref_addend,
reference::convolution_layer<float>(ref_src, ref_wei, ref_bia, dst_shape_nchw, conv_info),
data_type,
eltwise_add_desc.convert_policy);
RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
validate(CLAccessor(dst), ref_dst, tolerance_f32);
}
TEST_SUITE_END() // Validate
TEST_SUITE(Benchmark)
TEST_CASE(MoveNet_SubGraph_1_Gemm, framework::DatasetMode::ALL)
{
using std::chrono::duration_cast;
using std::chrono::microseconds;
const int num_iterations = 200;
std::map<std::string, std::chrono::microseconds> measurements;
/* Computation:
* out = add(addend, gemm_native(lhs, rhs, bias))
*/
const auto data_type = DataType::F32;
const auto m = 12U * 12U;
const auto n = 64U;
const auto k = 384U;
const auto t_lhs_shape = TensorShape(k, m);
const auto t_rhs_shape = TensorShape(n, k);
const auto t_dst_shape = TensorShape(n, m);
auto t_lhs_info = TensorInfo(t_lhs_shape, 1, data_type);
auto t_rhs_info = TensorInfo(t_rhs_shape, 1, data_type);
auto t_bias_info = TensorInfo(TensorShape(), 1, data_type);
auto t_l0_dst_info = TensorInfo(t_dst_shape, 1, data_type); // Intermediate tensor for cond3
auto t_l1_rhs_info = TensorInfo(t_dst_shape, 1, data_type);
auto t_dst_info = TensorInfo(t_dst_shape, 1, data_type);
const auto common_kernel_desc = ClKernelComponentDescriptor{};
const GemmNativeDescriptor gemm_native_desc{ 1.0, 0.0, m, n, k };
const GEMMKernelInfo gemm_info{ m, n, k, 0, false, false, false, false, ActivationLayerInfo{}, 1, 1, gemm_native_desc.lhs_info, gemm_native_desc.rhs_info, 0, 0 };
const EltwiseAddDescriptor eltwise_add_desc{ ConvertPolicy::WRAP };
const TileDescriptor store_tile_info{ Size2D(gemm_info.rhs_info.n0, gemm_info.lhs_info.m0), Size2D(gemm_info.n, gemm_info.m), ClippingStrategy::TOP_LEFT };
// Create reference
SimpleTensor<float> ref_t_lhs{ t_lhs_shape, data_type, 1 };
SimpleTensor<float> ref_t_rhs{ t_rhs_shape, data_type, 1 };
SimpleTensor<float> ref_t_bias_placeholder{ t_dst_shape, data_type, 1 };
SimpleTensor<float> ref_t_l1_addend{ t_dst_shape, data_type, 1 };
// Fill reference
fill<float>(ref_t_lhs, 0);
fill<float>(ref_t_rhs, 1);
fill<float>(ref_t_l1_addend, 2);
const auto ref_t_dst = reference::arithmetic_operation(
ArithmeticOperation::ADD,
ref_t_l1_addend,
reference::gemm(ref_t_lhs, ref_t_rhs, ref_t_bias_placeholder, gemm_native_desc.alpha, 0.f /* To disable bias */),
data_type,
eltwise_add_desc.convert_policy);
CLScheduler::get().default_reinit();
/* Condition 0: Dynamic Fused Kernel */
CLTensor cond0_t_dst{};
{
TICK(cond0_0_startup_time);
ClKernelBlueprint bp;
ArgumentID tid_lhs;
ArgumentID tid_rhs;
ArgumentID tid_l0_bias = g_arg_placeholder;
ArgumentID tid_l1_addend;
ArgumentID tid_dst;
const ClTensorDescriptor t_lhs_desc{ &t_lhs_info };
const ClTensorDescriptor t_rhs_desc{ &t_rhs_info };
const ClTensorDescriptor t_bias_desc{ &t_bias_info };
const ClTensorDescriptor t_addend_desc{ &t_dst_info };
const ClTensorDescriptor t_dst_desc{ &t_dst_info };
ClKernelCode cl_code;
TICK(cond0_build_time)
auto st = add_tensor_argument(bp, t_lhs_desc, tid_lhs);
st = add_tensor_argument(bp, t_rhs_desc, tid_rhs);
st = add_tensor_argument(bp, t_addend_desc, tid_l1_addend);
st = add_tensor_argument(bp, t_dst_desc, tid_dst);
ArgumentID tid_acc;
st = add_tensor_intermed(bp, tid_acc);
st = add_kcomp_gemm_native(bp, common_kernel_desc, gemm_native_desc, tid_lhs, tid_rhs, tid_l0_bias, tid_acc);
st = add_kcomp_eltwise_add(bp, common_kernel_desc, EltwiseAddDescriptor{}, tid_l1_addend, tid_acc, tid_acc);
st = add_kcomp_store(bp, common_kernel_desc, tid_acc, tid_dst, StoreType::StoreBlockBoundaryAware);
st = set_tile_info(bp, store_tile_info);
st = build(cl_code, ClCodeBuilderContext{ GpuInfo{ GPUTarget::G71 } }, bp);
TOCK(cond0_build_time, measurements)
TICK(cond0_tune_time)
ClExecutionDescriptor exec_desc{};
st = tune_static(exec_desc, cl_code);
TOCK(cond0_tune_time, measurements)
TICK(cond0_configure_time)
ClCompositeKernel kernel;
kernel.configure(CLKernelLibrary::get().get_compile_context(), cl_code);
TOCK(cond0_configure_time, measurements)
// Construct tensors
CLTensor t_lhs{};
CLTensor t_rhs{};
CLTensor t_l1_addend{};
// Init tensors
{
t_lhs.allocator()->init(t_lhs_info);
t_rhs.allocator()->init(t_rhs_info);
t_l1_addend.allocator()->init(t_dst_info);
cond0_t_dst.allocator()->init(t_dst_info);
}
// Allocate tensors
{
t_lhs.allocator()->allocate();
t_rhs.allocator()->allocate();
t_l1_addend.allocator()->allocate();
cond0_t_dst.allocator()->allocate();
fill<float>(CLAccessor(t_lhs), 0);
fill<float>(CLAccessor(t_rhs), 1);
fill<float>(CLAccessor(t_l1_addend), 2);
}
// "Pack" tensors
TensorBinding tensors({ { tid_lhs, &t_lhs }, { tid_rhs, &t_rhs }, { tid_l1_addend, &t_l1_addend }, { tid_dst, &cond0_t_dst } });
CLScheduler::get().enqueue_op(kernel, tensors, exec_desc, true);
CLScheduler::get().sync();
TOCK(cond0_0_startup_time, measurements)
TICK(cond0_1_latency)
for(int i = 0; i < num_iterations; ++i)
{
CLScheduler::get().enqueue_op(kernel, tensors, exec_desc, true);
}
CLScheduler::get().sync();
TOCK_AVG(cond0_1_latency, measurements, num_iterations)
}
/* Condition 1: Dynamic Unfused Kernel */
/* Condition 2: Static Fused Kernel (current) */
CLTensor cond2_t_dst{};
{
TICK(cond2_0_startup_time);
arm_compute::opencl::kernels::ClGemmMatrixMultiplyNativeKernel l0_gemm_mm;
TICK(cond2_configure_time);
experimental::PostOpList<ITensorInfo *> post_ops;
post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo *>>(&t_dst_info, 1, eltwise_add_desc.convert_policy);
GEMMKernelInfo gemm_info{ m, n, k, 0, false, false, false, false, ActivationLayerInfo{}, 1, 1, gemm_native_desc.lhs_info, gemm_native_desc.rhs_info, 0, 0, post_ops };
l0_gemm_mm.configure(CLKernelLibrary::get().get_compile_context(), &t_lhs_info, &t_rhs_info, nullptr, &t_dst_info, gemm_native_desc.alpha, gemm_native_desc.beta, gemm_native_desc.lhs_info,
gemm_native_desc.rhs_info, gemm_info);
TOCK(cond2_configure_time, measurements);
// Construct tensors
CLTensor t_lhs{};
CLTensor t_rhs{};
CLTensor t_l1_addend{};
// Init tensors
{
t_lhs.allocator()->init(t_lhs_info);
t_rhs.allocator()->init(t_rhs_info);
t_l1_addend.allocator()->init(t_dst_info);
cond2_t_dst.allocator()->init(t_dst_info);
}
// Allocate tensors
{
t_lhs.allocator()->allocate();
t_rhs.allocator()->allocate();
t_l1_addend.allocator()->allocate();
cond2_t_dst.allocator()->allocate();
fill<float>(CLAccessor(t_lhs), 0);
fill<float>(CLAccessor(t_rhs), 1);
fill<float>(CLAccessor(t_l1_addend), 2);
}
// "Pack" tensors
ITensorPack tensors
{
{ ACL_SRC_0, &t_lhs },
{ ACL_SRC_1, &t_rhs },
{ EXPERIMENTAL_ACL_POST_OP_ARG_FIRST, &t_l1_addend },
{ ACL_DST, &cond2_t_dst },
};
CLScheduler::get().enqueue_op(l0_gemm_mm, tensors, true);
CLScheduler::get().sync();
TOCK(cond2_0_startup_time, measurements);
TICK(cond2_1_latency);
for(int i = 0; i < num_iterations; ++i)
{
CLScheduler::get().enqueue_op(l0_gemm_mm, tensors, true);
}
CLScheduler::get().sync();
TOCK_AVG(cond2_1_latency, measurements, num_iterations);
}
/* Condition 3: Static Unfused Kernel (current) */
CLTensor cond3_t_dst{};
{
TICK(cond3_0_startup_time);
arm_compute::opencl::kernels::ClGemmMatrixMultiplyNativeKernel l0_gemm_mm;
arm_compute::opencl::kernels::ClSaturatedArithmeticKernel l1_add;
TICK(cond3_configure_time);
GEMMKernelInfo gemm_info{ m, n, k, 0, false, false, false, false, ActivationLayerInfo{}, 1, 1, gemm_native_desc.lhs_info, gemm_native_desc.rhs_info, 0, 0 };
l0_gemm_mm.configure(CLKernelLibrary::get().get_compile_context(), &t_lhs_info, &t_rhs_info, nullptr, &t_l0_dst_info, gemm_native_desc.alpha, gemm_native_desc.beta, gemm_native_desc.lhs_info,
gemm_native_desc.rhs_info, gemm_info);
l1_add.configure(CLKernelLibrary::get().get_compile_context(), ArithmeticOperation::ADD, &t_l0_dst_info, &t_l1_rhs_info, &t_dst_info, eltwise_add_desc.convert_policy);
TOCK(cond3_configure_time, measurements);
// Construct tensors
CLTensor t_lhs{};
CLTensor t_rhs{};
CLTensor t_l0_dst{};
CLTensor t_l1_addend{};
// Init tensors
{
t_lhs.allocator()->init(t_lhs_info);
t_rhs.allocator()->init(t_rhs_info);
t_l0_dst.allocator()->init(t_l0_dst_info);
t_l1_addend.allocator()->init(t_dst_info);
cond3_t_dst.allocator()->init(t_dst_info);
}
// Allocate tensors
{
t_lhs.allocator()->allocate();
t_rhs.allocator()->allocate();
t_l0_dst.allocator()->allocate();
t_l1_addend.allocator()->allocate();
cond3_t_dst.allocator()->allocate();
fill<float>(CLAccessor(t_lhs), 0);
fill<float>(CLAccessor(t_rhs), 1);
fill<float>(CLAccessor(t_l1_addend), 2);
}
// "Pack" tensors
ITensorPack tensors_l0
{
{ ACL_SRC_0, &t_lhs },
{ ACL_SRC_1, &t_rhs },
{ ACL_DST, &t_l0_dst },
};
ITensorPack tensors_l1
{
{ ACL_SRC_0, &t_l0_dst },
{ ACL_SRC_1, &t_l1_addend },
{ ACL_DST, &cond3_t_dst },
};
CLScheduler::get().enqueue_op(l0_gemm_mm, tensors_l0, true);
CLScheduler::get().enqueue_op(l1_add, tensors_l1, true);
CLScheduler::get().sync();
TOCK(cond3_0_startup_time, measurements);
TICK(cond3_1_latency);
for(int i = 0; i < num_iterations; ++i)
{
CLScheduler::get().enqueue_op(l0_gemm_mm, tensors_l0, true);
CLScheduler::get().enqueue_op(l1_add, tensors_l1, true);
}
CLScheduler::get().sync();
TOCK_AVG(cond3_1_latency, measurements, num_iterations);
}
RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
std::cout << "cond0 validation: " << std::endl;
validate(CLAccessor(cond0_t_dst), ref_t_dst, tolerance_f32);
std::cout << "cond2 validation: " << std::endl;
validate(CLAccessor(cond2_t_dst), ref_t_dst, tolerance_f32);
std::cout << "cond3 validation: " << std::endl;
validate(CLAccessor(cond3_t_dst), ref_t_dst, tolerance_f32);
/* Report */
std::cout << "Performance comparison (gemm native + add)" << std::endl;
std::cout << "cond0: dynamic fusion module" << std::endl;
std::cout << "cond2: static fused with post ops" << std::endl;
std::cout << "cond3: static unfused" << std::endl;
for(auto m : measurements)
{
std::cout << m.first << ": " << m.second.count() << "us" << std::endl;
}
}
TEST_SUITE_END() // Benchmark
TEST_SUITE_END() // ClCompositeKernel
TEST_SUITE_END() // DYNAMIC_FUSION
TEST_SUITE_END() // UNIT
TEST_SUITE_END() // CL
} // namespace validation
} // namespace test
} // namespace arm_compute
#endif // defined(ENABLE_EXPERIMENTAL_DYNAMIC_FUSION)