blob: bf997dbff71fd19c8380497d9df1acf7e68c7fe1 [file] [log] [blame]
Derek Lambertif674aa02019-08-01 15:56:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <boost/assert.hpp>
9#include <algorithm>
10
11namespace armnn
12{
13
Derek Lambertif674aa02019-08-01 15:56:25 +010014inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
15{
16 if (!weightsType)
17 {
18 return weightsType;
19 }
20
21 switch(weightsType.value())
22 {
23 case armnn::DataType::Float16:
24 case armnn::DataType::Float32:
25 return weightsType;
26 case armnn::DataType::QuantisedAsymm8:
27 return armnn::DataType::Signed32;
28 case armnn::DataType::QuantisedSymm16:
29 return armnn::DataType::Signed32;
30 default:
31 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
32 }
33 return armnn::EmptyOptional();
34}
35
Derek Lambertif674aa02019-08-01 15:56:25 +010036template<typename F>
37bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
38{
39 bool supported = rule();
40 if (!supported && reason)
41 {
42 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
43 }
44 return supported;
45}
46
47struct Rule
48{
49 bool operator()() const
50 {
51 return m_Res;
52 }
53
54 bool m_Res = true;
55};
56
57template<typename T>
58bool AllTypesAreEqualImpl(T t)
59{
60 return true;
61}
62
63template<typename T, typename... Rest>
64bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
65{
66 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
67
68 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
69}
70
71struct TypesAreEqual : public Rule
72{
73 template<typename ... Ts>
74 TypesAreEqual(const Ts&... ts)
75 {
76 m_Res = AllTypesAreEqualImpl(ts...);
77 }
78};
79
80struct QuantizationParametersAreEqual : public Rule
81{
82 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
83 {
84 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
85 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
86 }
87};
88
89struct TypeAnyOf : public Rule
90{
91 template<typename Container>
92 TypeAnyOf(const TensorInfo& info, const Container& c)
93 {
94 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
95 {
96 return dt == info.GetDataType();
97 });
98 }
99};
100
101struct TypeIs : public Rule
102{
103 TypeIs(const TensorInfo& info, DataType dt)
104 {
105 m_Res = dt == info.GetDataType();
106 }
107};
108
109struct BiasAndWeightsTypesMatch : public Rule
110{
111 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
112 {
113 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
114 }
115};
116
117struct BiasAndWeightsTypesCompatible : public Rule
118{
119 template<typename Container>
120 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
121 {
122 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
123 {
124 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
125 });
126 }
127};
128
129struct ShapesAreSameRank : public Rule
130{
131 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
132 {
133 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
134 }
135};
136
137struct ShapesAreSameTotalSize : public Rule
138{
139 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
140 {
141 m_Res = info0.GetNumElements() == info1.GetNumElements();
142 }
143};
144
145struct ShapesAreBroadcastCompatible : public Rule
146{
147 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
148 {
149 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
150 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
151 return sizeIn;
152 }
153
154 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
155 {
156 const TensorShape& shape0 = in0.GetShape();
157 const TensorShape& shape1 = in1.GetShape();
158 const TensorShape& outShape = out.GetShape();
159
160 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
161 {
162 unsigned int sizeOut = outShape[i];
163 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
164 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
165
166 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
167 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
168 }
169 }
170};
171
172struct TensorNumDimensionsAreCorrect : public Rule
173{
174 TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
175 {
176 m_Res = info.GetNumDimensions() == expectedNumDimensions;
177 }
178};
179
180} //namespace armnn