blob: 3a2ae06f5a5dfd2f1add455f11aff155c0111c3f [file] [log] [blame]
Derek Lambertif674aa02019-08-01 15:56:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <boost/assert.hpp>
9#include <algorithm>
10
11namespace armnn
12{
13
Derek Lambertif674aa02019-08-01 15:56:25 +010014inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
15{
16 if (!weightsType)
17 {
18 return weightsType;
19 }
20
21 switch(weightsType.value())
22 {
23 case armnn::DataType::Float16:
24 case armnn::DataType::Float32:
25 return weightsType;
Derek Lambertif90c56d2020-01-10 17:14:08 +000026 case armnn::DataType::QAsymmU8:
Derek Lambertif674aa02019-08-01 15:56:25 +010027 return armnn::DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000028 case armnn::DataType::QSymmS16:
Derek Lambertif674aa02019-08-01 15:56:25 +010029 return armnn::DataType::Signed32;
30 default:
31 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
32 }
33 return armnn::EmptyOptional();
34}
35
Derek Lambertif674aa02019-08-01 15:56:25 +010036template<typename F>
37bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
38{
39 bool supported = rule();
40 if (!supported && reason)
41 {
42 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
43 }
44 return supported;
45}
46
47struct Rule
48{
49 bool operator()() const
50 {
51 return m_Res;
52 }
53
54 bool m_Res = true;
55};
56
57template<typename T>
Derek Lamberti901ea112019-12-10 22:07:09 +000058bool AllTypesAreEqualImpl(T)
Derek Lambertif674aa02019-08-01 15:56:25 +010059{
60 return true;
61}
62
63template<typename T, typename... Rest>
64bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
65{
66 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
67
68 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
69}
70
71struct TypesAreEqual : public Rule
72{
73 template<typename ... Ts>
74 TypesAreEqual(const Ts&... ts)
75 {
76 m_Res = AllTypesAreEqualImpl(ts...);
77 }
78};
79
80struct QuantizationParametersAreEqual : public Rule
81{
82 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
83 {
84 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
85 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
86 }
87};
88
89struct TypeAnyOf : public Rule
90{
91 template<typename Container>
92 TypeAnyOf(const TensorInfo& info, const Container& c)
93 {
94 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
95 {
96 return dt == info.GetDataType();
97 });
98 }
99};
100
101struct TypeIs : public Rule
102{
103 TypeIs(const TensorInfo& info, DataType dt)
104 {
105 m_Res = dt == info.GetDataType();
106 }
107};
108
Derek Lambertid466a542020-01-22 15:37:29 +0000109struct TypeNotPerAxisQuantized : public Rule
110{
111 TypeNotPerAxisQuantized(const TensorInfo& info)
112 {
113 m_Res = !info.IsQuantized() || !info.HasPerAxisQuantization();
114 }
115};
116
Derek Lambertif674aa02019-08-01 15:56:25 +0100117struct BiasAndWeightsTypesMatch : public Rule
118{
119 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
120 {
121 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
122 }
123};
124
125struct BiasAndWeightsTypesCompatible : public Rule
126{
127 template<typename Container>
128 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
129 {
130 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
131 {
132 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
133 });
134 }
135};
136
137struct ShapesAreSameRank : public Rule
138{
139 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
140 {
141 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
142 }
143};
144
145struct ShapesAreSameTotalSize : public Rule
146{
147 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
148 {
149 m_Res = info0.GetNumElements() == info1.GetNumElements();
150 }
151};
152
153struct ShapesAreBroadcastCompatible : public Rule
154{
155 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
156 {
157 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
158 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
159 return sizeIn;
160 }
161
162 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
163 {
164 const TensorShape& shape0 = in0.GetShape();
165 const TensorShape& shape1 = in1.GetShape();
166 const TensorShape& outShape = out.GetShape();
167
168 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
169 {
170 unsigned int sizeOut = outShape[i];
171 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
172 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
173
174 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
175 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
176 }
177 }
178};
179
180struct TensorNumDimensionsAreCorrect : public Rule
181{
182 TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
183 {
184 m_Res = info.GetNumDimensions() == expectedNumDimensions;
185 }
186};
187
188} //namespace armnn