blob: e6b5f9e93402ecbdb44599998011bc116c2b9509 [file] [log] [blame]
//
// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include <Layer.hpp>
#include <armnn/backends/MemCopyWorkload.hpp>
#include <backendsCommon/MemImportWorkload.hpp>
#include <backendsCommon/MakeWorkloadHelper.hpp>
#include <armnn/backends/TensorHandle.hpp>
#include "TosaRefWorkloadFactory.hpp"
#include "TosaRefBackendId.hpp"
#include "workloads/TosaRefWorkloads.hpp"
#include "TosaRefTensorHandle.hpp"
#include "TosaRefWorkloadFactory.hpp"
namespace armnn
{
namespace
{
static const BackendId s_Id{TosaRefBackendId()};
}
template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
std::unique_ptr<IWorkload> TosaRefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
const WorkloadInfo& info) const
{
return MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload, NullWorkload, NullWorkload, NullWorkload>
(descriptor, info);
}
template <DataType ArmnnType>
bool IsDataType(const WorkloadInfo& info)
{
auto checkType = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == ArmnnType;};
auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkType);
if (it != std::end(info.m_InputTensorInfos))
{
return true;
}
it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkType);
if (it != std::end(info.m_OutputTensorInfos))
{
return true;
}
return false;
}
TosaRefWorkloadFactory::TosaRefWorkloadFactory(const std::shared_ptr<TosaRefMemoryManager>& memoryManager)
: m_MemoryManager(memoryManager)
{
}
TosaRefWorkloadFactory::TosaRefWorkloadFactory()
: m_MemoryManager(new TosaRefMemoryManager())
{
}
const BackendId& TosaRefWorkloadFactory::GetBackendId() const
{
return s_Id;
}
bool TosaRefWorkloadFactory::IsLayerSupported(const Layer& layer,
Optional<DataType> dataType,
std::string& outReasonIfUnsupported)
{
return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
}
bool TosaRefWorkloadFactory::IsLayerSupported(const IConnectableLayer& layer,
Optional<DataType> dataType,
std::string& outReasonIfUnsupported,
const ModelOptions& modelOptions)
{
return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported, modelOptions);
}
std::unique_ptr<ITensorHandle> TosaRefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
const bool isMemoryManaged) const
{
if (isMemoryManaged)
{
return std::make_unique<TosaRefTensorHandle>(tensorInfo, m_MemoryManager);
}
else
{
return std::make_unique<TosaRefTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
}
}
std::unique_ptr<ITensorHandle> TosaRefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
DataLayout dataLayout,
const bool isMemoryManaged) const
{
// For TosaRef it is okay to make the TensorHandle memory managed as it can also store a pointer
// to unmanaged memory. This also ensures memory alignment.
IgnoreUnused(isMemoryManaged, dataLayout);
if (isMemoryManaged)
{
return std::make_unique<TosaRefTensorHandle>(tensorInfo, m_MemoryManager);
}
else
{
return std::make_unique<TosaRefTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
}
}
std::unique_ptr<IWorkload> TosaRefWorkloadFactory::CreateWorkload(LayerType type,
const QueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
switch(type)
{
case LayerType::PreCompiled:
{
auto precompiledQueueDescriptor = PolymorphicDowncast<const PreCompiledQueueDescriptor*>(&descriptor);
return std::make_unique<TosaRefPreCompiledWorkload>(*precompiledQueueDescriptor, info);
}
default:
return nullptr;
}
}
} // namespace armnn