blob: 5da6e5ac6ada4128b584b25c21ebd1b1e89f9822 [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
John Mcloughlin0422cf22023-04-27 16:55:00 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12TfLiteStatus VisitBatchMatMulOperator(DelegateData& delegateData,
13 TfLiteOpaqueContext* tfLiteContext,
14 TfLiteOpaqueNode* tfLiteNode,
15 int nodeIndex,
16 int32_t operatorCode)
17{
18 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
19 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
20
21 // Gather input indices and use to get input tensor.
22 auto numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode);
23 const int* inputTensors;
24 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
25 {
26 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
27 tfLiteContext,
28 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
29 nodeIndex);
30 return kTfLiteError;
31 }
32
33 const TfLiteOpaqueTensor* kTfLiteLHSInputTensor =
34 TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
35 const TfLiteOpaqueTensor* kTfLiteRHSInputTensor =
36 TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
37
38 if (!IsValid(tfLiteContext, kTfLiteLHSInputTensor, operatorCode, nodeIndex))
39 {
40 return kTfLiteError;
41 }
42 if (!IsValid(tfLiteContext, kTfLiteRHSInputTensor, operatorCode, nodeIndex))
43 {
44 return kTfLiteError;
45 }
46
47 if (IsDynamicTensor(kTfLiteLHSInputTensor) || IsDynamicTensor(kTfLiteRHSInputTensor))
48 {
49 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
50 tfLiteContext,
51 "TfLiteArmnnOpaqueDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
52 operatorCode, nodeIndex);
53 return kTfLiteError;
54 }
55
56 // Gather output indices and use to get output tensors.
57 int numOutputs = 0;
58 const int* outputTensors;
59 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
60 {
61 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
62 tfLiteContext,
63 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
64 nodeIndex);
65 return kTfLiteError;
66 }
67
68 const TfLiteOpaqueTensor* kTfLiteOutputTensor =
69 TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
70 if (IsDynamicTensor(kTfLiteOutputTensor))
71 {
72 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
73 tfLiteContext,
74 "TfLiteArmnnOpaqueDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
75 operatorCode, nodeIndex);
76 return kTfLiteError;
77 }
78
79 const armnn::TensorInfo& armnnLHSInputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(kTfLiteLHSInputTensor);
80 const armnn::TensorInfo& armnnRHSInputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(kTfLiteRHSInputTensor);
81 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(kTfLiteOutputTensor, true);
82
83 armnn::BatchMatMulDescriptor descriptor;
84 auto* params = reinterpret_cast<TfLiteBatchMatMulParams *>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
85
86 // Tensorflow params are called adjoint, however they are actually just transposes behind the scene. They do
87 // not perform ajoint.
88 descriptor.m_TransposeX = params->adj_x;
89 descriptor.m_TransposeY = params->adj_y;
90
91 // Check if supported
92 bool isSupported = false;
93 armnn::BackendId setBackend;
94 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
95 {
96 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("BATCH_MATMUL",
97 tfLiteContext,
98 IsBatchMatMulSupported,
99 delegateData.m_Backends,
100 isSupported,
101 setBackend,
102 armnnLHSInputTensorInfo,
103 armnnRHSInputTensorInfo,
104 outputTensorInfo,
105 descriptor);
106 };
107
108 if (!delegateData.m_Network)
109 {
110 validateFunc(outputTensorInfo, isSupported);
111 return isSupported ? kTfLiteOk : kTfLiteError;
112 }
113
114 armnn::IConnectableLayer* layer = delegateData.m_Network->AddBatchMatMulLayer(descriptor);
115 layer->SetBackendId(setBackend);
116 ARMNN_ASSERT(layer != nullptr);
117
118 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
119 outputSlot.SetTensorInfo(outputTensorInfo);
120
121 // try to connect the Constant Inputs if there are any
122 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
123 {
124 return kTfLiteError;
125 }
126
127 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
128}
129
130} // namespace armnnOpaqueDelegate