blob: db3f38ccbb75a6e315017d7e3c81f71fae98d843 [file] [log] [blame]
Derek Lambertif674aa02019-08-01 15:56:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <boost/assert.hpp>
9#include <algorithm>
10
11namespace armnn
12{
13
14namespace
15{
16
17inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
18{
19 if (!weightsType)
20 {
21 return weightsType;
22 }
23
24 switch(weightsType.value())
25 {
26 case armnn::DataType::Float16:
27 case armnn::DataType::Float32:
28 return weightsType;
29 case armnn::DataType::QuantisedAsymm8:
30 return armnn::DataType::Signed32;
31 case armnn::DataType::QuantisedSymm16:
32 return armnn::DataType::Signed32;
33 default:
34 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
35 }
36 return armnn::EmptyOptional();
37}
38
39} //namespace
40
41template<typename F>
42bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
43{
44 bool supported = rule();
45 if (!supported && reason)
46 {
47 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
48 }
49 return supported;
50}
51
52struct Rule
53{
54 bool operator()() const
55 {
56 return m_Res;
57 }
58
59 bool m_Res = true;
60};
61
62template<typename T>
63bool AllTypesAreEqualImpl(T t)
64{
65 return true;
66}
67
68template<typename T, typename... Rest>
69bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
70{
71 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
72
73 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
74}
75
76struct TypesAreEqual : public Rule
77{
78 template<typename ... Ts>
79 TypesAreEqual(const Ts&... ts)
80 {
81 m_Res = AllTypesAreEqualImpl(ts...);
82 }
83};
84
85struct QuantizationParametersAreEqual : public Rule
86{
87 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
88 {
89 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
90 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
91 }
92};
93
94struct TypeAnyOf : public Rule
95{
96 template<typename Container>
97 TypeAnyOf(const TensorInfo& info, const Container& c)
98 {
99 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
100 {
101 return dt == info.GetDataType();
102 });
103 }
104};
105
106struct TypeIs : public Rule
107{
108 TypeIs(const TensorInfo& info, DataType dt)
109 {
110 m_Res = dt == info.GetDataType();
111 }
112};
113
114struct BiasAndWeightsTypesMatch : public Rule
115{
116 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
117 {
118 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
119 }
120};
121
122struct BiasAndWeightsTypesCompatible : public Rule
123{
124 template<typename Container>
125 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
126 {
127 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
128 {
129 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
130 });
131 }
132};
133
134struct ShapesAreSameRank : public Rule
135{
136 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
137 {
138 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
139 }
140};
141
142struct ShapesAreSameTotalSize : public Rule
143{
144 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
145 {
146 m_Res = info0.GetNumElements() == info1.GetNumElements();
147 }
148};
149
150struct ShapesAreBroadcastCompatible : public Rule
151{
152 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
153 {
154 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
155 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
156 return sizeIn;
157 }
158
159 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
160 {
161 const TensorShape& shape0 = in0.GetShape();
162 const TensorShape& shape1 = in1.GetShape();
163 const TensorShape& outShape = out.GetShape();
164
165 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
166 {
167 unsigned int sizeOut = outShape[i];
168 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
169 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
170
171 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
172 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
173 }
174 }
175};
176
177struct TensorNumDimensionsAreCorrect : public Rule
178{
179 TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
180 {
181 m_Res = info.GetNumDimensions() == expectedNumDimensions;
182 }
183};
184
185} //namespace armnn