blob: c2b0b1b0b9e2dc7b7ed73b84bbe03a5a4247a106 [file] [log] [blame]
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "TosaRefLayerSupport.hpp"
7#include <tosaCommon/TosaMappings.hpp>
8
9#include <armnn/Types.hpp>
10#include <armnn/utility/IgnoreUnused.hpp>
11#include <tosaCommon/TosaLayerSupportRules.hpp>
12#include <LayerSupportCommon.hpp>
13
14#include <vector>
15#include <array>
16
17namespace armnn
18{
19
20static bool IsTosaLayerSupported(TosaSerializationOperator* op,
21 const std::vector<TosaSerializationTensor*>& inputs,
22 const std::vector<TosaSerializationTensor*>& outputs,
23 Optional<string&> reasonIfUnsupported)
24{
25 switch(op->GetOp())
26 {
27 case tosa::Op_ADD:
28 {
29 bool supported = true;
30
31 std::array<Attribute, 1> supportedAttributes =
Matthew Sloyanda824cc2022-10-10 12:43:20 +010032 {
33 Attribute_NONE
34 };
Francis Murtagh9270d9e2022-08-12 13:54:17 +010035
36 // Check Attribute from operator (GetAttribute)
37 supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported,
38 std::string("TOSA Reference addition: operator has an unsupported attribute.").c_str());
39
Matthew Sloyanda824cc2022-10-10 12:43:20 +010040 std::array<DType, 9> supportedTypes =
41 {
42 DType_BOOL,
43 DType_UINT8,
44 DType_UINT16,
45 DType_INT4,
46 DType_INT8,
47 DType_INT16,
48 DType_INT32,
49 DType_FP16,
50 DType_FP32
51 };
Francis Murtagh9270d9e2022-08-12 13:54:17 +010052
53 for (auto tensor : inputs)
54 {
55 // Check Dtype from tensor (GetDtype)
56 supported &= CheckSupportRule(TosaTypeAnyOf(tensor, supportedTypes),
57 reasonIfUnsupported,
58 std::string("TOSA Reference addition: " + tensor->GetName() +
59 " is not a supported type.").c_str());
60
61 // Check Shape from tensor (GetShape)
62 supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(tensor),
63 reasonIfUnsupported,
64 std::string("Tosa Reference addition: " + tensor->GetName() + " Shape.Size()"
65 " outside bounds of between Zero and MaxNumOfTensorDimensions.").c_str());
66 }
67
68 // Check Dtype from tensor (GetDtype)
69 supported &= CheckSupportRule(TosaTypeAnyOf(outputs[0], supportedTypes),
70 reasonIfUnsupported,
71 std::string("TOSA Reference addition: " + outputs[0]->GetName() +
72 " is not a supported type.").c_str());
73
74 // Check Shape from tensor (GetShape)
75 supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(outputs[0]),
76 reasonIfUnsupported,
77 std::string("Tosa Reference addition: " + outputs[0]->GetName() + " Shape.Size()"
78 " outside bounds of between Zero and MaxNumOfTensorDimensions.").c_str());
79
80 return supported;
81 }
82 default:
83 SetValueChecked(reasonIfUnsupported, "Operation is currently unsupported by the TOSA Reference Backend.");
84 return false;
85 }
86}
87
88bool TosaRefLayerSupport::IsLayerSupported(const LayerType& type,
89 const std::vector<TensorInfo>& infos,
90 const BaseDescriptor& descriptor,
91 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
92 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
93 Optional<std::string&> reasonIfUnsupported) const
94{
95 IgnoreUnused(lstmParamsInfo);
96 IgnoreUnused(quantizedLstmInputParamsInfo);
97
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010098 std::vector<const TensorInfo*> inputInfos;
99 std::vector<const TensorInfo*> outputInfos;
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100100
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100101 switch (type)
102 {
103 case LayerType::Addition:
104 // Setup inputs and outputs
105 inputInfos.push_back(&infos[0]);
106 inputInfos.push_back(&infos[1]);
107 outputInfos.push_back(&infos[2]);
108 break;
109 case LayerType::Input:
110 case LayerType::Output:
111 return true;
112 default:
113 break;
114 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100115
Matthew Sloyan5c54c382022-11-09 16:28:51 +0000116 auto mappings = GetTosaMapping(type, inputInfos, outputInfos, descriptor, false);
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100117 if (mappings->GetName() == "")
118 {
119 // There currently isn't a TOSA mapping for this layer, as the default was returned.
120 return false;
121 }
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100122
123 // Loop through block and get each tensor and operator
124 for (long unsigned int i = 0; i < mappings->GetOperators().size(); ++i)
125 {
126 // While looping over operators check for op_UNKNOWN which is unsupported
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100127 if (mappings->GetOperators()[i]->GetOp() == tosa::Op_UNKNOWN) { return false; }
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100128
129 // Loop over operators and get GetInput/OutputTensorNames, loop over resulting names and
130 // use GetTensorByName to pass pointers to tensors on to the IsTosaLayerSupported()
131 std::vector<TosaSerializationTensor*> inputTensorsVect;
132 for (const auto& name : mappings->GetOperators()[i]->GetInputTensorNames())
133 {
134 inputTensorsVect.push_back(mappings->GetTensorByName(name));
135 }
136
137 std::vector<TosaSerializationTensor*> outputTensorsVect;
138 for (const auto& name : mappings->GetOperators()[i]->GetOutputTensorNames())
139 {
140 outputTensorsVect.push_back(mappings->GetTensorByName(name));
141 }
142
143 if (!IsTosaLayerSupported(mappings->GetOperators()[i],
144 inputTensorsVect,
145 outputTensorsVect,
146 reasonIfUnsupported))
147 {
148 return false;
149 }
150 }
151 return true;
152}
153
154} // namespace armnn