blob: a83fd628679736325587d81d078011772171a964 [file] [log] [blame]
Derek Lambertif674aa02019-08-01 15:56:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01008#include <armnn/utility/Assert.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +01009#include <algorithm>
10
11namespace armnn
12{
13
Derek Lambertif674aa02019-08-01 15:56:25 +010014inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
15{
16 if (!weightsType)
17 {
18 return weightsType;
19 }
20
21 switch(weightsType.value())
22 {
23 case armnn::DataType::Float16:
24 case armnn::DataType::Float32:
25 return weightsType;
Sadik Armagandb73c982020-04-01 17:35:30 +010026 case armnn::DataType::QAsymmS8:
Teresa Charlin9d07e002020-11-14 13:43:46 +000027 case armnn::DataType::QAsymmU8:
28 case armnn::DataType::QSymmS8:
29 case armnn::DataType::QSymmS16:
Sadik Armagandb73c982020-04-01 17:35:30 +010030 return armnn::DataType::Signed32;
Derek Lambertif674aa02019-08-01 15:56:25 +010031 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010032 ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
Derek Lambertif674aa02019-08-01 15:56:25 +010033 }
34 return armnn::EmptyOptional();
35}
36
Derek Lambertif674aa02019-08-01 15:56:25 +010037template<typename F>
38bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
39{
40 bool supported = rule();
41 if (!supported && reason)
42 {
43 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
44 }
45 return supported;
46}
47
48struct Rule
49{
50 bool operator()() const
51 {
52 return m_Res;
53 }
54
55 bool m_Res = true;
56};
57
58template<typename T>
Derek Lamberti901ea112019-12-10 22:07:09 +000059bool AllTypesAreEqualImpl(T)
Derek Lambertif674aa02019-08-01 15:56:25 +010060{
61 return true;
62}
63
64template<typename T, typename... Rest>
65bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
66{
67 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
68
69 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
70}
71
72struct TypesAreEqual : public Rule
73{
74 template<typename ... Ts>
75 TypesAreEqual(const Ts&... ts)
76 {
77 m_Res = AllTypesAreEqualImpl(ts...);
78 }
79};
80
81struct QuantizationParametersAreEqual : public Rule
82{
83 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
84 {
85 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
86 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
87 }
88};
89
90struct TypeAnyOf : public Rule
91{
92 template<typename Container>
93 TypeAnyOf(const TensorInfo& info, const Container& c)
94 {
95 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
Jim Flynn4b2f3472021-10-13 21:20:07 +010096 {
97 return dt == info.GetDataType();
98 });
Derek Lambertif674aa02019-08-01 15:56:25 +010099 }
100};
101
102struct TypeIs : public Rule
103{
104 TypeIs(const TensorInfo& info, DataType dt)
105 {
106 m_Res = dt == info.GetDataType();
107 }
108};
109
Derek Lambertid466a542020-01-22 15:37:29 +0000110struct TypeNotPerAxisQuantized : public Rule
111{
112 TypeNotPerAxisQuantized(const TensorInfo& info)
113 {
114 m_Res = !info.IsQuantized() || !info.HasPerAxisQuantization();
115 }
116};
117
Derek Lambertif674aa02019-08-01 15:56:25 +0100118struct BiasAndWeightsTypesMatch : public Rule
119{
120 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
121 {
122 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
123 }
124};
125
126struct BiasAndWeightsTypesCompatible : public Rule
127{
128 template<typename Container>
129 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
130 {
131 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
132 {
133 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
134 });
135 }
136};
137
138struct ShapesAreSameRank : public Rule
139{
140 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
141 {
142 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
143 }
144};
145
146struct ShapesAreSameTotalSize : public Rule
147{
148 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
149 {
150 m_Res = info0.GetNumElements() == info1.GetNumElements();
151 }
152};
153
154struct ShapesAreBroadcastCompatible : public Rule
155{
156 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
157 {
158 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
159 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
160 return sizeIn;
161 }
162
163 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
164 {
165 const TensorShape& shape0 = in0.GetShape();
166 const TensorShape& shape1 = in1.GetShape();
167 const TensorShape& outShape = out.GetShape();
168
169 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
170 {
171 unsigned int sizeOut = outShape[i];
172 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
173 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
174
175 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
176 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
177 }
178 }
179};
180
181struct TensorNumDimensionsAreCorrect : public Rule
182{
183 TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
184 {
185 m_Res = info.GetNumDimensions() == expectedNumDimensions;
186 }
187};
188
Samuel Yap6b478092022-07-06 15:36:03 +0100189struct TensorNumDimensionsAreGreaterOrEqualTo : public Rule
190{
191 TensorNumDimensionsAreGreaterOrEqualTo(const TensorInfo& info, unsigned int numDimensionsToCompare)
192 {
193 m_Res = info.GetNumDimensions() >= numDimensionsToCompare;
194 }
195};
196
Derek Lambertif674aa02019-08-01 15:56:25 +0100197} //namespace armnn