blob: 59afc7022aec76a55cbad17fe517898842197c0d [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#pragma once
7
telsoa01c577f2c2018-08-31 09:22:23 +01008#include <arm_compute/core/Types.h>
David Beck711fa312018-09-24 10:46:38 +01009#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000010#include <backends/CpuTensorHandle.hpp>
11#include <backends/NeonTensorHandle.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010012#include <backends/NeonWorkloadUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013#include <backends/Workload.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010014#include <Half.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/cast.hpp>
17
18namespace armnn
19{
20
telsoa01c577f2c2018-08-31 09:22:23 +010021// Base class template providing an implementation of the Constant layer common to all data types.
22template <armnn::DataType... DataFormats>
23class NeonBaseConstantWorkload : public TypedWorkload<ConstantQueueDescriptor, DataFormats...>
telsoa014fcda012018-03-09 14:13:49 +000024{
25public:
26 NeonBaseConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info)
telsoa01c577f2c2018-08-31 09:22:23 +010027 : TypedWorkload<ConstantQueueDescriptor, DataFormats...>(descriptor, info)
telsoa014fcda012018-03-09 14:13:49 +000028 , m_RanOnce(false)
29 {
30 }
31
32 virtual void Execute() const override
33 {
34 using namespace armcomputetensorutils;
35
36 // The intermediate tensor held by the corresponding layer output handler can be initialised with the
37 // given data on the first inference, then reused for subsequent inferences.
38 // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer
39 // may not have been configured at the time.
40 if (!m_RanOnce)
41 {
42 const ConstantQueueDescriptor& data = this->m_Data;
43
44 BOOST_ASSERT(data.m_LayerOutput != nullptr);
45 arm_compute::ITensor& output =
46 boost::polymorphic_downcast<NeonTensorHandle*>(data.m_Outputs[0])->GetTensor();
telsoa01c577f2c2018-08-31 09:22:23 +010047 arm_compute::DataType computeDataType =
48 boost::polymorphic_downcast<NeonTensorHandle*>(data.m_Outputs[0])->GetDataType();
telsoa014fcda012018-03-09 14:13:49 +000049
telsoa01c577f2c2018-08-31 09:22:23 +010050 switch (computeDataType)
telsoa014fcda012018-03-09 14:13:49 +000051 {
telsoa01c577f2c2018-08-31 09:22:23 +010052 case arm_compute::DataType::F16:
53 {
54 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<Half>(), output);
55 break;
56 }
57 case arm_compute::DataType::F32:
telsoa014fcda012018-03-09 14:13:49 +000058 {
59 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<float>(), output);
60 break;
61 }
telsoa01c577f2c2018-08-31 09:22:23 +010062 case arm_compute::DataType::QASYMM8:
telsoa014fcda012018-03-09 14:13:49 +000063 {
64 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<uint8_t>(), output);
65 break;
66 }
67 default:
68 {
69 BOOST_ASSERT_MSG(false, "Unknown data type");
70 break;
71 }
72 }
73
74 m_RanOnce = true;
75 }
76 }
77
78private:
79 mutable bool m_RanOnce;
80};
81
82} //namespace armnn