blob: 6561850d79c2a17353347e599d856a8686139a06 [file] [log] [blame]
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Bentham39ef3e52020-01-20 10:09:09 +00008#include <armnn/QuantizedLstmParams.hpp>
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01009#include <backendsCommon/Workload.hpp>
10#include <backendsCommon/WorkloadData.hpp>
11
Matthew Bentham39ef3e52020-01-20 10:09:09 +000012#include <arm_compute/runtime/CL/functions/CLLSTMLayerQuantized.h>
Ferran Balaguer737d9ff2019-08-01 09:58:08 +010013
14namespace armnn
15{
16
17arm_compute::Status ClQuantizedLstmWorkloadValidate(const TensorInfo& input, const TensorInfo& previousCellStateIn,
18 const TensorInfo& previousOutputIn, const TensorInfo& cellStateOut,
19 const TensorInfo& output,
20 const QuantizedLstmInputParamsInfo& paramsInfo);
21
22class ClQuantizedLstmWorkload : public BaseWorkload<QuantizedLstmQueueDescriptor>
23{
24public:
Sadik Armagane9444752020-12-02 11:28:58 +000025 ClQuantizedLstmWorkload(const QuantizedLstmQueueDescriptor& descriptor,
26 const WorkloadInfo& info,
27 const arm_compute::CLCompileContext& clCompileContext);
Ferran Balaguer737d9ff2019-08-01 09:58:08 +010028 void Execute() const override;
29
30private:
31 mutable arm_compute::CLLSTMLayerQuantized m_QuantizedLstmLayer;
32
33 std::unique_ptr<arm_compute::CLTensor> m_InputToInputWeightsTensor;
34 std::unique_ptr<arm_compute::CLTensor> m_InputToForgetWeightsTensor;
35 std::unique_ptr<arm_compute::CLTensor> m_InputToCellWeightsTensor;
36 std::unique_ptr<arm_compute::CLTensor> m_InputToOutputWeightsTensor;
37 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToInputWeightsTensor;
38 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToForgetWeightsTensor;
39 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToCellWeightsTensor;
40 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToOutputWeightsTensor;
41 std::unique_ptr<arm_compute::CLTensor> m_InputGateBiasTensor;
42 std::unique_ptr<arm_compute::CLTensor> m_ForgetGateBiasTensor;
43 std::unique_ptr<arm_compute::CLTensor> m_CellBiasTensor;
44 std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor;
45
46 void FreeUnusedTensors();
47};
48
49} //namespace armnn
50
51