blob: b9817ba1ea12853397b5c8fc79b1f4b6aa91eccf [file] [log] [blame]
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefDetectionPostProcessWorkload.hpp"
7
8#include "Decoders.hpp"
9#include "DetectionPostProcess.hpp"
10#include "Profiling.hpp"
11#include "RefWorkloadUtils.hpp"
12
13namespace armnn
14{
15
16RefDetectionPostProcessWorkload::RefDetectionPostProcessWorkload(
17 const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info)
18 : BaseWorkload<DetectionPostProcessQueueDescriptor>(descriptor, info),
19 m_Anchors(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Anchors))) {}
20
21void RefDetectionPostProcessWorkload::Execute() const
22{
23 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefDetectionPostProcessWorkload_Execute");
24
25 const TensorInfo& boxEncodingsInfo = GetTensorInfo(m_Data.m_Inputs[0]);
26 const TensorInfo& scoresInfo = GetTensorInfo(m_Data.m_Inputs[1]);
Matthew Bentham4cefc412019-06-18 16:14:34 +010027 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010028
29 const TensorInfo& detectionBoxesInfo = GetTensorInfo(m_Data.m_Outputs[0]);
30 const TensorInfo& detectionClassesInfo = GetTensorInfo(m_Data.m_Outputs[1]);
31 const TensorInfo& detectionScoresInfo = GetTensorInfo(m_Data.m_Outputs[2]);
32 const TensorInfo& numDetectionsInfo = GetTensorInfo(m_Data.m_Outputs[3]);
33
34 auto boxEncodings = MakeDecoder<float>(boxEncodingsInfo, m_Data.m_Inputs[0]->Map());
35 auto scores = MakeDecoder<float>(scoresInfo, m_Data.m_Inputs[1]->Map());
36 auto anchors = MakeDecoder<float>(anchorsInfo, m_Anchors->Map(false));
37
38 float* detectionBoxes = GetOutputTensorData<float>(0, m_Data);
39 float* detectionClasses = GetOutputTensorData<float>(1, m_Data);
40 float* detectionScores = GetOutputTensorData<float>(2, m_Data);
41 float* numDetections = GetOutputTensorData<float>(3, m_Data);
42
43 DetectionPostProcess(boxEncodingsInfo, scoresInfo, anchorsInfo,
44 detectionBoxesInfo, detectionClassesInfo,
45 detectionScoresInfo, numDetectionsInfo, m_Data.m_Parameters,
46 *boxEncodings, *scores, *anchors, detectionBoxes,
47 detectionClasses, detectionScores, numDetections);
48}
49
50} //namespace armnn