blob: 49fba052388f79b04f7ef1ed4bd24acc1ce3a6c0 [file] [log] [blame]
Ryan OShea49ed0df2022-09-21 16:09:41 +01001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Ryan OShea49ed0df2022-09-21 16:09:41 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "DelegateUtils.hpp"
9#include <algorithm>
10#include <iterator>
11#include <string>
12#include <vector>
13
14namespace armnnDelegate
15{
16 TfLiteStatus VisitBatchMatMulOperator(DelegateData& delegateData,
17 TfLiteContext* tfLiteContext,
18 TfLiteNode* tfLiteNode,
19 int nodeIndex,
20 int32_t operatorCode)
21 {
22 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
23 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
24
25 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
26 const TfLiteTensor& kTfLiteLHSInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
27 const TfLiteTensor& kTfLiteRHSInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
28
29 if (!IsValid(tfLiteContext, kTfLiteLHSInputTensor, operatorCode, nodeIndex))
30 {
31 return kTfLiteError;
32 }
33 if (!IsValid(tfLiteContext, kTfLiteRHSInputTensor, operatorCode, nodeIndex))
34 {
35 return kTfLiteError;
36 }
37
38 if (IsDynamicTensor(kTfLiteLHSInputTensor) || IsDynamicTensor(kTfLiteRHSInputTensor))
39 {
40 TF_LITE_MAYBE_KERNEL_LOG(
41 tfLiteContext,
42 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
43 operatorCode, nodeIndex);
44 return kTfLiteError;
45 }
46
47 const TfLiteTensor& kTfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
48 if (IsDynamicTensor(kTfLiteOutputTensor))
49 {
50 TF_LITE_MAYBE_KERNEL_LOG(
51 tfLiteContext,
52 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
53 operatorCode, nodeIndex);
54 return kTfLiteError;
55 }
56
57 const armnn::TensorInfo& armnnLHSInputTensorInfo = GetTensorInfoForTfLiteTensor(kTfLiteLHSInputTensor);
58 const armnn::TensorInfo& armnnRHSInputTensorInfo = GetTensorInfoForTfLiteTensor(kTfLiteRHSInputTensor);
59 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(kTfLiteOutputTensor, true);
60
61 armnn::BatchMatMulDescriptor descriptor;
62 auto* params = reinterpret_cast<TfLiteBatchMatMulParams *>(tfLiteNode->builtin_data);
63
64 // Tensorflow params are called adjoint, however they are actually just transposes behind the scene. They do
65 // not perform ajoint.
66 descriptor.m_TransposeX = params->adj_x;
67 descriptor.m_TransposeY = params->adj_y;
68
69 // Check if supported
70 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +010071 armnn::BackendId setBackend;
Ryan OShea49ed0df2022-09-21 16:09:41 +010072 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
73 {
74 FORWARD_LAYER_SUPPORT_FUNC("BATCH_MATMUL",
75 tfLiteContext,
76 IsBatchMatMulSupported,
77 delegateData.m_Backends,
78 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010079 setBackend,
Ryan OShea49ed0df2022-09-21 16:09:41 +010080 armnnLHSInputTensorInfo,
81 armnnRHSInputTensorInfo,
82 outputTensorInfo,
83 descriptor);
84 };
85
86 if (!delegateData.m_Network)
87 {
88 validateFunc(outputTensorInfo, isSupported);
89 return isSupported ? kTfLiteOk : kTfLiteError;
90 }
91
92 armnn::IConnectableLayer* layer = delegateData.m_Network->AddBatchMatMulLayer(descriptor);
Cathal Corbett53837672022-09-01 11:34:37 +010093 layer->SetBackendId(setBackend);
Ryan OShea49ed0df2022-09-21 16:09:41 +010094 ARMNN_ASSERT(layer != nullptr);
95
96 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
97 outputSlot.SetTensorInfo(outputTensorInfo);
Ryan OShea49ed0df2022-09-21 16:09:41 +010098
Ryan OShea4c231de2023-01-17 15:19:20 +000099 // try to connect the Constant Inputs if there are any
100 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
101 {
102 return kTfLiteError;
103 }
104
105 return Connect(layer, tfLiteNode, delegateData);
Ryan OShea49ed0df2022-09-21 16:09:41 +0100106 }
Cathal Corbett53837672022-09-01 11:34:37 +0100107} // namespace armnnDelegate