blob: f4242ec8a432323d081fe38f94877c002b1cffc4 [file] [log] [blame]
James Conroy4f1f8992020-04-29 20:01:10 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/TypesUtils.hpp>
9
10#include <backendsCommon/Workload.hpp>
11#include <backendsCommon/WorkloadData.hpp>
12
13namespace armnn
14{
15
16class RefQLstmWorkload : public BaseWorkload<QLstmQueueDescriptor>
17{
18public:
19 explicit RefQLstmWorkload(const QLstmQueueDescriptor& descriptor, const WorkloadInfo& info);
20
Finn Williamsb8181f72021-04-07 10:23:21 +010021 void Execute() const override;
22 void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
James Conroy4f1f8992020-04-29 20:01:10 +010023
24private:
Finn Williamsb8181f72021-04-07 10:23:21 +010025 void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const;
James Conroy4f1f8992020-04-29 20:01:10 +010026 std::unique_ptr<ScopedCpuTensorHandle> m_InputToInputWeightsTensor;
27 std::unique_ptr<ScopedCpuTensorHandle> m_InputToForgetWeightsTensor;
28 std::unique_ptr<ScopedCpuTensorHandle> m_InputToCellWeightsTensor;
29 std::unique_ptr<ScopedCpuTensorHandle> m_InputToOutputWeightsTensor;
30
31 std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToInputWeightsTensor;
32 std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToForgetWeightsTensor;
33 std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToCellWeightsTensor;
34 std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToOutputWeightsTensor;
35
36 std::unique_ptr<ScopedCpuTensorHandle> m_CellToInputWeightsTensor;
37 std::unique_ptr<ScopedCpuTensorHandle> m_CellToForgetWeightsTensor;
38 std::unique_ptr<ScopedCpuTensorHandle> m_CellToOutputWeightsTensor;
39
40 std::unique_ptr<ScopedCpuTensorHandle> m_InputGateBiasTensor;
41 std::unique_ptr<ScopedCpuTensorHandle> m_ForgetGateBiasTensor;
42 std::unique_ptr<ScopedCpuTensorHandle> m_CellBiasTensor;
43 std::unique_ptr<ScopedCpuTensorHandle> m_OutputGateBiasTensor;
44
45 std::unique_ptr<ScopedCpuTensorHandle> m_ProjectionWeightsTensor;
46 std::unique_ptr<ScopedCpuTensorHandle> m_ProjectionBiasTensor;
47
48 std::unique_ptr<ScopedCpuTensorHandle> m_InputLayerNormWeightsTensor;
49 std::unique_ptr<ScopedCpuTensorHandle> m_ForgetLayerNormWeightsTensor;
50 std::unique_ptr<ScopedCpuTensorHandle> m_CellLayerNormWeightsTensor;
51 std::unique_ptr<ScopedCpuTensorHandle> m_OutputLayerNormWeightsTensor;
52
53 float m_LayerNormEpsilon = static_cast<float>(1e-8);
54};
55
56} //namespace armnn