blob: 2000ce4a57749c05385aea823bd9e91312facb1f [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "WorkloadData.hpp"
6
7#include "CpuTensorHandle.hpp"
telsoa014fcda012018-03-09 14:13:49 +00008
Matteo Martincigh21350152018-11-28 16:22:22 +00009#include <DataLayoutIndexed.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000013#include <string>
14#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/format.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010017#include <boost/numeric/conversion/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
telsoa01c577f2c2018-08-31 09:22:23 +010029 case DataType::Float16:
30 return DataType::Float16;
telsoa014fcda012018-03-09 14:13:49 +000031 case DataType::Float32:
32 return DataType::Float32;
33 case DataType::QuantisedAsymm8:
34 return DataType::Signed32;
Ruomei Yan88d44b82019-05-23 14:29:06 +010035 case DataType::QuantisedSymm16:
36 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000037 default:
38 BOOST_ASSERT_MSG(false, "Invalid input data type");
39 return DataType::Float32;
40 }
41}
42
43namespace
44{
45
46//---------------------------------------------------------------
47//android ndk does not support std::to_string function.
48template <typename T>
49std::string to_string(T value)
50{
51 std::ostringstream os;
52 os << value;
53 return os.str();
54}
55
56//---------------------------------------------------------------
57void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
58{
59 if (!ptr)
60 {
61 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
62 paramName + " parameter must be set.");
63 }
64}
65
66//---------------------------------------------------------------
67void ValidateTensorShapesMatch(const TensorInfo& first,
68 const TensorInfo& second,
69 std::string const& descName,
70 std::string const& firstName,
71 std::string const& secondName)
72{
73 if (first.GetShape() != second.GetShape())
74 {
75 throw InvalidArgumentException(descName + ": "
76 + firstName + " & " + secondName + " must have identical shapes");
77 }
78}
79
80//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010081void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000082{
Sadik Armaganeff363d2019-04-05 15:25:46 +010083 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000084 {
85 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010086 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000087 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
88 }
89}
90
91//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010092void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000093{
Sadik Armaganeff363d2019-04-05 15:25:46 +010094 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000095 {
96 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010097 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +000098 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
99 }
100}
101
102//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100103void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000104 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100105 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000106 std::string const& tensorName)
107{
108 if (tensor.GetNumDimensions() != numDimensions)
109 {
110 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
111 to_string(tensor.GetNumDimensions()) + " dimensions for " +
112 tensorName + " tensor.");
113 }
114}
115
116//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100117void ValidateTensorNumElements(const TensorInfo& tensor,
118 std::string const& descName,
119 unsigned int numElements,
120 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100121{
122 if (tensor.GetNumElements() != numElements)
123 {
124 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
125 to_string(tensor.GetNumDimensions()) + " elements for " +
126 tensorName + " tensor.");
127 }
128}
129
130//---------------------------------------------------------------
131void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100132 unsigned int numDimension,
133 unsigned int numElements,
134 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100135{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100136 const std::string functionName{"ValidateTensorNumDimNumElem"};
137 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
138 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100139}
140
141//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000142void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
143 const std::string& descName, std::string const& tensorName)
144{
145 if (tensor.GetDataType() != dataType)
146 {
147 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
148 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
149 }
150}
151
152//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100153void ValidateTensorQuantizationSpace(const TensorInfo& first,
154 const TensorInfo& second,
155 const std::string& descName,
156 std::string const& firstName,
157 std::string const& secondName)
158{
159 if (!first.IsQuantized() ||
160 !second.IsQuantized())
161 {
162 // Not a quantized type, ignore the validation
163 return;
164 }
165
166 DataType firstDataType = first.GetDataType();
167 DataType secondDataType = second.GetDataType();
168
169 if (firstDataType != secondDataType)
170 {
171 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
172 " must be of the same quantized type, " +
173 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
174 secondName + " is " + GetDataTypeName(secondDataType));
175 }
176
177 if (!first.IsTypeSpaceMatch(second))
178 {
179 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
180 " must have the same quantization space, " +
181 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
182 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
183 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
184 " and scale " + to_string(second.GetQuantizationScale()));
185 }
186}
187
188//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100189void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
190 const TensorInfo& inputTensorInfo,
191 const TensorInfo& weightsTensorInfo,
192 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000193{
194 if (biasTensor.GetQuantizationOffset() != 0)
195 {
196 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
197 to_string(biasTensor.GetQuantizationOffset()));
198 }
199 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
kevmay016c46dd32018-12-17 15:32:45 +0000200 if (std::abs(biasTensor.GetQuantizationScale() - expectedScale) > 0.00000001f)
telsoa014fcda012018-03-09 14:13:49 +0000201 {
202 // Print the float values with extra precision to see very small differences
203 std::stringstream msg;
204 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
205 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
206 biasTensor.GetQuantizationScale();
207 throw InvalidArgumentException(msg.str());
208 }
209}
210
211//---------------------------------------------------------------
212void ValidateTensors(const std::vector<ITensorHandle*>& vec,
213 unsigned int numExpected,
214 const std::string& descName,
215 const std::string& varName)
216{
217 if (vec.empty() && numExpected > 0)
218 {
219 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
220 }
221
222 for (unsigned int i = 0; i < numExpected; ++i)
223 {
224 if (!vec[i])
225 {
226 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
227 }
228 }
229}
230
231//---------------------------------------------------------------
232void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
233 const TensorInfo& second,
234 const TensorInfo& output,
235 std::string const& descName,
236 std::string const& firstName,
237 std::string const& secondName)
238{
239 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
240 // broadcasted.
241 if (first.GetNumDimensions() != second.GetNumDimensions())
242 {
243 throw InvalidArgumentException(descName + ": Tensors "
244 + firstName + " & " + secondName
245 + " must have the same number of dimensions in order to be broadcasted");
246 }
247 uint32_t numDims = first.GetNumDimensions();
248 std::vector<uint32_t> outputDims(numDims, 0u);
249 for (uint32_t i = 0; i < numDims; i++)
250 {
251 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
252 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
253 if (dimsNotEqual && dimsNotOne)
254 {
255 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
256 }
257 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
258 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100259 TensorShape broadcastShape = TensorShape(boost::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000260 if (broadcastShape != output.GetShape())
261 {
262 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
263 + firstName + " & " + secondName
264 + " does not match the output shape");
265 }
266}
267
268//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100269void ValidateDataTypes(const TensorInfo& info,
270 const std::vector<armnn::DataType>& supportedTypes,
271 std::string const& descName)
272{
273 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
274 if (iterator == supportedTypes.end())
275 {
276 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
277 }
278}
279
James Conroy4d1ff582019-06-10 17:06:39 +0100280//---------------------------------------------------------------
281void ValidateTensorDataTypesMatch(const TensorInfo& first,
282 const TensorInfo& second,
283 std::string const& descName,
284 std::string const& firstName,
285 std::string const& secondName)
286{
287 if (first.GetDataType() != second.GetDataType())
288 {
289 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
290 " must have identical data types.");
291 }
292}
293
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100294//---------------------------------------------------------------
295void ValidateTensorNumElementsMatch(const TensorInfo& first,
296 const TensorInfo& second,
297 std::string const& descName,
298 std::string const& firstName,
299 std::string const& secondName)
300{
301 if (first.GetNumElements() != second.GetNumElements())
302 {
303 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
304 " must have the same number of elements.");
305 }
306}
307
308} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000309
310void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
311 unsigned int numExpectedIn, unsigned int numExpectedOut) const
312{
313 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
314 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
315}
316
317//---------------------------------------------------------------
318void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
319{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100320 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000321
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100322 ValidateNumInputs(workloadInfo, descriptorName, 1);
323 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000324
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100325 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
326 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
327
328 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
329 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000330
331 if (m_Inputs.size() != m_Outputs.size())
332 {
333 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100334 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
335 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000336 }
337
338 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
339 {
340 if (!m_Inputs[i])
341 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100342 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
343 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000344 }
345
346 if (!m_Outputs[i])
347 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100348 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
349 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000350 }
351 }
352}
353
telsoa014fcda012018-03-09 14:13:49 +0000354void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
355{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100356 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100357
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100358 ValidateNumInputs(workloadInfo, descriptorName, 1);
359 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100360
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100361 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
362 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100363
364 std::vector<DataType> supportedTypes =
365 {
366 DataType::Float16,
367 DataType::Float32,
368 DataType::QuantisedAsymm8,
369 DataType::QuantisedSymm16
370 };
371
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100372 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
373 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
374 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000375}
376
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100377void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
378{
379 const std::string descriptorName{"SoftmaxQueueDescriptor"};
380
381 ValidateNumInputs(workloadInfo, descriptorName, 1);
382 ValidateNumOutputs(workloadInfo, descriptorName, 1);
383
384 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
385 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
386
387 std::vector<DataType> supportedTypes =
388 {
389 DataType::Float16,
390 DataType::Float32,
391 DataType::QuantisedAsymm8,
392 DataType::QuantisedSymm16
393 };
394
395 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
396 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
397 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
398}
399
telsoa014fcda012018-03-09 14:13:49 +0000400void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
401{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100402 const std::string descriptorName{"SplitterQueueDescriptor"};
403
404 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000405
Ruomei Yan25339c32019-05-28 16:48:20 +0100406 // Check the supported data types
407 std::vector<DataType> supportedTypes =
408 {
409 DataType::Float32,
410 DataType::Float16,
411 DataType::Boolean,
412 DataType::Signed32,
413 DataType::QuantisedAsymm8,
414 DataType::QuantisedSymm16
415 };
416
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100417 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
418 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100419 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100420 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
421 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
422
423 const std::string outputName = "output_" + std::to_string(i);
424 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100425 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100426
telsoa014fcda012018-03-09 14:13:49 +0000427 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
428 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100429 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000430 }
431
432 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
433 {
434 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100435 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000436 "has to match number of workloadInfo.m_OutputTensorInfos. "
437 "Number of windows: " +
438 to_string(m_ViewOrigins.size()) +
439 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
440 }
441
telsoa01c577f2c2018-08-31 09:22:23 +0100442 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000443 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
444 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
445 {
telsoa01c577f2c2018-08-31 09:22:23 +0100446 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000447 ViewOrigin const& e = m_ViewOrigins[w];
448 if (e.m_Origin.size() != inputDims)
449 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100450 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000451 "have the same dimensionality as the input tensor. "
452 "Window origin (index: " +
453 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
454 " dimensions, the input "
455 "tensor has " +
456 to_string(inputDims) + " dimensions.");
457 }
458 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
459 {
460 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
461 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
462 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100463 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000464 "be smaller or equal than the size of the input in that coord.");
465 }
466 }
467 }
468}
469
Jim Flynne242f2d2019-05-22 14:24:13 +0100470void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000471{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100472 const std::string descriptorName{"ConcatQueueDescriptor"};
473
474 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000475
476 if (m_Inputs.size() <= 0)
477 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100478 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000479 }
480 if (m_Outputs.size() <= 0)
481 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100482 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000483 }
484
485 if (workloadInfo.m_InputTensorInfos.size() <= 0)
486 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100487 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000488 }
489 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
490 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100491 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000492 }
493
Nikhil Raj8599a412018-11-19 14:51:07 +0000494 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
495 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100496 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000497 }
498
499 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
500 {
501 return;
502 }
503
telsoa014fcda012018-03-09 14:13:49 +0000504 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
505 {
506 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100507 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000508 "has to match number of workloadInfo.m_InputTensorInfos. "
509 "Number of windows: " +
510 to_string(m_ViewOrigins.size()) +
511 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
512 }
513
telsoa01c577f2c2018-08-31 09:22:23 +0100514 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000515 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
516 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
517 {
telsoa01c577f2c2018-08-31 09:22:23 +0100518 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000519 ViewOrigin const& e = m_ViewOrigins[w];
520 if (e.m_Origin.size() != outputDims)
521 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100522 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000523 "have the same dimensionality as the output tensor. "
524 "Window origin (index: " +
525 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
526 " dimensions, the output "
527 "tensor has " +
528 to_string(outputDims) + " dimensions.");
529 }
telsoa01c577f2c2018-08-31 09:22:23 +0100530 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000531 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
532 {
533 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
534 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
535 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100536 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000537 "be smaller or equal than the size of the output in that coord.");
538 }
539 }
540 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100541
542 // Check the supported data types
543 std::vector<DataType> supportedTypes =
544 {
545 DataType::Float32,
546 DataType::Float16,
547 DataType::Boolean,
548 DataType::Signed32,
549 DataType::QuantisedAsymm8,
550 DataType::QuantisedSymm16
551 };
552
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100553 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
554 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100555 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100556 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
557 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
558
559 const std::string inputName = "input_" + std::to_string(i);
560 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100561 }
telsoa014fcda012018-03-09 14:13:49 +0000562}
563
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100564void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
565{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100566 const std::string descriptorName{"StackQueueDescriptor"};
567
568 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100569
570 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
571 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100572 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100573 }
574
575 // All inputs must have the same shape, which is defined in parameters
576 const TensorShape& inputShape = m_Parameters.m_InputShape;
577 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
578 {
579 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
580 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100581 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100582 }
583 }
584
585 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
586 // since the output tensor has an additional dimension.
587 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
588 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100589 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100590 "than the number of input dimensions.");
591 }
592
593 // Output shape must be as inferred from the input shape
594 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
595 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
596 {
597 if (outputShape[i] != inputShape[i])
598 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100599 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100600 "match shape inferred from input tensor.");
601 }
602 }
603
604 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
605 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100606 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100607 "match shape inferred from input tensor.");
608 }
609
610 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
611 {
612 if (outputShape[i] != inputShape[i-1])
613 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100614 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100615 "match shape inferred from input tensor.");
616 }
617 }
618
619 // Check the supported data types
620 std::vector<DataType> supportedTypes =
621 {
622 DataType::Float32,
623 DataType::Float16,
624 DataType::Boolean,
625 DataType::Signed32,
626 DataType::QuantisedAsymm8,
627 DataType::QuantisedSymm16
628 };
629
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100630 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100631
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100632 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100633 {
634 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
635 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100636 descriptorName,
637 "input_0",
638 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100639 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100640
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100641 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
642 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100643 descriptorName,
644 "input_0",
645 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100646}
647
telsoa014fcda012018-03-09 14:13:49 +0000648void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
649{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100650 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000651
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100652 ValidateNumInputs(workloadInfo, descriptorName, 1);
653 ValidateNumOutputs(workloadInfo, descriptorName, 1);
654
655 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
656 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
657
658 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
659
660 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +0000661 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100662 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +0000663 }
664
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100665 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000666
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100667 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
668 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000669
670 if (m_Parameters.m_BiasEnabled)
671 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100672 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000673
telsoa01c577f2c2018-08-31 09:22:23 +0100674 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100675 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
676 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000677
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100678 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
679 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000680 }
681
Francis Murtagh46c09d02019-05-28 08:15:28 +0100682 // Check the supported data types
683 std::vector<DataType> supportedTypes =
684 {
685 DataType::Float32,
686 DataType::Float16,
687 DataType::QuantisedAsymm8,
688 DataType::QuantisedSymm16
689 };
690
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100691 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
692 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000693}
694
telsoa014fcda012018-03-09 14:13:49 +0000695void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
696{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100697 const std::string descriptorName{"NormalizationQueueDescriptor"};
698
699 ValidateNumInputs(workloadInfo, descriptorName, 1);
700 ValidateNumOutputs(workloadInfo, descriptorName, 1);
701
702 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
703 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100704
705 // Check the supported data types
706 std::vector<DataType> supportedTypes =
707 {
708 DataType::Float16,
709 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100710 DataType::QuantisedAsymm8,
711 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100712 };
713
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100714 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100715
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100716 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100717
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100718 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000719}
720
721void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
722{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100723 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000724
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100725 ValidateNumInputs(workloadInfo, descriptorName, 2);
726 ValidateNumOutputs(workloadInfo, descriptorName, 1);
727
728 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
729 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
730 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
731
732 std::vector<DataType> supportedTypes =
733 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100734 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100735 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +0100736 DataType::QuantisedSymm16,
737 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100738 };
739
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100740 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
741 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
742 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100743
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100744 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
745 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100746
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100747 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
748 inputTensorInfo1,
749 outputTensorInfo,
750 descriptorName,
751 "input_0",
752 "input_1");
telsoa014fcda012018-03-09 14:13:49 +0000753}
754
telsoa014fcda012018-03-09 14:13:49 +0000755void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
756{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100757 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +0100758
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100759 ValidateNumInputs(workloadInfo, descriptorName, 2);
760 ValidateNumOutputs(workloadInfo, descriptorName, 1);
761
762 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
763 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
764 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
765
766 std::vector<DataType> supportedTypes =
767 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100768 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100769 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +0100770 DataType::QuantisedSymm16,
771 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100772 };
773
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100774 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
775 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
776 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100777
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100778 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
779 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100780
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100781 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
782 inputTensorInfo1,
783 outputTensorInfo,
784 descriptorName,
785 "input_0",
786 "input_1");
telsoa014fcda012018-03-09 14:13:49 +0000787}
788
789void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
790{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100791 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100792
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100793 ValidateNumInputs(workloadInfo, descriptorName, 1);
794 ValidateNumOutputs(workloadInfo, descriptorName, 1);
795
796 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
797 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100798
799 std::vector<DataType> supportedTypes =
800 {
801 DataType::Float16,
802 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100803 DataType::QuantisedAsymm8,
804 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100805 };
806
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100807 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
808 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100809
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100810 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
811 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
812 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100813
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100814 ValidatePointer(m_Mean, descriptorName, "mean");
815 ValidatePointer(m_Variance, descriptorName, "variance");
816 ValidatePointer(m_Beta, descriptorName, "beta");
817 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000818
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100819 const TensorInfo& mean = m_Mean->GetTensorInfo();
820 const TensorInfo& variance = m_Variance->GetTensorInfo();
821 const TensorInfo& beta = m_Beta->GetTensorInfo();
822 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000823
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100824 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
825 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
826 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
827 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000828
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100829 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
830 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
831 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000832}
833
834void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
835{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100836 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000837
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100838 ValidateNumInputs(workloadInfo, descriptorName, 1);
839 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000840
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100841 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
842 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +0000843
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100844 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
845 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +0000846
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100847 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000848
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100849 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
850 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000851
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100852 ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight");
telsoa014fcda012018-03-09 14:13:49 +0000853
854 if (m_Parameters.m_BiasEnabled)
855 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100856 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000857
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100858 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
859 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
860
861 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
862 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000863 }
864
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100865 std::vector<DataType> supportedTypes =
866 {
Ruomei Yan88d44b82019-05-23 14:29:06 +0100867 DataType::Float32,
868 DataType::QuantisedAsymm8,
869 DataType::QuantisedSymm16,
870 DataType::Float16
871 };
872
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100873 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
874 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
875}
Ruomei Yan88d44b82019-05-23 14:29:06 +0100876
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100877void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
878{
879 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
880
881 ValidateNumInputs(workloadInfo, descriptorName, 1);
882 ValidateNumOutputs(workloadInfo, descriptorName, 1);
883
884 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
885 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
886
887 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
888 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
889
890 ValidatePointer(m_Weight, descriptorName, "weight");
891
892 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
893 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
894
895 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
896 {
897 throw InvalidArgumentException(
898 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
899 "cannot be smaller than 1.") % descriptorName %
900 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
901 }
902
903 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
904
905 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
906 // inputChannels * channelMultiplier should be equal to outputChannels.
907 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
908 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
909 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
910 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
911 {
912 throw InvalidArgumentException(
913 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
914 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
915 "(provided %4%).") % descriptorName % numWeightOutputChannels %
916 numWeightInputChannels % numWeightChannelMultiplier));
917 }
918
919 ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight");
920
921 if (m_Parameters.m_BiasEnabled)
922 {
923 ValidatePointer(m_Bias, descriptorName, "bias");
924
925 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
926 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
927
928 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
929 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
930 }
931
932 std::vector<DataType> supportedTypes =
933 {
934 DataType::Float32,
935 DataType::QuantisedAsymm8,
936 DataType::QuantisedSymm16,
937 DataType::Float16
938 };
939
940 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
941 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000942}
943
944void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
945{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100946 const std::string descriptorName{"PermuteQueueDescriptor"};
947
948 ValidateNumInputs(workloadInfo, descriptorName, 1);
949 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000950
951 const PermutationVector& mapping = m_Parameters.m_DimMappings;
952
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100953 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
954 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +0000955
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100956 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
957 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +0000958
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100959 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +0000960 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100961 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +0000962 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100963 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
964 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
965 "must match dst dimension " + to_string(mapping[i]) +
966 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +0000967 }
968 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100969
970 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000971}
972
973void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
974{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100975 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000976
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100977 ValidateNumInputs(workloadInfo, descriptorName, 1);
978 ValidateNumOutputs(workloadInfo, descriptorName, 1);
979
980 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
981 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
982
983 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
984 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +0100985
986 std::vector<DataType> supportedTypes =
987 {
988 DataType::Float32,
989 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +0100990 DataType::QuantisedAsymm8,
991 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +0100992 };
993
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100994 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
995 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000996}
997
998void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
999{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001000 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001001
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001002 ValidateNumInputs(workloadInfo, descriptorName, 1);
1003 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1004
1005 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1006 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1007
1008 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1009 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001010
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001011 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001012 {
1013 DataType::Float16,
1014 DataType::Float32,
1015 DataType::QuantisedAsymm8,
1016 DataType::QuantisedSymm16
1017 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001018
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001019 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1020 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001021
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001022 // ResizeBilinear only changes width and height: batch and channel count must match.
1023 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1024 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001025 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001026 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001027 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001028 boost::str(boost::format("%1%: Input batch size (%2%) "
1029 "does not match output batch size (%3%)") %
1030 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001031 }
1032
Teresa Charlin970f43b2019-07-01 13:51:07 +01001033 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001034 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1035 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001036 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001037 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001038 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001039 boost::str(boost::format("%1%: Input channel count (%2%) "
1040 "does not match output channel count (%3%)") %
1041 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001042 }
1043}
1044
1045void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1046{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001047 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001048
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001049 ValidateNumInputs(workloadInfo, descriptorName, 1);
1050 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1051
1052 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1053 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1054
1055 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1056 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001057
1058 std::vector<DataType> supportedTypes =
1059 {
1060 DataType::Float16,
1061 DataType::Float32,
1062 DataType::QuantisedAsymm8,
1063 DataType::QuantisedSymm16
1064 };
1065
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001066 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1067 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001068
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001069 // Resize only changes width and height: batch and channel count must match.
1070 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1071 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001072 if (inputBatchSize != outputBatchSize)
1073 {
1074 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001075 boost::str(boost::format("%1%: Input batch size (%2%) "
1076 "does not match output batch size (%3%)") %
1077 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001078 }
1079
1080 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001081 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1082 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001083 if (inputChannelCount != outputChannelCount)
1084 {
1085 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001086 boost::str(boost::format("%1%: Input channel count (%2%) "
1087 "does not match output channel count (%3%)") %
1088 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001089 }
1090}
1091
1092void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1093{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001094 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001095
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001096 ValidateNumInputs(workloadInfo, descriptorName, 1);
1097 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1098
1099 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1100 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1101
1102 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1103 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1104
1105 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1106
telsoa014fcda012018-03-09 14:13:49 +00001107 if (m_Parameters.m_Min > m_Parameters.m_Max)
1108 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001109 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001110 }
telsoa014fcda012018-03-09 14:13:49 +00001111}
1112
1113void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1114{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001115 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001116
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001117 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001118 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1119
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001120 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1121 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1122
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001123 if (inputTensorInfo.GetNumDimensions() > 4)
1124 {
1125 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1126 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001127
1128 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001129
1130 // Check the supported data types
1131 std::vector<DataType> supportedTypes =
1132 {
1133 DataType::Float32,
1134 DataType::Float16,
1135 DataType::QuantisedAsymm8,
1136 DataType::QuantisedSymm16
1137 };
1138
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001139 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1140 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1141
1142 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001143}
1144
1145void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1146{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001147 const std::string descriptorName{"ConstantQueueDescriptor"};
1148
1149 ValidateNumInputs(workloadInfo, descriptorName, 0);
1150 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001151
1152 if (!m_LayerOutput)
1153 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001154 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001155 }
1156
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001157 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1158 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001159
1160 // Check the supported data types
1161 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001162 {
1163 DataType::Float32,
1164 DataType::Float16,
1165 DataType::Signed32,
1166 DataType::QuantisedAsymm8,
1167 DataType::QuantisedSymm16
1168 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001169
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001170 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001171}
1172
1173void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1174{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001175 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001176
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001177 ValidateNumInputs(workloadInfo, descriptorName, 1);
1178 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1179
1180 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1181 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1182
1183 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001184
1185 // Check the supported data types
1186 std::vector<DataType> supportedTypes =
1187 {
1188 DataType::Float32,
1189 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001190 DataType::QuantisedAsymm8,
1191 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001192 };
1193
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001194 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1195 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001196}
1197
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001198void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1199{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001200 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001201
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001202 ValidateNumInputs(workloadInfo, descriptorName, 1);
1203 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1204
1205 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1206 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1207
1208 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1209 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001210
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001211 if (m_Parameters.m_BlockShape.size() != 2)
1212 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001213 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001214 }
1215
1216 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1217 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001218 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1219 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001220 }
1221
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001222 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001223
1224 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001225 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001226
Matthew Bentham8800c002018-11-19 13:19:28 +00001227 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001228
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001229 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1230 widthPad.first + widthPad.second;
1231 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1232 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001233
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001234 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1235 inputShape[dimensionIndices.GetChannelsIndex()];
1236 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001237
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001238 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001239 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001240 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001241 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001242 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001243 }
1244
1245 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001246 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001247 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1248 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001249 }
nikraj01120522a2019-05-31 11:33:07 +01001250
1251 std::vector<DataType> supportedTypes =
1252 {
1253 DataType::Float16,
1254 DataType::Float32,
1255 DataType::QuantisedAsymm8,
1256 DataType::QuantisedSymm16
1257 };
1258
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001259 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1260 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001261}
1262
Keith Davisa57eccb2019-06-14 17:33:22 +01001263void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1264{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001265 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001266
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001267 ValidateNumInputs(workloadInfo, descriptorName, 1);
1268 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001269
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001270 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1271 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1272
1273 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1274 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001275
1276 std::vector<DataType> supportedTypes =
1277 {
1278 DataType::Float32,
1279 DataType::Float16,
James Conroyd2aa85e2019-07-01 17:12:40 +01001280 DataType::QuantisedAsymm8,
1281 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001282 };
1283
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001284 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1285 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001286
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001287 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1288 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1289 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1290 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001291
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001292 const TensorShape& inputShape = inputTensorInfo.GetShape();
Keith Davisa57eccb2019-06-14 17:33:22 +01001293
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001294 const unsigned int numInputElements =
1295 inputShape[0] * inputShape[wIndex] * inputShape[hIndex] * inputShape[cIndex];
1296 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
1297
1298 if (numOutputElements != numInputElements)
Keith Davisa57eccb2019-06-14 17:33:22 +01001299 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001300 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
1301 std::to_string(numInputElements) + " but output tensor has " +
1302 std::to_string(numOutputElements) + " elements.");
Keith Davisa57eccb2019-06-14 17:33:22 +01001303 }
1304
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001305 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001306 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001307 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1308 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001309 }
1310}
1311
telsoa014fcda012018-03-09 14:13:49 +00001312void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1313{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001314 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001315
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001316 ValidateNumInputs(workloadInfo, descriptorName, 1);
1317 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1318
1319 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1320 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001321
1322 std::vector<DataType> supportedTypes =
1323 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001324 DataType::Float32,
1325 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +01001326 };
1327
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001328 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001329
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001330 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001331 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001332 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001333 }
1334}
1335
telsoa01c577f2c2018-08-31 09:22:23 +01001336void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1337{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001338 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1339
1340 const std::string descriptorName{"LstmQueueDescriptor"};
1341
1342 // check dimensions of all inputs and outputs
1343 if (workloadInfo.m_InputTensorInfos.size() != 3)
1344 {
1345 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1346 }
1347 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1348 {
1349 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1350 }
1351
1352 std::vector<DataType> supportedTypes =
1353 {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001354 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001355 DataType::Float32,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001356 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001357 };
1358
Jan Eilers38e05bd2019-06-26 13:10:09 +01001359 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001360 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1361
Jan Eilers38e05bd2019-06-26 13:10:09 +01001362 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001363 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001364 {
1365 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1366 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001367 descriptorName,
1368 "input_0",
1369 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001370 }
1371 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001372 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001373 {
1374 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1375 workloadInfo.m_OutputTensorInfos[i],
1376 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001377 "input_0",
1378 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001379 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001380
Jan Eilers38e05bd2019-06-26 13:10:09 +01001381 // TODO: check clipping parameter is valid
1382
1383 // Inferring batch size, number of outputs and number of cells from the inputs.
1384 // TODO: figure out if there is a way to make sure the specific inputs are at that index of workloadInfo
1385 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1386 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1387 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1388 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1389 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1390 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1391
Jan Eilers38e05bd2019-06-26 13:10:09 +01001392 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001393 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1394 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001395 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001396 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1397 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001398 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001399 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1400 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001401 // scratchBufferTensor
1402 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001403 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1404 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001405 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001406 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1407 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001408 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001409 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1410 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001411 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001412 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1413 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001414
1415
1416 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1417 if ( m_InputToInputWeights )
1418 {
1419 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1420 (n_cell * n_input), "InputLayerNormWeights");
1421 }
1422
1423 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1424 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1425 (n_cell * n_input), "InputToForgetWeights");
1426
1427 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1428 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1429 (n_cell * n_input), "InputToCellWeights");
1430
1431 if ( m_RecurrentToInputWeights )
1432 {
1433 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1434 (n_cell * n_output), "RecurrentToInputWeights");
1435 }
1436
1437 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1438 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1439 (n_cell * n_output), "RecurrentToForgetWeights");
1440
1441 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1442 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1443 (n_cell * n_output), "RecurrentToCellWeights");
1444
1445 // Make sure the input-gate's parameters are either both present (regular
1446 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1447 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1448 !m_Parameters.m_CifgEnabled) ||
1449 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1450 m_Parameters.m_CifgEnabled));
1451 if (!cifg_weights_all_or_none)
1452 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001453 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1454 "RecurrentToInputWeights must either both be present (regular LSTM) "
1455 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1456 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001457 }
1458
1459 if ( m_CellToInputWeights )
1460 {
1461 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1462 n_cell, "CellToInputWeights");
1463 }
1464 if ( m_CellToForgetWeights )
1465 {
1466 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1467 n_cell, "CellToForgetWeights");
1468 }
1469 if ( m_CellToOutputWeights )
1470 {
1471 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1472 n_cell, "CellToOutputWeights");
1473 }
1474
1475 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1476 bool peephole_weights_all_or_none =
1477 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1478 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1479 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1480 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1481 if (!peephole_weights_all_or_none)
1482 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001483 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001484 }
1485
1486 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1487 if (m_Parameters.m_CifgEnabled)
1488 {
1489 if (m_InputGateBias)
1490 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001491 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001492 }
1493 }
1494 else
1495 {
1496 if (!m_InputGateBias)
1497 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001498 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1499 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001500 }
1501 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1502 n_cell, "InputGateBias");
1503 }
1504
1505 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1506 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1507
1508 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1509 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1510
1511 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1512 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1513
1514 if (m_ProjectionWeights)
1515 {
1516 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
1517 (n_cell * n_output), "ProjectionWeights");
1518 }
1519 if (m_ProjectionBias)
1520 {
1521 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
1522 }
1523
1524 // Making sure the projection tensors are consistent:
1525 // 1) If projection weight is not present, then projection bias should not be
1526 // present.
1527 // 2) If projection weight is present, then projection bias is optional.
1528 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
1529 !m_Parameters.m_ProjectionEnabled)
1530 || (m_ProjectionWeights && !m_ProjectionBias &&
1531 m_Parameters.m_ProjectionEnabled)
1532 || (m_ProjectionWeights && m_ProjectionBias &&
1533 m_Parameters.m_ProjectionEnabled));
1534 if (!projecton_tensors_consistent)
1535 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001536 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001537 }
1538
1539 // The four layer normalization weights either all have values or none of them have values. Additionally, if
1540 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
1541 // either all have values or none of them have values. Layer normalization is used when the values of all the
1542 // layer normalization weights are present
1543 if (m_InputLayerNormWeights)
1544 {
1545 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
1546 }
1547 if (m_ForgetLayerNormWeights)
1548 {
1549 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1550 }
1551 if (m_CellLayerNormWeights)
1552 {
1553 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1554 }
1555 if (m_OutputLayerNormWeights)
1556 {
1557 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1558 }
1559
Jan Eilers38e05bd2019-06-26 13:10:09 +01001560 if (m_Parameters.m_LayerNormEnabled)
1561 {
1562 if (!m_Parameters.m_CifgEnabled)
1563 {
1564 if (!m_InputLayerNormWeights)
1565 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001566 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
1567 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001568 }
1569 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
1570 1, n_cell, "InputLayerNormWeights");
1571 }
1572 else if (m_InputLayerNormWeights)
1573 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001574 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
1575 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001576 }
1577
1578 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
1579 "ForgetLayerNormWeights");
1580 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1581
1582 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
1583 "OutputLayerNormWeights");
1584 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1585
1586 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
1587 "CellLayerNormWeights");
1588 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1589 }
1590 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
1591 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001592 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
1593 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001594 }
telsoa01c577f2c2018-08-31 09:22:23 +01001595}
1596
1597void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1598{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001599 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001600
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001601 ValidateNumInputs(workloadInfo, descriptorName, 1);
1602 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1603
1604 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1605 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1606
1607 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01001608 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001609 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01001610 }
1611
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001612 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001613 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001614 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001615 }
1616
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001617 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001618}
1619
1620void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1621{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001622 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001623
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001624 ValidateNumInputs(workloadInfo, descriptorName, 1);
1625 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1626
1627 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1628 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1629
1630 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001631 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001632 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001633 }
1634
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001635 if (outputTensorInfo.GetDataType() != DataType::Float32)
1636 {
1637 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
1638 }
1639
1640 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001641}
1642
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001643void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1644{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001645 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001646
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001647 ValidateNumInputs(workloadInfo, descriptorName, 2);
1648 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1649
1650 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1651 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1652 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1653
1654 std::vector<DataType> supportedTypes =
1655 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001656 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001657 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001658 DataType::QuantisedSymm16,
1659 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001660 };
1661
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001662 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1663 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1664 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001665
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001666 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1667 inputTensorInfo1,
1668 outputTensorInfo,
1669 descriptorName,
1670 "input_0",
1671 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001672}
1673
David Beckc2044fe2018-09-05 15:00:38 +01001674void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1675{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001676 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01001677
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001678 ValidateNumInputs(workloadInfo, descriptorName, 2);
1679 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1680
1681 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1682 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1683 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1684
1685 std::vector<DataType> supportedTypes =
1686 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001687 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001688 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001689 DataType::QuantisedSymm16,
1690 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001691 };
1692
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001693 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1694 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1695 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001696
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001697 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1698 inputTensorInfo1,
1699 outputTensorInfo,
1700 descriptorName,
1701 "input_0",
1702 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01001703}
1704
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001705void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1706{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001707 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001708
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001709 ValidateNumInputs(workloadInfo, descriptorName, 2);
1710 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1711
1712 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1713 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1714 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1715
1716 std::vector<DataType> supportedTypes =
1717 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001718 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001719 DataType::QuantisedAsymm8,
1720 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001721 };
1722
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001723 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1724 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1725 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001726
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001727 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1728 inputTensorInfo1,
1729 outputTensorInfo,
1730 descriptorName,
1731 "input_0",
1732 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001733}
1734
narpra01a6bf9122018-09-10 09:50:09 +01001735void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1736{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001737 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01001738
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001739 ValidateNumInputs(workloadInfo, descriptorName, 1);
1740 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1741
1742 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1743 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01001744
1745 std::vector<DataType> supportedTypes =
1746 {
1747 DataType::Float32,
1748 DataType::Float16,
1749 DataType::QuantisedAsymm8,
1750 DataType::QuantisedSymm16
1751 };
narpra01eb061912018-09-10 17:35:27 +01001752
James Conroy4d1ff582019-06-10 17:06:39 +01001753 // First check if input tensor data type is supported, then
1754 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001755 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1756 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01001757
narpra0132b90462018-09-13 11:07:48 +01001758 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01001759 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001760 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01001761 }
narpra0132b90462018-09-13 11:07:48 +01001762 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01001763 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001764 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01001765 }
1766 else
1767 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001768 unsigned int outputDim =
1769 inputTensorInfo.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
1770 ValidateTensorNumDimensions(outputTensorInfo,
1771 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01001772 outputDim > 0 ? outputDim : 1,
1773 "output");
1774 }
narpra01a6bf9122018-09-10 09:50:09 +01001775}
1776
jimfly012c9322a2018-09-19 10:59:49 +01001777void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1778{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001779 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01001780
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001781 ValidateNumInputs(workloadInfo, descriptorName, 1);
1782 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1783
1784 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1785 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01001786
jimfly012c9322a2018-09-19 10:59:49 +01001787 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001788 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
1789
jimfly012c9322a2018-09-19 10:59:49 +01001790 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001791 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
1792 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
1793 "as there are dimensions in the input tensor that is " +
1794 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
1795 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01001796 }
1797}
1798
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001799void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1800{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001801 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001802
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001803 ValidateNumInputs(workloadInfo, descriptorName, 1);
1804 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001805
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001806 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1807 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1808
1809 if (inputTensorInfo.GetDataType() != DataType::Float32)
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001810 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001811 throw InvalidArgumentException(descriptorName + ": Quantize only accepts Float32 inputs.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001812 }
1813
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001814 if (outputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
1815 outputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001816 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001817 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001818 }
1819}
1820
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00001821void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1822{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001823 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001824
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001825 ValidateNumInputs(workloadInfo, descriptorName, 1);
1826 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001827
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001828 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1829 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001830
1831 std::vector<DataType> supportedTypes =
1832 {
1833 DataType::Float32,
1834 DataType::QuantisedAsymm8,
1835 DataType::QuantisedSymm16
1836 };
1837
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001838 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1839 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00001840}
1841
Conor Kennedy430b5d82018-11-14 15:28:28 +00001842void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1843{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001844 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00001845
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001846 ValidateNumInputs(workloadInfo, descriptorName, 1);
1847 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1848
1849 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1850 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001851
1852 std::vector<DataType> supportedTypes =
1853 {
1854 DataType::Float16,
1855 DataType::Float32,
Matteo Martincigh42666a12019-05-29 08:53:41 +01001856 DataType::QuantisedAsymm8,
1857 DataType::QuantisedSymm16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001858 };
1859
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001860 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1861 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001862
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001863 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001864
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001865 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00001866 if (rank > 4)
1867 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001868 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00001869 }
1870
Conor Kennedy430b5d82018-11-14 15:28:28 +00001871 // Begin, End & Stride length must be of rank(input0)
1872 if (m_Parameters.m_Begin.size() != rank)
1873 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001874 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00001875 }
1876
1877 if (m_Parameters.m_End.size() != rank)
1878 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001879 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00001880 }
1881
1882 if (m_Parameters.m_Stride.size() != rank)
1883 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001884 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00001885 }
1886
1887 // Stride entries must be non-zero
1888 for (auto& stride : m_Parameters.m_Stride)
1889 {
1890 if (stride == 0)
1891 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001892 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00001893 }
1894 }
1895}
1896
kevmay0190539692018-11-29 08:40:19 +00001897void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1898{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001899 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00001900
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001901 ValidateNumInputs(workloadInfo, descriptorName, 2);
1902 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1903
1904 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1905 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1906 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1907
1908 std::vector<DataType> supportedTypes =
1909 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001910 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001911 DataType::QuantisedAsymm8,
1912 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001913 };
1914
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001915 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1916 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1917 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001918
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001919 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1920 inputTensorInfo1,
1921 outputTensorInfo,
1922 descriptorName,
1923 "input_0",
1924 "input_1");
kevmay0190539692018-11-29 08:40:19 +00001925}
1926
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00001927void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1928{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001929 const std::string descriptorName{"DebugQueueDescriptor"};
1930
1931 ValidateNumInputs(workloadInfo, descriptorName, 1);
1932 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00001933}
1934
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00001935void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1936{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001937 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00001938
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001939 ValidateNumInputs(workloadInfo, descriptorName, 2);
1940 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00001941
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001942 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1943 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1944 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1945
1946 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1947 inputTensorInfo1,
1948 outputTensorInfo,
1949 descriptorName,
1950 "input_0",
1951 "input_1");
1952
1953 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00001954 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001955 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00001956 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00001957}
1958
FrancisMurtagh878f0232018-12-19 10:56:15 +00001959void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1960{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001961 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00001962
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001963 ValidateNumInputs(workloadInfo, descriptorName, 2);
1964 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00001965
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001966 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1967 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1968 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1969
1970 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1971 inputTensorInfo1,
1972 outputTensorInfo,
1973 descriptorName,
1974 "input_0",
1975 "input_1");
1976
1977 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00001978 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001979 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00001980 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00001981}
1982
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001983void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1984{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001985 const std::string descriptorName{"RsqrtQueueDescriptor"};
1986
1987 ValidateNumInputs(workloadInfo, descriptorName, 1);
1988 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1989
1990 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1991 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1992
1993 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01001994
1995 std::vector<DataType> supportedTypes =
1996 {
1997 DataType::Float16,
1998 DataType::Float32,
nikraj0124d73212019-06-14 14:20:40 +01001999 DataType::QuantisedAsymm8,
2000 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01002001 };
2002
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002003 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2004 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002005}
2006
narpra01b89b05f2019-01-16 09:53:09 +00002007void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2008{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002009 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002010
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002011 ValidateNumInputs(workloadInfo, descriptorName, 2);
2012 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002013
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002014 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2015 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002016 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002017 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002018 }
2019
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002020 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2021 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2022
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002023 std::vector<DataType> supportedTypes =
2024 {
2025 DataType::Float16,
2026 DataType::Float32,
2027 DataType::QuantisedAsymm8,
2028 DataType::QuantisedSymm16
2029 };
2030
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002031 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002032
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002033 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002034
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002035 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2036 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002037}
2038
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002039void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2040{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002041 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2042
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002043 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002044
2045 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2046 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002047 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002048 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2049 }
2050
2051 if (m_Anchors == nullptr)
2052 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002053 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002054 }
2055
2056 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002057 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2058 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2059
2060 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002061 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002062 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2063 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002064
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002065 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2066 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2067 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002068
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002069 const std::vector<DataType> supportedInputTypes =
2070 {
2071 DataType::Float32,
2072 DataType::QuantisedAsymm8,
2073 DataType::QuantisedSymm16
2074 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002075
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002076 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2077 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2078 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2079
2080 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2081 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2082 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2083 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2084
2085 // NOTE: Output is always Float32 regardless of input type
2086 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2087 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2088 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2089 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002090
2091 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2092 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002093 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002094 "must be positive and less than or equal to 1.");
2095 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002096
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002097 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2098 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002099 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002100 "should be equal to number of classes + 1.");
2101 }
2102}
2103
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002104void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2105{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002106 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002107
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002108 ValidateNumInputs(workloadInfo, descriptorName, 1);
2109 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2110
2111 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2112 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2113
2114 if (inputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
2115 inputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002116 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002117 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002118 }
2119
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002120 if (outputTensorInfo.GetDataType() != DataType::Float32)
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002121 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002122 throw InvalidArgumentException(descriptorName + ": Output of dequantize layer must be Float32 type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002123 }
2124}
2125
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002126void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2127{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002128 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002129
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002130 ValidateNumInputs(workloadInfo, descriptorName, 2);
2131 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002132
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002133 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2134 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2135 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002136
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002137 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2138 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2139
2140 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2141 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002142}
2143
Sadik Armaganeff363d2019-04-05 15:25:46 +01002144void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2145{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002146 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002147
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002148 ValidateNumInputs(workloadInfo, descriptorName, 2);
2149 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2150
2151 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2152 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2153
2154 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2155 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2156
2157 std::vector<DataType> supportedTypes =
2158 {
Sadik Armaganeff363d2019-04-05 15:25:46 +01002159 DataType::Float32,
2160 DataType::QuantisedAsymm8,
2161 DataType::QuantisedSymm16
2162 };
2163
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002164 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2165 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002166
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002167 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2168 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002169
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002170 ValidateTensorShapesMatch(inputTensorInfo0,
2171 outputTensorInfo0,
2172 descriptorName,
2173 "input_0",
2174 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002175
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002176 ValidateTensorShapesMatch(inputTensorInfo0,
2177 outputTensorInfo1,
2178 descriptorName,
2179 "input_0",
2180 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002181}
2182
Matteo Martincigh49124022019-01-11 13:25:59 +00002183void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2184{
2185 // This is internally generated so it should not need validation.
2186}
2187
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002188void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2189{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002190 const std::string& descriptorName{"PreluQueueDescriptor"};
2191
2192 ValidateNumInputs(workloadInfo, descriptorName, 2);
2193 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2194
2195 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2196 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2197 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002198
2199 std::vector<DataType> supportedTypes
2200 {
2201 DataType::Float16,
2202 DataType::Float32,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002203 DataType::QuantisedAsymm8,
2204 DataType::QuantisedSymm16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002205 };
2206
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002207 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2208 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002209
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002210 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002211
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002212 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2213 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002214
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002215 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2216 alphaTensorInfo,
2217 outputTensorInfo,
2218 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002219 "input",
2220 "alpha");
2221}
2222
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002223void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2224{
2225 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2226
2227 ValidateNumInputs(workloadInfo, descriptorName, 1);
2228 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2229
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002230 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2231 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2232
2233 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2234 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002235
2236 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002237
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002238 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2239 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
2240 ValidateTensorDataType(weightTensorInfo, inputTensorInfo.GetDataType(), descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002241
2242 if (m_Parameters.m_BiasEnabled)
2243 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002244 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002245
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002246 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
2247 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002248
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002249 ValidateTensorDataType(biasTensorInfo,
2250 GetBiasDataType(inputTensorInfo.GetDataType()),
2251 descriptorName,
2252 "bias");
2253
2254 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002255 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002256}
2257
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002258} // namespace armnn