blob: 5261fbd6c4071f77b9f2d28b380549871db79dd3 [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
John Mcloughlin0422cf22023-04-27 16:55:00 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12TfLiteStatus VisitBatchMatMulOperator(DelegateData& delegateData,
13 TfLiteOpaqueContext* tfLiteContext,
14 TfLiteOpaqueNode* tfLiteNode,
15 int nodeIndex,
16 int32_t operatorCode)
17{
18 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
19 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
20
21 // Gather input indices and use to get input tensor.
22 auto numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode);
23 const int* inputTensors;
24 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
25 {
26 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
27 tfLiteContext,
28 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
29 nodeIndex);
30 return kTfLiteError;
31 }
32
33 const TfLiteOpaqueTensor* kTfLiteLHSInputTensor =
34 TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
35 const TfLiteOpaqueTensor* kTfLiteRHSInputTensor =
36 TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
37
38 if (!IsValid(tfLiteContext, kTfLiteLHSInputTensor, operatorCode, nodeIndex))
39 {
40 return kTfLiteError;
41 }
42 if (!IsValid(tfLiteContext, kTfLiteRHSInputTensor, operatorCode, nodeIndex))
43 {
44 return kTfLiteError;
45 }
46
John Mcloughlin0422cf22023-04-27 16:55:00 +010047 // Gather output indices and use to get output tensors.
48 int numOutputs = 0;
49 const int* outputTensors;
50 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
51 {
52 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
53 tfLiteContext,
54 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
55 nodeIndex);
56 return kTfLiteError;
57 }
58
59 const TfLiteOpaqueTensor* kTfLiteOutputTensor =
60 TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
61 if (IsDynamicTensor(kTfLiteOutputTensor))
62 {
63 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
64 tfLiteContext,
65 "TfLiteArmnnOpaqueDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
66 operatorCode, nodeIndex);
67 return kTfLiteError;
68 }
69
70 const armnn::TensorInfo& armnnLHSInputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(kTfLiteLHSInputTensor);
71 const armnn::TensorInfo& armnnRHSInputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(kTfLiteRHSInputTensor);
72 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(kTfLiteOutputTensor, true);
73
74 armnn::BatchMatMulDescriptor descriptor;
75 auto* params = reinterpret_cast<TfLiteBatchMatMulParams *>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
76
77 // Tensorflow params are called adjoint, however they are actually just transposes behind the scene. They do
78 // not perform ajoint.
79 descriptor.m_TransposeX = params->adj_x;
80 descriptor.m_TransposeY = params->adj_y;
81
82 // Check if supported
83 bool isSupported = false;
84 armnn::BackendId setBackend;
85 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
86 {
87 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("BATCH_MATMUL",
88 tfLiteContext,
89 IsBatchMatMulSupported,
90 delegateData.m_Backends,
91 isSupported,
92 setBackend,
93 armnnLHSInputTensorInfo,
94 armnnRHSInputTensorInfo,
95 outputTensorInfo,
96 descriptor);
97 };
98
99 if (!delegateData.m_Network)
100 {
101 validateFunc(outputTensorInfo, isSupported);
102 return isSupported ? kTfLiteOk : kTfLiteError;
103 }
104
105 armnn::IConnectableLayer* layer = delegateData.m_Network->AddBatchMatMulLayer(descriptor);
106 layer->SetBackendId(setBackend);
107 ARMNN_ASSERT(layer != nullptr);
108
109 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
110 outputSlot.SetTensorInfo(outputTensorInfo);
111
112 // try to connect the Constant Inputs if there are any
113 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
114 {
115 return kTfLiteError;
116 }
117
118 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
119}
120
121} // namespace armnnOpaqueDelegate