blob: f4a09d4aed6189561d49c3b30fa61480cdc1c108 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#pragma once
7
telsoa01c577f2c2018-08-31 09:22:23 +01008#include <arm_compute/core/Types.h>
telsoa014fcda012018-03-09 14:13:49 +00009#include <backends/ArmComputeTensorUtils.hpp>
10#include <backends/CpuTensorHandle.hpp>
11#include <backends/NeonTensorHandle.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010012#include <backends/NeonWorkloadUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013#include <backends/Workload.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010014#include <Half.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/cast.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010017#include "Half.hpp"
telsoa014fcda012018-03-09 14:13:49 +000018
19namespace armnn
20{
21
telsoa01c577f2c2018-08-31 09:22:23 +010022// Base class template providing an implementation of the Constant layer common to all data types.
23template <armnn::DataType... DataFormats>
24class NeonBaseConstantWorkload : public TypedWorkload<ConstantQueueDescriptor, DataFormats...>
telsoa014fcda012018-03-09 14:13:49 +000025{
26public:
27 NeonBaseConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info)
telsoa01c577f2c2018-08-31 09:22:23 +010028 : TypedWorkload<ConstantQueueDescriptor, DataFormats...>(descriptor, info)
telsoa014fcda012018-03-09 14:13:49 +000029 , m_RanOnce(false)
30 {
31 }
32
33 virtual void Execute() const override
34 {
35 using namespace armcomputetensorutils;
36
37 // The intermediate tensor held by the corresponding layer output handler can be initialised with the
38 // given data on the first inference, then reused for subsequent inferences.
39 // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer
40 // may not have been configured at the time.
41 if (!m_RanOnce)
42 {
43 const ConstantQueueDescriptor& data = this->m_Data;
44
45 BOOST_ASSERT(data.m_LayerOutput != nullptr);
46 arm_compute::ITensor& output =
47 boost::polymorphic_downcast<NeonTensorHandle*>(data.m_Outputs[0])->GetTensor();
telsoa01c577f2c2018-08-31 09:22:23 +010048 arm_compute::DataType computeDataType =
49 boost::polymorphic_downcast<NeonTensorHandle*>(data.m_Outputs[0])->GetDataType();
telsoa014fcda012018-03-09 14:13:49 +000050
telsoa01c577f2c2018-08-31 09:22:23 +010051 switch (computeDataType)
telsoa014fcda012018-03-09 14:13:49 +000052 {
telsoa01c577f2c2018-08-31 09:22:23 +010053 case arm_compute::DataType::F16:
54 {
55 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<Half>(), output);
56 break;
57 }
58 case arm_compute::DataType::F32:
telsoa014fcda012018-03-09 14:13:49 +000059 {
60 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<float>(), output);
61 break;
62 }
telsoa01c577f2c2018-08-31 09:22:23 +010063 case arm_compute::DataType::QASYMM8:
telsoa014fcda012018-03-09 14:13:49 +000064 {
65 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<uint8_t>(), output);
66 break;
67 }
68 default:
69 {
70 BOOST_ASSERT_MSG(false, "Unknown data type");
71 break;
72 }
73 }
74
75 m_RanOnce = true;
76 }
77 }
78
79private:
80 mutable bool m_RanOnce;
81};
82
83} //namespace armnn