blob: ddecc821721ceff97dd15f39e9d883f50fb2b3ac [file] [log] [blame]
Derek Lambertif674aa02019-08-01 15:56:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01008#include <armnn/utility/Assert.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +01009#include <algorithm>
10
11namespace armnn
12{
13
Derek Lambertif674aa02019-08-01 15:56:25 +010014inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
15{
16 if (!weightsType)
17 {
18 return weightsType;
19 }
20
21 switch(weightsType.value())
22 {
23 case armnn::DataType::Float16:
24 case armnn::DataType::Float32:
25 return weightsType;
Derek Lambertif90c56d2020-01-10 17:14:08 +000026 case armnn::DataType::QAsymmU8:
Derek Lambertif674aa02019-08-01 15:56:25 +010027 return armnn::DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000028 case armnn::DataType::QSymmS16:
Derek Lambertif674aa02019-08-01 15:56:25 +010029 return armnn::DataType::Signed32;
Sadik Armagandb73c982020-04-01 17:35:30 +010030 case armnn::DataType::QAsymmS8:
31 return armnn::DataType::Signed32;
Derek Lambertif674aa02019-08-01 15:56:25 +010032 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010033 ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
Derek Lambertif674aa02019-08-01 15:56:25 +010034 }
35 return armnn::EmptyOptional();
36}
37
Derek Lambertif674aa02019-08-01 15:56:25 +010038template<typename F>
39bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
40{
41 bool supported = rule();
42 if (!supported && reason)
43 {
44 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
45 }
46 return supported;
47}
48
49struct Rule
50{
51 bool operator()() const
52 {
53 return m_Res;
54 }
55
56 bool m_Res = true;
57};
58
59template<typename T>
Derek Lamberti901ea112019-12-10 22:07:09 +000060bool AllTypesAreEqualImpl(T)
Derek Lambertif674aa02019-08-01 15:56:25 +010061{
62 return true;
63}
64
65template<typename T, typename... Rest>
66bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
67{
68 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
69
70 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
71}
72
73struct TypesAreEqual : public Rule
74{
75 template<typename ... Ts>
76 TypesAreEqual(const Ts&... ts)
77 {
78 m_Res = AllTypesAreEqualImpl(ts...);
79 }
80};
81
82struct QuantizationParametersAreEqual : public Rule
83{
84 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
85 {
86 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
87 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
88 }
89};
90
91struct TypeAnyOf : public Rule
92{
93 template<typename Container>
94 TypeAnyOf(const TensorInfo& info, const Container& c)
95 {
96 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
97 {
98 return dt == info.GetDataType();
99 });
100 }
101};
102
103struct TypeIs : public Rule
104{
105 TypeIs(const TensorInfo& info, DataType dt)
106 {
107 m_Res = dt == info.GetDataType();
108 }
109};
110
Derek Lambertid466a542020-01-22 15:37:29 +0000111struct TypeNotPerAxisQuantized : public Rule
112{
113 TypeNotPerAxisQuantized(const TensorInfo& info)
114 {
115 m_Res = !info.IsQuantized() || !info.HasPerAxisQuantization();
116 }
117};
118
Derek Lambertif674aa02019-08-01 15:56:25 +0100119struct BiasAndWeightsTypesMatch : public Rule
120{
121 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
122 {
123 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
124 }
125};
126
127struct BiasAndWeightsTypesCompatible : public Rule
128{
129 template<typename Container>
130 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
131 {
132 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
133 {
134 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
135 });
136 }
137};
138
139struct ShapesAreSameRank : public Rule
140{
141 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
142 {
143 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
144 }
145};
146
147struct ShapesAreSameTotalSize : public Rule
148{
149 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
150 {
151 m_Res = info0.GetNumElements() == info1.GetNumElements();
152 }
153};
154
155struct ShapesAreBroadcastCompatible : public Rule
156{
157 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
158 {
159 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
160 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
161 return sizeIn;
162 }
163
164 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
165 {
166 const TensorShape& shape0 = in0.GetShape();
167 const TensorShape& shape1 = in1.GetShape();
168 const TensorShape& outShape = out.GetShape();
169
170 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
171 {
172 unsigned int sizeOut = outShape[i];
173 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
174 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
175
176 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
177 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
178 }
179 }
180};
181
182struct TensorNumDimensionsAreCorrect : public Rule
183{
184 TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
185 {
186 m_Res = info.GetNumDimensions() == expectedNumDimensions;
187 }
188};
189
190} //namespace armnn