blob: 580db490d68f412021c56e9d61fb2fb3dc7ff20a [file] [log] [blame]
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Bentham39ef3e52020-01-20 10:09:09 +00008#include <armnn/QuantizedLstmParams.hpp>
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01009#include <backendsCommon/Workload.hpp>
10#include <backendsCommon/WorkloadData.hpp>
11
Matthew Bentham39ef3e52020-01-20 10:09:09 +000012#include <arm_compute/runtime/CL/functions/CLLSTMLayerQuantized.h>
Ferran Balaguer737d9ff2019-08-01 09:58:08 +010013
14namespace armnn
15{
16
17arm_compute::Status ClQuantizedLstmWorkloadValidate(const TensorInfo& input, const TensorInfo& previousCellStateIn,
18 const TensorInfo& previousOutputIn, const TensorInfo& cellStateOut,
19 const TensorInfo& output,
20 const QuantizedLstmInputParamsInfo& paramsInfo);
21
22class ClQuantizedLstmWorkload : public BaseWorkload<QuantizedLstmQueueDescriptor>
23{
24public:
25 ClQuantizedLstmWorkload(const QuantizedLstmQueueDescriptor& descriptor, const WorkloadInfo& info);
26 void Execute() const override;
27
28private:
29 mutable arm_compute::CLLSTMLayerQuantized m_QuantizedLstmLayer;
30
31 std::unique_ptr<arm_compute::CLTensor> m_InputToInputWeightsTensor;
32 std::unique_ptr<arm_compute::CLTensor> m_InputToForgetWeightsTensor;
33 std::unique_ptr<arm_compute::CLTensor> m_InputToCellWeightsTensor;
34 std::unique_ptr<arm_compute::CLTensor> m_InputToOutputWeightsTensor;
35 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToInputWeightsTensor;
36 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToForgetWeightsTensor;
37 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToCellWeightsTensor;
38 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToOutputWeightsTensor;
39 std::unique_ptr<arm_compute::CLTensor> m_InputGateBiasTensor;
40 std::unique_ptr<arm_compute::CLTensor> m_ForgetGateBiasTensor;
41 std::unique_ptr<arm_compute::CLTensor> m_CellBiasTensor;
42 std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor;
43
44 void FreeUnusedTensors();
45};
46
47} //namespace armnn
48
49