blob: 19d3a2af0f10eb3cdb72412b14269580e45cdec4 [file] [log] [blame]
James Conroy4f1f8992020-04-29 20:01:10 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/TypesUtils.hpp>
9
10#include <backendsCommon/Workload.hpp>
11#include <backendsCommon/WorkloadData.hpp>
12
13namespace armnn
14{
15
16class RefQLstmWorkload : public BaseWorkload<QLstmQueueDescriptor>
17{
18public:
19 explicit RefQLstmWorkload(const QLstmQueueDescriptor& descriptor, const WorkloadInfo& info);
20
21 virtual void Execute() const override;
22
23private:
24 std::unique_ptr<ScopedCpuTensorHandle> m_InputToInputWeightsTensor;
25 std::unique_ptr<ScopedCpuTensorHandle> m_InputToForgetWeightsTensor;
26 std::unique_ptr<ScopedCpuTensorHandle> m_InputToCellWeightsTensor;
27 std::unique_ptr<ScopedCpuTensorHandle> m_InputToOutputWeightsTensor;
28
29 std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToInputWeightsTensor;
30 std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToForgetWeightsTensor;
31 std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToCellWeightsTensor;
32 std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToOutputWeightsTensor;
33
34 std::unique_ptr<ScopedCpuTensorHandle> m_CellToInputWeightsTensor;
35 std::unique_ptr<ScopedCpuTensorHandle> m_CellToForgetWeightsTensor;
36 std::unique_ptr<ScopedCpuTensorHandle> m_CellToOutputWeightsTensor;
37
38 std::unique_ptr<ScopedCpuTensorHandle> m_InputGateBiasTensor;
39 std::unique_ptr<ScopedCpuTensorHandle> m_ForgetGateBiasTensor;
40 std::unique_ptr<ScopedCpuTensorHandle> m_CellBiasTensor;
41 std::unique_ptr<ScopedCpuTensorHandle> m_OutputGateBiasTensor;
42
43 std::unique_ptr<ScopedCpuTensorHandle> m_ProjectionWeightsTensor;
44 std::unique_ptr<ScopedCpuTensorHandle> m_ProjectionBiasTensor;
45
46 std::unique_ptr<ScopedCpuTensorHandle> m_InputLayerNormWeightsTensor;
47 std::unique_ptr<ScopedCpuTensorHandle> m_ForgetLayerNormWeightsTensor;
48 std::unique_ptr<ScopedCpuTensorHandle> m_CellLayerNormWeightsTensor;
49 std::unique_ptr<ScopedCpuTensorHandle> m_OutputLayerNormWeightsTensor;
50
51 float m_LayerNormEpsilon = static_cast<float>(1e-8);
52};
53
54} //namespace armnn