blob: 5277efc9478afc8c9744289625384834de715965 [file] [log] [blame]
Teresa Charlin94916a52022-10-19 08:48:07 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ClBaseWorkload.hpp"
9
10#include <arm_compute/runtime/IFunction.h>
11#include <arm_compute/runtime/CL/CLTensor.h>
12#include <memory>
13
14namespace armnn
15{
16 arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
17 const TensorInfo& inputY,
18 const TensorInfo& output,
19 const BatchMatMulDescriptor& descriptor);
20
21 class ClBatchMatMulWorkload : public ClBaseWorkload<BatchMatMulQueueDescriptor>
22 {
23 public:
24 ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
25 const WorkloadInfo& info,
26 const arm_compute::CLCompileContext& clCompileContext);
27 virtual void Execute() const override;
28
29 private:
30 // ACL layers required to fully form a Batch Mat Mul layer.
31 std::unique_ptr<arm_compute::IFunction> m_GEMMLayer;
32 std::unique_ptr<arm_compute::IFunction> m_PermuteLayerX;
33 std::unique_ptr<arm_compute::IFunction> m_PermuteLayerY;
34
35 // Additional CL arm_compute::Tensors.
36 // Required to perform permutations.
37 arm_compute::CLTensor m_PermutedTensorX;
38 arm_compute::CLTensor m_PermutedTensorY;
39
40 };
41} //namespace armnn