blob: 5af03b3790ebefe1c963e81290399829624e7a1c [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Teresa Charlin42362962023-04-28 14:23:33 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12
13TfLiteStatus VisitTransposeOperator(DelegateData& delegateData,
14 TfLiteOpaqueContext* tfLiteContext,
15 TfLiteOpaqueNode* tfLiteNode,
16 int nodeIndex,
17 int32_t tfliteTransposeOperatorCode)
18{
19 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
20 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
21
22 // Gather input indices and use to get input tensor.
23 const int* inputTensors;
24 int numInputs;
25 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
26 {
27 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
28 tfLiteContext,
29 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
30 nodeIndex);
31 return kTfLiteError;
32 }
33 const TfLiteOpaqueTensor* tfLiteInputTensor0 = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
34 if (!IsValid(tfLiteContext, tfLiteInputTensor0, tfliteTransposeOperatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38 const TfLiteOpaqueTensor* tfLiteInputTensor1 = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
39 if (!IsValid(tfLiteContext, tfLiteInputTensor1, tfliteTransposeOperatorCode, nodeIndex))
40 {
41 return kTfLiteError;
42 }
43
44 // Gather output indices and use to get output tensors.
45 const int* outputTensors;
46 int numOutputs;
47 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
48 {
49 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
50 tfLiteContext,
51 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
52 nodeIndex);
53 return kTfLiteError;
54 }
55
56 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
57 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfliteTransposeOperatorCode, nodeIndex))
58 {
59 return kTfLiteError;
60 }
61
62 const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor0);
63 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
64
65 auto* permTensorDataPtr = static_cast<int32_t*>(TfLiteOpaqueTensorData(tfLiteInputTensor1));
66 unsigned int numEl = TfLiteOpaqueTensorDim(tfLiteInputTensor1, 0);
67
68 ARMNN_ASSERT( numEl <= static_cast<int>(armnn::MaxNumOfTensorDimensions) );
69 // Ensure only single dimension to the permutation tensor
70 ARMNN_ASSERT( TfLiteOpaqueTensorNumDims(tfLiteInputTensor1) == 1 );
71
72 armnn::TransposeDescriptor descriptor(armnn::PermutationVector(
73 reinterpret_cast<const armnn::PermutationVector::ValueType *> (permTensorDataPtr),
74 static_cast<armnn::PermutationVector::SizeType>(numEl)));
75
76 bool isSupported = false;
77 armnn::BackendId setBackend;
78 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
79 {
80 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("TRANSPOSE",
81 tfLiteContext,
82 IsTransposeSupported,
83 delegateData.m_Backends,
84 isSupported,
85 setBackend,
86 inputTensorInfo0,
87 outputTensorInfo,
88 descriptor);
89 };
90
91 if (!delegateData.m_Network)
92 {
93 validateFunc(outputTensorInfo, isSupported);
94 return isSupported ? kTfLiteOk : kTfLiteError;
95 }
96
Mike Kellya2806502023-08-03 10:42:11 +010097 auto layerName = GetName(armnn::LayerType::Transpose, nodeIndex);
98 armnn::IConnectableLayer* transposeLayer = delegateData.m_Network->AddTransposeLayer(descriptor, layerName.c_str());
Teresa Charlin42362962023-04-28 14:23:33 +010099 transposeLayer->SetBackendId(setBackend);
100 ARMNN_ASSERT(transposeLayer != nullptr);
101 // Permutation vector given to descriptor object
102 ARMNN_ASSERT(transposeLayer->GetNumInputSlots() == 1);
103
104 armnn::IOutputSlot& outputSlot = transposeLayer->GetOutputSlot(0);
105 outputSlot.SetTensorInfo(outputTensorInfo);
106
107 // try to connect the Constant Inputs if there are any
Mike Kellya2806502023-08-03 10:42:11 +0100108 if (ProcessInputs(transposeLayer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Teresa Charlin42362962023-04-28 14:23:33 +0100109 {
110 return kTfLiteError;
111 }
112
113 return Connect(transposeLayer, tfLiteContext, tfLiteNode, delegateData);
114}
115} // namespace armnnOpaqueDelegate