blob: 257c410d141fd3a9d3344073c32bf959283e1bd6 [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
John Mcloughlin0422cf22023-04-27 16:55:00 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12TfLiteStatus VisitBatchMatMulOperator(DelegateData& delegateData,
13 TfLiteOpaqueContext* tfLiteContext,
14 TfLiteOpaqueNode* tfLiteNode,
15 int nodeIndex,
16 int32_t operatorCode)
17{
18 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
19 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
20
21 // Gather input indices and use to get input tensor.
22 auto numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode);
23 const int* inputTensors;
24 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
25 {
26 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
27 tfLiteContext,
28 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
29 nodeIndex);
30 return kTfLiteError;
31 }
32
33 const TfLiteOpaqueTensor* kTfLiteLHSInputTensor =
34 TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
35 const TfLiteOpaqueTensor* kTfLiteRHSInputTensor =
36 TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
37
38 if (!IsValid(tfLiteContext, kTfLiteLHSInputTensor, operatorCode, nodeIndex))
39 {
40 return kTfLiteError;
41 }
42 if (!IsValid(tfLiteContext, kTfLiteRHSInputTensor, operatorCode, nodeIndex))
43 {
44 return kTfLiteError;
45 }
46
John Mcloughlin0422cf22023-04-27 16:55:00 +010047 // Gather output indices and use to get output tensors.
48 int numOutputs = 0;
49 const int* outputTensors;
50 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
51 {
52 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
53 tfLiteContext,
54 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
55 nodeIndex);
56 return kTfLiteError;
57 }
58
59 const TfLiteOpaqueTensor* kTfLiteOutputTensor =
60 TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
61 if (IsDynamicTensor(kTfLiteOutputTensor))
62 {
63 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
64 tfLiteContext,
65 "TfLiteArmnnOpaqueDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
66 operatorCode, nodeIndex);
67 return kTfLiteError;
68 }
69
70 const armnn::TensorInfo& armnnLHSInputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(kTfLiteLHSInputTensor);
71 const armnn::TensorInfo& armnnRHSInputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(kTfLiteRHSInputTensor);
72 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(kTfLiteOutputTensor, true);
73
74 armnn::BatchMatMulDescriptor descriptor;
75 auto* params = reinterpret_cast<TfLiteBatchMatMulParams *>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
76
77 // Tensorflow params are called adjoint, however they are actually just transposes behind the scene. They do
78 // not perform ajoint.
79 descriptor.m_TransposeX = params->adj_x;
80 descriptor.m_TransposeY = params->adj_y;
81
82 // Check if supported
83 bool isSupported = false;
84 armnn::BackendId setBackend;
85 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
86 {
87 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("BATCH_MATMUL",
88 tfLiteContext,
89 IsBatchMatMulSupported,
90 delegateData.m_Backends,
91 isSupported,
92 setBackend,
93 armnnLHSInputTensorInfo,
94 armnnRHSInputTensorInfo,
95 outputTensorInfo,
96 descriptor);
97 };
98
99 if (!delegateData.m_Network)
100 {
101 validateFunc(outputTensorInfo, isSupported);
102 return isSupported ? kTfLiteOk : kTfLiteError;
103 }
104
Mike Kellya2806502023-08-03 10:42:11 +0100105 auto layerName = GetName(armnn::LayerType::BatchMatMul, nodeIndex);
106 armnn::IConnectableLayer* layer = delegateData.m_Network->AddBatchMatMulLayer(descriptor, layerName.c_str());
John Mcloughlin0422cf22023-04-27 16:55:00 +0100107 layer->SetBackendId(setBackend);
108 ARMNN_ASSERT(layer != nullptr);
109
110 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
111 outputSlot.SetTensorInfo(outputTensorInfo);
112
113 // try to connect the Constant Inputs if there are any
Mike Kellya2806502023-08-03 10:42:11 +0100114 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
John Mcloughlin0422cf22023-04-27 16:55:00 +0100115 {
116 return kTfLiteError;
117 }
118
119 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
120}
121
122} // namespace armnnOpaqueDelegate