blob: 65d874b5f3e5bd0b0bdf5ee175b66279dfd6ce72 [file] [log] [blame]
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001//
Teresa Charlin588cbdf2022-01-19 15:55:37 +00002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Teresa Charlin588cbdf2022-01-19 15:55:37 +00008#include "ClBaseWorkload.hpp"
9
Matthew Bentham39ef3e52020-01-20 10:09:09 +000010#include <armnn/QuantizedLstmParams.hpp>
Colm Donelan0c479742021-12-10 12:43:54 +000011#include <armnn/backends/Workload.hpp>
12#include <armnn/backends/WorkloadData.hpp>
Ferran Balaguer737d9ff2019-08-01 09:58:08 +010013
Matthew Bentham39ef3e52020-01-20 10:09:09 +000014#include <arm_compute/runtime/CL/functions/CLLSTMLayerQuantized.h>
Ferran Balaguer737d9ff2019-08-01 09:58:08 +010015
16namespace armnn
17{
18
19arm_compute::Status ClQuantizedLstmWorkloadValidate(const TensorInfo& input, const TensorInfo& previousCellStateIn,
20 const TensorInfo& previousOutputIn, const TensorInfo& cellStateOut,
21 const TensorInfo& output,
22 const QuantizedLstmInputParamsInfo& paramsInfo);
23
Teresa Charlin588cbdf2022-01-19 15:55:37 +000024class ClQuantizedLstmWorkload : public ClBaseWorkload<QuantizedLstmQueueDescriptor>
Ferran Balaguer737d9ff2019-08-01 09:58:08 +010025{
26public:
Sadik Armagane9444752020-12-02 11:28:58 +000027 ClQuantizedLstmWorkload(const QuantizedLstmQueueDescriptor& descriptor,
28 const WorkloadInfo& info,
29 const arm_compute::CLCompileContext& clCompileContext);
Ferran Balaguer737d9ff2019-08-01 09:58:08 +010030 void Execute() const override;
31
32private:
33 mutable arm_compute::CLLSTMLayerQuantized m_QuantizedLstmLayer;
34
35 std::unique_ptr<arm_compute::CLTensor> m_InputToInputWeightsTensor;
36 std::unique_ptr<arm_compute::CLTensor> m_InputToForgetWeightsTensor;
37 std::unique_ptr<arm_compute::CLTensor> m_InputToCellWeightsTensor;
38 std::unique_ptr<arm_compute::CLTensor> m_InputToOutputWeightsTensor;
39 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToInputWeightsTensor;
40 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToForgetWeightsTensor;
41 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToCellWeightsTensor;
42 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToOutputWeightsTensor;
43 std::unique_ptr<arm_compute::CLTensor> m_InputGateBiasTensor;
44 std::unique_ptr<arm_compute::CLTensor> m_ForgetGateBiasTensor;
45 std::unique_ptr<arm_compute::CLTensor> m_CellBiasTensor;
46 std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor;
47
48 void FreeUnusedTensors();
49};
50
51} //namespace armnn
52
53