blob: 16457820c4ef383d1f72d5974db10662dce4b2d0 [file] [log] [blame]
//
// Copyright © 2018-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include <armnn/TypesUtils.hpp>
#include "RefBaseWorkload.hpp"
namespace armnn
{
template <armnn::DataType DataType>
class RefDebugWorkload : public TypedWorkload<DebugQueueDescriptor, DataType>
{
public:
RefDebugWorkload(const DebugQueueDescriptor& descriptor, const WorkloadInfo& info)
: TypedWorkload<DebugQueueDescriptor, DataType>(descriptor, info)
, m_Callback(nullptr) {}
virtual const std::string& GetName() const override
{
static const std::string name = std::string("RefDebug") + GetDataTypeName(DataType) + "Workload";
return name;
}
using TypedWorkload<DebugQueueDescriptor, DataType>::m_Data;
using TypedWorkload<DebugQueueDescriptor, DataType>::TypedWorkload;
void Execute() const override;
void ExecuteAsync(ExecutionData& executionData) override;
void RegisterDebugCallback(const DebugCallbackFunction& func) override;
private:
void Execute(std::vector<ITensorHandle*> inputs) const;
DebugCallbackFunction m_Callback;
};
using RefDebugBFloat16Workload = RefDebugWorkload<DataType::BFloat16>;
using RefDebugFloat16Workload = RefDebugWorkload<DataType::Float16>;
using RefDebugFloat32Workload = RefDebugWorkload<DataType::Float32>;
using RefDebugQAsymmU8Workload = RefDebugWorkload<DataType::QAsymmU8>;
using RefDebugQAsymmS8Workload = RefDebugWorkload<DataType::QAsymmS8>;
using RefDebugQSymmS16Workload = RefDebugWorkload<DataType::QSymmS16>;
using RefDebugQSymmS8Workload = RefDebugWorkload<DataType::QSymmS8>;
using RefDebugSigned32Workload = RefDebugWorkload<DataType::Signed32>;
using RefDebugSigned64Workload = RefDebugWorkload<DataType::Signed64>;
using RefDebugBooleanWorkload = RefDebugWorkload<DataType::Boolean>;
} // namespace armnn