blob: 1caa354d4d307c599cc7ecd9dc3ed00b94b47dba [file] [log] [blame]
Ryan OShea49ed0df2022-09-21 16:09:41 +01001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Ryan OShea49ed0df2022-09-21 16:09:41 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyan11572322023-03-16 10:17:51 +00008#include <ClassicDelegateUtils.hpp>
9
Ryan OShea49ed0df2022-09-21 16:09:41 +010010#include <algorithm>
11#include <iterator>
12#include <string>
13#include <vector>
14
15namespace armnnDelegate
16{
17 TfLiteStatus VisitBatchMatMulOperator(DelegateData& delegateData,
18 TfLiteContext* tfLiteContext,
19 TfLiteNode* tfLiteNode,
20 int nodeIndex,
21 int32_t operatorCode)
22 {
23 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
24 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25
26 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
27 const TfLiteTensor& kTfLiteLHSInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
28 const TfLiteTensor& kTfLiteRHSInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
29
30 if (!IsValid(tfLiteContext, kTfLiteLHSInputTensor, operatorCode, nodeIndex))
31 {
32 return kTfLiteError;
33 }
34 if (!IsValid(tfLiteContext, kTfLiteRHSInputTensor, operatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38
39 if (IsDynamicTensor(kTfLiteLHSInputTensor) || IsDynamicTensor(kTfLiteRHSInputTensor))
40 {
41 TF_LITE_MAYBE_KERNEL_LOG(
42 tfLiteContext,
43 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
44 operatorCode, nodeIndex);
45 return kTfLiteError;
46 }
47
48 const TfLiteTensor& kTfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
49 if (IsDynamicTensor(kTfLiteOutputTensor))
50 {
51 TF_LITE_MAYBE_KERNEL_LOG(
52 tfLiteContext,
53 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
54 operatorCode, nodeIndex);
55 return kTfLiteError;
56 }
57
58 const armnn::TensorInfo& armnnLHSInputTensorInfo = GetTensorInfoForTfLiteTensor(kTfLiteLHSInputTensor);
59 const armnn::TensorInfo& armnnRHSInputTensorInfo = GetTensorInfoForTfLiteTensor(kTfLiteRHSInputTensor);
60 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(kTfLiteOutputTensor, true);
61
62 armnn::BatchMatMulDescriptor descriptor;
63 auto* params = reinterpret_cast<TfLiteBatchMatMulParams *>(tfLiteNode->builtin_data);
64
65 // Tensorflow params are called adjoint, however they are actually just transposes behind the scene. They do
66 // not perform ajoint.
67 descriptor.m_TransposeX = params->adj_x;
68 descriptor.m_TransposeY = params->adj_y;
69
70 // Check if supported
71 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +010072 armnn::BackendId setBackend;
Ryan OShea49ed0df2022-09-21 16:09:41 +010073 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
74 {
75 FORWARD_LAYER_SUPPORT_FUNC("BATCH_MATMUL",
76 tfLiteContext,
77 IsBatchMatMulSupported,
78 delegateData.m_Backends,
79 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010080 setBackend,
Ryan OShea49ed0df2022-09-21 16:09:41 +010081 armnnLHSInputTensorInfo,
82 armnnRHSInputTensorInfo,
83 outputTensorInfo,
84 descriptor);
85 };
86
87 if (!delegateData.m_Network)
88 {
89 validateFunc(outputTensorInfo, isSupported);
90 return isSupported ? kTfLiteOk : kTfLiteError;
91 }
92
Mike Kelly07169c82023-08-02 13:23:09 +010093 auto layerName = GetLayerName(armnn::LayerType::BatchMatMul, nodeIndex);
94 armnn::IConnectableLayer* layer = delegateData.m_Network->AddBatchMatMulLayer(descriptor, layerName.c_str());
Cathal Corbett53837672022-09-01 11:34:37 +010095 layer->SetBackendId(setBackend);
Ryan OShea49ed0df2022-09-21 16:09:41 +010096 ARMNN_ASSERT(layer != nullptr);
97
98 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
99 outputSlot.SetTensorInfo(outputTensorInfo);
Ryan OShea49ed0df2022-09-21 16:09:41 +0100100
Ryan OShea4c231de2023-01-17 15:19:20 +0000101 // try to connect the Constant Inputs if there are any
Mike Kelly07169c82023-08-02 13:23:09 +0100102 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Ryan OShea4c231de2023-01-17 15:19:20 +0000103 {
104 return kTfLiteError;
105 }
106
107 return Connect(layer, tfLiteNode, delegateData);
Ryan OShea49ed0df2022-09-21 16:09:41 +0100108 }
Cathal Corbett53837672022-09-01 11:34:37 +0100109} // namespace armnnDelegate