blob: 1bd825b77471b136cc7c20d958c623386e0eb6be [file] [log] [blame]
Derek Lambertif674aa02019-08-01 15:56:25 +01001//
Colm Donelanb4ef1632024-02-01 15:00:43 +00002// Copyright © 2017, 2024 Arm Ltd. All rights reserved.
Derek Lambertif674aa02019-08-01 15:56:25 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Derek Lambertif674aa02019-08-01 15:56:25 +01008#include <algorithm>
9
10namespace armnn
11{
12
Derek Lambertif674aa02019-08-01 15:56:25 +010013inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
14{
15 if (!weightsType)
16 {
17 return weightsType;
18 }
19
20 switch(weightsType.value())
21 {
22 case armnn::DataType::Float16:
23 case armnn::DataType::Float32:
24 return weightsType;
Sadik Armagandb73c982020-04-01 17:35:30 +010025 case armnn::DataType::QAsymmS8:
Teresa Charlin9d07e002020-11-14 13:43:46 +000026 case armnn::DataType::QAsymmU8:
27 case armnn::DataType::QSymmS8:
28 case armnn::DataType::QSymmS16:
Sadik Armagandb73c982020-04-01 17:35:30 +010029 return armnn::DataType::Signed32;
Derek Lambertif674aa02019-08-01 15:56:25 +010030 default:
Colm Donelanb4ef1632024-02-01 15:00:43 +000031 throw InvalidArgumentException("GetBiasTypeFromWeightsType(): Unsupported data type.");
Derek Lambertif674aa02019-08-01 15:56:25 +010032 }
33 return armnn::EmptyOptional();
34}
35
Derek Lambertif674aa02019-08-01 15:56:25 +010036template<typename F>
37bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
38{
39 bool supported = rule();
40 if (!supported && reason)
41 {
42 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
43 }
44 return supported;
45}
46
47struct Rule
48{
49 bool operator()() const
50 {
51 return m_Res;
52 }
53
54 bool m_Res = true;
55};
56
57template<typename T>
Derek Lamberti901ea112019-12-10 22:07:09 +000058bool AllTypesAreEqualImpl(T)
Derek Lambertif674aa02019-08-01 15:56:25 +010059{
60 return true;
61}
62
63template<typename T, typename... Rest>
64bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
65{
66 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
67
68 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
69}
70
71struct TypesAreEqual : public Rule
72{
73 template<typename ... Ts>
74 TypesAreEqual(const Ts&... ts)
75 {
76 m_Res = AllTypesAreEqualImpl(ts...);
77 }
78};
79
80struct QuantizationParametersAreEqual : public Rule
81{
82 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
83 {
84 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
85 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
86 }
87};
88
89struct TypeAnyOf : public Rule
90{
91 template<typename Container>
92 TypeAnyOf(const TensorInfo& info, const Container& c)
93 {
94 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
Jim Flynn4b2f3472021-10-13 21:20:07 +010095 {
96 return dt == info.GetDataType();
97 });
Derek Lambertif674aa02019-08-01 15:56:25 +010098 }
99};
100
101struct TypeIs : public Rule
102{
103 TypeIs(const TensorInfo& info, DataType dt)
104 {
105 m_Res = dt == info.GetDataType();
106 }
107};
108
Derek Lambertid466a542020-01-22 15:37:29 +0000109struct TypeNotPerAxisQuantized : public Rule
110{
111 TypeNotPerAxisQuantized(const TensorInfo& info)
112 {
113 m_Res = !info.IsQuantized() || !info.HasPerAxisQuantization();
114 }
115};
116
Derek Lambertif674aa02019-08-01 15:56:25 +0100117struct BiasAndWeightsTypesMatch : public Rule
118{
119 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
120 {
121 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
122 }
123};
124
125struct BiasAndWeightsTypesCompatible : public Rule
126{
127 template<typename Container>
128 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
129 {
130 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
131 {
132 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
133 });
134 }
135};
136
137struct ShapesAreSameRank : public Rule
138{
139 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
140 {
141 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
142 }
143};
144
145struct ShapesAreSameTotalSize : public Rule
146{
147 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
148 {
149 m_Res = info0.GetNumElements() == info1.GetNumElements();
150 }
151};
152
153struct ShapesAreBroadcastCompatible : public Rule
154{
155 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
156 {
157 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
158 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
159 return sizeIn;
160 }
161
162 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
163 {
164 const TensorShape& shape0 = in0.GetShape();
165 const TensorShape& shape1 = in1.GetShape();
166 const TensorShape& outShape = out.GetShape();
167
168 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
169 {
170 unsigned int sizeOut = outShape[i];
171 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
172 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
173
174 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
175 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
176 }
177 }
178};
179
180struct TensorNumDimensionsAreCorrect : public Rule
181{
182 TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
183 {
184 m_Res = info.GetNumDimensions() == expectedNumDimensions;
185 }
186};
187
Samuel Yap6b478092022-07-06 15:36:03 +0100188struct TensorNumDimensionsAreGreaterOrEqualTo : public Rule
189{
190 TensorNumDimensionsAreGreaterOrEqualTo(const TensorInfo& info, unsigned int numDimensionsToCompare)
191 {
192 m_Res = info.GetNumDimensions() >= numDimensionsToCompare;
193 }
194};
195
Derek Lambertif674aa02019-08-01 15:56:25 +0100196} //namespace armnn