blob: 94b25fe7b5760a2a2b1a8f23f462bf411cd9eb35 [file] [log] [blame]
Ryan OShea49ed0df2022-09-21 16:09:41 +01001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Ryan OShea49ed0df2022-09-21 16:09:41 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyan11572322023-03-16 10:17:51 +00008#include <ClassicDelegateUtils.hpp>
9
Ryan OShea49ed0df2022-09-21 16:09:41 +010010#include <algorithm>
11#include <iterator>
12#include <string>
13#include <vector>
14
15namespace armnnDelegate
16{
17 TfLiteStatus VisitBatchMatMulOperator(DelegateData& delegateData,
18 TfLiteContext* tfLiteContext,
19 TfLiteNode* tfLiteNode,
20 int nodeIndex,
21 int32_t operatorCode)
22 {
23 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
24 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25
26 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
27 const TfLiteTensor& kTfLiteLHSInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
28 const TfLiteTensor& kTfLiteRHSInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
29
30 if (!IsValid(tfLiteContext, kTfLiteLHSInputTensor, operatorCode, nodeIndex))
31 {
32 return kTfLiteError;
33 }
34 if (!IsValid(tfLiteContext, kTfLiteRHSInputTensor, operatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38
39 if (IsDynamicTensor(kTfLiteLHSInputTensor) || IsDynamicTensor(kTfLiteRHSInputTensor))
40 {
41 TF_LITE_MAYBE_KERNEL_LOG(
42 tfLiteContext,
43 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
44 operatorCode, nodeIndex);
45 return kTfLiteError;
46 }
47
48 const TfLiteTensor& kTfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
49 if (IsDynamicTensor(kTfLiteOutputTensor))
50 {
51 TF_LITE_MAYBE_KERNEL_LOG(
52 tfLiteContext,
53 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
54 operatorCode, nodeIndex);
55 return kTfLiteError;
56 }
57
58 const armnn::TensorInfo& armnnLHSInputTensorInfo = GetTensorInfoForTfLiteTensor(kTfLiteLHSInputTensor);
59 const armnn::TensorInfo& armnnRHSInputTensorInfo = GetTensorInfoForTfLiteTensor(kTfLiteRHSInputTensor);
60 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(kTfLiteOutputTensor, true);
61
62 armnn::BatchMatMulDescriptor descriptor;
63 auto* params = reinterpret_cast<TfLiteBatchMatMulParams *>(tfLiteNode->builtin_data);
64
65 // Tensorflow params are called adjoint, however they are actually just transposes behind the scene. They do
66 // not perform ajoint.
67 descriptor.m_TransposeX = params->adj_x;
68 descriptor.m_TransposeY = params->adj_y;
69
70 // Check if supported
71 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +010072 armnn::BackendId setBackend;
Ryan OShea49ed0df2022-09-21 16:09:41 +010073 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
74 {
75 FORWARD_LAYER_SUPPORT_FUNC("BATCH_MATMUL",
76 tfLiteContext,
77 IsBatchMatMulSupported,
78 delegateData.m_Backends,
79 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010080 setBackend,
Ryan OShea49ed0df2022-09-21 16:09:41 +010081 armnnLHSInputTensorInfo,
82 armnnRHSInputTensorInfo,
83 outputTensorInfo,
84 descriptor);
85 };
86
87 if (!delegateData.m_Network)
88 {
89 validateFunc(outputTensorInfo, isSupported);
90 return isSupported ? kTfLiteOk : kTfLiteError;
91 }
92
93 armnn::IConnectableLayer* layer = delegateData.m_Network->AddBatchMatMulLayer(descriptor);
Cathal Corbett53837672022-09-01 11:34:37 +010094 layer->SetBackendId(setBackend);
Ryan OShea49ed0df2022-09-21 16:09:41 +010095 ARMNN_ASSERT(layer != nullptr);
96
97 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
98 outputSlot.SetTensorInfo(outputTensorInfo);
Ryan OShea49ed0df2022-09-21 16:09:41 +010099
Ryan OShea4c231de2023-01-17 15:19:20 +0000100 // try to connect the Constant Inputs if there are any
101 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
102 {
103 return kTfLiteError;
104 }
105
106 return Connect(layer, tfLiteNode, delegateData);
Ryan OShea49ed0df2022-09-21 16:09:41 +0100107 }
Cathal Corbett53837672022-09-01 11:34:37 +0100108} // namespace armnnDelegate