blob: c7d83755c720f7021922acb5471221508f60567f [file] [log] [blame]
Ferran Balaguer737d9ff2019-08-01 09:58:08 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <backendsCommon/Workload.hpp>
9#include <backendsCommon/WorkloadData.hpp>
10
11#include <arm_compute/runtime/CL/CLFunctions.h>
12
13namespace armnn
14{
15
16arm_compute::Status ClQuantizedLstmWorkloadValidate(const TensorInfo& input, const TensorInfo& previousCellStateIn,
17 const TensorInfo& previousOutputIn, const TensorInfo& cellStateOut,
18 const TensorInfo& output,
19 const QuantizedLstmInputParamsInfo& paramsInfo);
20
21class ClQuantizedLstmWorkload : public BaseWorkload<QuantizedLstmQueueDescriptor>
22{
23public:
24 ClQuantizedLstmWorkload(const QuantizedLstmQueueDescriptor& descriptor, const WorkloadInfo& info);
25 void Execute() const override;
26
27private:
28 mutable arm_compute::CLLSTMLayerQuantized m_QuantizedLstmLayer;
29
30 std::unique_ptr<arm_compute::CLTensor> m_InputToInputWeightsTensor;
31 std::unique_ptr<arm_compute::CLTensor> m_InputToForgetWeightsTensor;
32 std::unique_ptr<arm_compute::CLTensor> m_InputToCellWeightsTensor;
33 std::unique_ptr<arm_compute::CLTensor> m_InputToOutputWeightsTensor;
34 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToInputWeightsTensor;
35 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToForgetWeightsTensor;
36 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToCellWeightsTensor;
37 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToOutputWeightsTensor;
38 std::unique_ptr<arm_compute::CLTensor> m_InputGateBiasTensor;
39 std::unique_ptr<arm_compute::CLTensor> m_ForgetGateBiasTensor;
40 std::unique_ptr<arm_compute::CLTensor> m_CellBiasTensor;
41 std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor;
42
43 void FreeUnusedTensors();
44};
45
46} //namespace armnn
47
48