blob: 27144f24009f28534290547c9f309a749c1f7243 [file] [log] [blame]
Teresa Charlin0f86ecf2022-10-13 15:47:08 +01001//
Teresa Charlin1fe6c812022-11-01 15:59:50 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Teresa Charlin0f86ecf2022-10-13 15:47:08 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "NeonBaseWorkload.hpp"
9
Teresa Charlin1fe6c812022-11-01 15:59:50 +000010#include <arm_compute/runtime/NEON/functions/NEMatMul.h>
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010011
12namespace armnn
13{
Teresa Charlin1fe6c812022-11-01 15:59:50 +000014 arm_compute::Status NeonBatchMatMulValidate(const TensorInfo& inputInfoX,
15 const TensorInfo& inputInfoY,
16 const TensorInfo& outputInfo,
17 const BatchMatMulDescriptor& descriptor,
18 const bool isFastMathEnabled,
19 const ActivationDescriptor* activationDescriptor);
20
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010021
22 class NeonBatchMatMulWorkload : public NeonBaseWorkload<BatchMatMulQueueDescriptor>
23 {
24 public:
25 NeonBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
Teresa Charlin1fe6c812022-11-01 15:59:50 +000026 const WorkloadInfo& info,
27 const bool isFastMathEnabled);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010028 virtual void Execute() const override;
29
30 private:
Teresa Charlin1fe6c812022-11-01 15:59:50 +000031 mutable arm_compute::NEMatMul m_MatMulLayer;
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010032 };
33} //namespace armnn