blob: fa5c6fe38e82f5f58d88278b5c9955b9182193bf [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00006#include <backendsCommon/WorkloadData.hpp>
7#include <backendsCommon/CpuTensorHandle.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/TensorUtils.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000013#include <string>
14#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/format.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010017#include <boost/numeric/conversion/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
telsoa01c577f2c2018-08-31 09:22:23 +010029 case DataType::Float16:
30 return DataType::Float16;
telsoa014fcda012018-03-09 14:13:49 +000031 case DataType::Float32:
32 return DataType::Float32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000033 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000034 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000035 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010036 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000037 default:
38 BOOST_ASSERT_MSG(false, "Invalid input data type");
39 return DataType::Float32;
40 }
41}
42
43namespace
44{
45
46//---------------------------------------------------------------
47//android ndk does not support std::to_string function.
48template <typename T>
49std::string to_string(T value)
50{
51 std::ostringstream os;
52 os << value;
53 return os.str();
54}
55
56//---------------------------------------------------------------
57void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
58{
59 if (!ptr)
60 {
61 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
62 paramName + " parameter must be set.");
63 }
64}
65
66//---------------------------------------------------------------
67void ValidateTensorShapesMatch(const TensorInfo& first,
68 const TensorInfo& second,
69 std::string const& descName,
70 std::string const& firstName,
71 std::string const& secondName)
72{
73 if (first.GetShape() != second.GetShape())
74 {
75 throw InvalidArgumentException(descName + ": "
76 + firstName + " & " + secondName + " must have identical shapes");
77 }
78}
79
80//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010081void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000082{
Sadik Armaganeff363d2019-04-05 15:25:46 +010083 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000084 {
85 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010086 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000087 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
88 }
89}
90
91//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010092void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000093{
Sadik Armaganeff363d2019-04-05 15:25:46 +010094 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000095 {
96 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010097 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +000098 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
99 }
100}
101
102//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100103void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000104 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100105 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000106 std::string const& tensorName)
107{
108 if (tensor.GetNumDimensions() != numDimensions)
109 {
110 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
111 to_string(tensor.GetNumDimensions()) + " dimensions for " +
112 tensorName + " tensor.");
113 }
114}
115
116//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100117void ValidateTensorNumElements(const TensorInfo& tensor,
118 std::string const& descName,
119 unsigned int numElements,
120 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100121{
122 if (tensor.GetNumElements() != numElements)
123 {
124 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100125 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100126 tensorName + " tensor.");
127 }
128}
129
130//---------------------------------------------------------------
131void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100132 unsigned int numDimension,
133 unsigned int numElements,
134 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100135{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100136 const std::string functionName{"ValidateTensorNumDimNumElem"};
137 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
138 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100139}
140
141//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000142void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
143 const std::string& descName, std::string const& tensorName)
144{
145 if (tensor.GetDataType() != dataType)
146 {
147 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
148 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
149 }
150}
151
152//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100153void ValidateTensorQuantizationSpace(const TensorInfo& first,
154 const TensorInfo& second,
155 const std::string& descName,
156 std::string const& firstName,
157 std::string const& secondName)
158{
159 if (!first.IsQuantized() ||
160 !second.IsQuantized())
161 {
162 // Not a quantized type, ignore the validation
163 return;
164 }
165
166 DataType firstDataType = first.GetDataType();
167 DataType secondDataType = second.GetDataType();
168
169 if (firstDataType != secondDataType)
170 {
171 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
172 " must be of the same quantized type, " +
173 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
174 secondName + " is " + GetDataTypeName(secondDataType));
175 }
176
177 if (!first.IsTypeSpaceMatch(second))
178 {
179 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
180 " must have the same quantization space, " +
181 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
182 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
183 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
184 " and scale " + to_string(second.GetQuantizationScale()));
185 }
186}
187
188//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100189void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
190 const TensorInfo& inputTensorInfo,
191 const TensorInfo& weightsTensorInfo,
192 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000193{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000194 // Helper lambda function to validate a single bias quantization scale value
195 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
196 {
ricbur013f4d7102019-10-31 16:22:18 +0000197 constexpr float tolerance = 0.000001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000198 if (std::abs(biasScale - expectedScale) > tolerance)
199 {
200 // Print the float values with extra precision to see very small differences
201 std::stringstream msg;
202 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
203 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
204 biasScale;
205 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
206 }
207 };
208
telsoa014fcda012018-03-09 14:13:49 +0000209 if (biasTensor.GetQuantizationOffset() != 0)
210 {
211 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
212 to_string(biasTensor.GetQuantizationOffset()));
213 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000214
215 if (biasTensor.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000216 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000217 // Validate per-axis quantization scales
218 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
219 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
220
221 if (weightScales.size() != biasScales.size())
222 {
223 std::stringstream msg;
224 msg << descName << ": Expected matchhing number of per-axis quantization scales, but got different "
225 << "values: weights=" << weightScales.size() << ", biases=" << biasScales.size();
226 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
227 }
228
229 for (size_t i = 0ul; i < biasScales.size(); ++i)
230 {
231 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
232 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
233 }
234 }
235 else
236 {
237 // Validate per-tensor quantization scale
238 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
239 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000240 }
241}
242
243//---------------------------------------------------------------
244void ValidateTensors(const std::vector<ITensorHandle*>& vec,
245 unsigned int numExpected,
246 const std::string& descName,
247 const std::string& varName)
248{
249 if (vec.empty() && numExpected > 0)
250 {
251 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
252 }
253
254 for (unsigned int i = 0; i < numExpected; ++i)
255 {
256 if (!vec[i])
257 {
258 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
259 }
260 }
261}
262
263//---------------------------------------------------------------
264void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
265 const TensorInfo& second,
266 const TensorInfo& output,
267 std::string const& descName,
268 std::string const& firstName,
269 std::string const& secondName)
270{
271 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
272 // broadcasted.
273 if (first.GetNumDimensions() != second.GetNumDimensions())
274 {
275 throw InvalidArgumentException(descName + ": Tensors "
276 + firstName + " & " + secondName
277 + " must have the same number of dimensions in order to be broadcasted");
278 }
279 uint32_t numDims = first.GetNumDimensions();
280 std::vector<uint32_t> outputDims(numDims, 0u);
281 for (uint32_t i = 0; i < numDims; i++)
282 {
283 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
284 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
285 if (dimsNotEqual && dimsNotOne)
286 {
287 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
288 }
289 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
290 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100291 TensorShape broadcastShape = TensorShape(boost::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000292 if (broadcastShape != output.GetShape())
293 {
294 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
295 + firstName + " & " + secondName
296 + " does not match the output shape");
297 }
298}
299
300//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100301void ValidateDataTypes(const TensorInfo& info,
302 const std::vector<armnn::DataType>& supportedTypes,
303 std::string const& descName)
304{
305 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
306 if (iterator == supportedTypes.end())
307 {
308 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
309 }
310}
311
James Conroy4d1ff582019-06-10 17:06:39 +0100312//---------------------------------------------------------------
313void ValidateTensorDataTypesMatch(const TensorInfo& first,
314 const TensorInfo& second,
315 std::string const& descName,
316 std::string const& firstName,
317 std::string const& secondName)
318{
319 if (first.GetDataType() != second.GetDataType())
320 {
321 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
322 " must have identical data types.");
323 }
324}
325
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100326//---------------------------------------------------------------
327void ValidateTensorNumElementsMatch(const TensorInfo& first,
328 const TensorInfo& second,
329 std::string const& descName,
330 std::string const& firstName,
331 std::string const& secondName)
332{
333 if (first.GetNumElements() != second.GetNumElements())
334 {
335 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
336 " must have the same number of elements.");
337 }
338}
339
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000340void ValidateWeightDataType(const TensorInfo& inputInfo,
341 const TensorInfo& weightInfo,
342 const std::string& descName)
343{
344 const DataType inputType = inputInfo.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +0000345 if (inputType == DataType::QAsymmU8)
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000346 {
347 const std::vector<DataType> validTypes =
348 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000349 DataType::QAsymmU8,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000350 DataType::QuantizedSymm8PerAxis
351 };
352
353 ValidateDataTypes(weightInfo, validTypes, descName);
354 }
355 else
356 {
357 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
358 }
359}
360
361void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
362 const std::string& descName,
363 const std::string& tensorName)
364{
365 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
366 if (!quantizationDim.has_value())
367 {
368 throw InvalidArgumentException(boost::str(
369 boost::format("%1%: Quantization dimension for per-axis quantization not set on tensor %2%.")
370 % descName % tensorName));
371 }
372
373 if (quantizationDim.value() != 0)
374 {
375 throw InvalidArgumentException(boost::str(
376 boost::format("%1%: Quantization dimension for per-axis quantization expected to be 0 on tensor %2%, "
377 "but got: %3%") % descName % tensorName % quantizationDim.value()));
378 }
379}
380
381void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
382 const std::string& descName,
383 const std::string& tensorName)
384{
385 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
386 if (quantizationOffset != 0)
387 {
388 throw InvalidArgumentException(boost::str(
389 boost::format("%1%: Quantization offset for per-axis quantization expected to be 0 on tensor %2%, "
390 "but got: %3%") % descName % tensorName % quantizationOffset));
391 }
392}
393
394void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
395 const TensorInfo& outputInfo,
396 const TensorInfo& weightInfo,
397 const Optional<TensorInfo>& optionalBiasInfo,
398 const std::string& descName)
399{
400 if (weightInfo.HasPerAxisQuantization())
401 {
402 const DataType inputDataType = inputInfo.GetDataType();
403 const DataType outputDataType = outputInfo.GetDataType();
404
405 const bool canHavePerAxisQuantization =
Derek Lambertif90c56d2020-01-10 17:14:08 +0000406 inputDataType == DataType::QAsymmU8 && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000407
408 if (!canHavePerAxisQuantization)
409 {
410 throw InvalidArgumentException(boost::str(
411 boost::format("%1%: Per-axis quantization parameters set on tensor %2%, "
412 "but data type does not support per-axis quantization.") % descName % "weight"));
413 }
414
415 ValidateTensorDataType(weightInfo, DataType::QuantizedSymm8PerAxis, descName, "weight");
416 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
417 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
418
419 if (optionalBiasInfo.has_value())
420 {
421 const TensorInfo& biasInfo = optionalBiasInfo.value();
422 if (!biasInfo.HasPerAxisQuantization())
423 {
424 throw InvalidArgumentException(boost::str(
425 boost::format("%1%: Per-axis quantization parameters not set on bias tensor, despite being set on "
426 "weight tensor.") % descName));
427 }
428
429 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
430 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
431 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
432 }
433 }
434}
435
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100436} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000437
438void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
439 unsigned int numExpectedIn, unsigned int numExpectedOut) const
440{
441 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
442 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
443}
444
445//---------------------------------------------------------------
446void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
447{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100448 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000449
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100450 ValidateNumInputs(workloadInfo, descriptorName, 1);
451 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000452
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100453 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
454 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
455
456 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
457 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000458
459 if (m_Inputs.size() != m_Outputs.size())
460 {
461 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100462 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
463 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000464 }
465
466 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
467 {
468 if (!m_Inputs[i])
469 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100470 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
471 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000472 }
473
474 if (!m_Outputs[i])
475 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100476 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
477 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000478 }
479 }
480}
481
Derek Lambertif674aa02019-08-01 15:56:25 +0100482//---------------------------------------------------------------
483void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
484{
485 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
486 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
487
488 if (workloadInfo.m_InputTensorInfos.size() != 1)
489 {
490 throw InvalidArgumentException(boost::str(
491 boost::format("Number of input infos (%1%) is not 1.")
492 % workloadInfo.m_InputTensorInfos.size()));
493
494 }
495
496 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
497 {
498 throw InvalidArgumentException(boost::str(
499 boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
500 % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
501 }
502
503 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
504 {
505 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
506 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
507 {
508 throw InvalidArgumentException(boost::str(
509 boost::format("Number of elements for tensor input and output %1% does not match")
510 % i ));
511 }
512 }
513
514 if (m_Inputs.size() != 1)
515 {
516 throw InvalidArgumentException(boost::str(
517 boost::format("Number of inputs (%1%) is not 1.")
518 % m_Inputs.size()));
519 }
520
521 if (m_Inputs.size() != m_Outputs.size())
522 {
523 throw InvalidArgumentException(boost::str(
524 boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
525 % m_Inputs.size() % m_Outputs.size()));
526 }
527
528 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
529 {
530 if (!m_Inputs[i])
531 {
532 throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
533 }
534
535 if (!m_Outputs[i])
536 {
537 throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
538 }
539 }
540}
541
542//---------------------------------------------------------------
543void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
544{
545 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
546 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
547
Derek Lambertif674aa02019-08-01 15:56:25 +0100548 if (m_Inputs.size() != 1)
549 {
550 throw InvalidArgumentException(boost::str(
551 boost::format("Number of inputs (%1%) is not 1.")
552 % m_Inputs.size()));
553 }
554
555 if (m_Outputs.size() != 0)
556 {
557 throw InvalidArgumentException(boost::str(
558 boost::format("Number of outputs (%1%) is not 0.")
559 % m_Inputs.size() % m_Outputs.size()));
560 }
561
562 if (!m_Inputs[0])
563 {
564 throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
565 }
566}
567
568//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000569void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
570{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100571 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100572
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100573 ValidateNumInputs(workloadInfo, descriptorName, 1);
574 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100575
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100576 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
577 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100578
579 std::vector<DataType> supportedTypes =
580 {
James Conroyd47a0642019-09-17 14:22:06 +0100581 DataType::Float16,
582 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000583 DataType::QAsymmU8,
584 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100585 };
586
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100587 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
588 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
589 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000590}
591
Nikhil Rajee391d52019-09-05 17:50:44 +0100592void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
593{
594 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
595
596 ValidateNumInputs(workloadInfo, descriptorName, 1);
597 ValidateNumOutputs(workloadInfo, descriptorName, 1);
598
599 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
600 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
601
Nikhil Raj68c2c902019-09-19 11:21:11 +0100602 if (outputTensorInfo.GetDataType() != DataType::Signed32)
603 {
604 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32.");
605 }
606
James Conroyd47a0642019-09-17 14:22:06 +0100607 std::vector<DataType> supportedInputTypes =
608 {
609 DataType::Float16,
610 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000611 DataType::QAsymmU8,
612 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000613 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +0100614 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100615
James Conroyd47a0642019-09-17 14:22:06 +0100616 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100617
618 auto inputShape = inputTensorInfo.GetShape();
619 auto outputShape = outputTensorInfo.GetShape();
620
621 auto inputNumDimensions = inputShape.GetNumDimensions();
622 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
623
624 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
625
626 // 1D input shape results in scalar output shape
627 if (inputShape.GetNumDimensions() == 1)
628 {
629 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
630 {
631 throw InvalidArgumentException(descriptorName + outputShapeError);
632 }
633 }
634 else
635 {
636 for (unsigned int i = 0; i < unsignedAxis; ++i)
637 {
638 if (outputShape[i] != inputShape[i])
639 {
640 throw InvalidArgumentException(descriptorName + outputShapeError);
641 }
642 }
643
644 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
645 {
646 if (outputShape[i - 1] != inputShape[i])
647 {
648 throw InvalidArgumentException(descriptorName + outputShapeError);
649 }
650 }
651 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100652}
653
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100654void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
655{
656 const std::string descriptorName{"SoftmaxQueueDescriptor"};
657
658 ValidateNumInputs(workloadInfo, descriptorName, 1);
659 ValidateNumOutputs(workloadInfo, descriptorName, 1);
660
661 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
662 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
663
664 std::vector<DataType> supportedTypes =
665 {
James Conroyd47a0642019-09-17 14:22:06 +0100666 DataType::Float16,
667 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000668 DataType::QAsymmU8,
669 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100670 };
671
672 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
673 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
674 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
675}
676
telsoa014fcda012018-03-09 14:13:49 +0000677void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
678{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100679 const std::string descriptorName{"SplitterQueueDescriptor"};
680
681 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000682
Ruomei Yan25339c32019-05-28 16:48:20 +0100683 // Check the supported data types
684 std::vector<DataType> supportedTypes =
685 {
James Conroyd47a0642019-09-17 14:22:06 +0100686 DataType::Float32,
687 DataType::Float16,
688 DataType::Boolean,
689 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000690 DataType::QAsymmU8,
691 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100692 };
693
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100694 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
695 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100696 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100697 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
698 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
699
700 const std::string outputName = "output_" + std::to_string(i);
701 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100702 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100703
telsoa014fcda012018-03-09 14:13:49 +0000704 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
705 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100706 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000707 }
708
709 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
710 {
711 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100712 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000713 "has to match number of workloadInfo.m_OutputTensorInfos. "
714 "Number of windows: " +
715 to_string(m_ViewOrigins.size()) +
716 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
717 }
718
telsoa01c577f2c2018-08-31 09:22:23 +0100719 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000720 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
721 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
722 {
telsoa01c577f2c2018-08-31 09:22:23 +0100723 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000724 ViewOrigin const& e = m_ViewOrigins[w];
725 if (e.m_Origin.size() != inputDims)
726 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100727 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000728 "have the same dimensionality as the input tensor. "
729 "Window origin (index: " +
730 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
731 " dimensions, the input "
732 "tensor has " +
733 to_string(inputDims) + " dimensions.");
734 }
735 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
736 {
737 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
738 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
739 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100740 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000741 "be smaller or equal than the size of the input in that coord.");
742 }
743 }
744 }
745}
746
Jim Flynne242f2d2019-05-22 14:24:13 +0100747void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000748{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100749 const std::string descriptorName{"ConcatQueueDescriptor"};
750
751 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000752
753 if (m_Inputs.size() <= 0)
754 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100755 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000756 }
757 if (m_Outputs.size() <= 0)
758 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100759 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000760 }
761
762 if (workloadInfo.m_InputTensorInfos.size() <= 0)
763 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100764 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000765 }
766 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
767 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100768 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000769 }
770
Nikhil Raj8599a412018-11-19 14:51:07 +0000771 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
772 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100773 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000774 }
775
776 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
777 {
778 return;
779 }
780
telsoa014fcda012018-03-09 14:13:49 +0000781 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
782 {
783 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100784 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000785 "has to match number of workloadInfo.m_InputTensorInfos. "
786 "Number of windows: " +
787 to_string(m_ViewOrigins.size()) +
788 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
789 }
790
telsoa01c577f2c2018-08-31 09:22:23 +0100791 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000792 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
793 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
794 {
telsoa01c577f2c2018-08-31 09:22:23 +0100795 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000796 ViewOrigin const& e = m_ViewOrigins[w];
797 if (e.m_Origin.size() != outputDims)
798 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100799 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000800 "have the same dimensionality as the output tensor. "
801 "Window origin (index: " +
802 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
803 " dimensions, the output "
804 "tensor has " +
805 to_string(outputDims) + " dimensions.");
806 }
telsoa01c577f2c2018-08-31 09:22:23 +0100807 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000808 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
809 {
810 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
811 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
812 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100813 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000814 "be smaller or equal than the size of the output in that coord.");
815 }
816 }
817 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100818
819 // Check the supported data types
820 std::vector<DataType> supportedTypes =
821 {
James Conroyd47a0642019-09-17 14:22:06 +0100822 DataType::Float32,
823 DataType::Float16,
824 DataType::Boolean,
825 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000826 DataType::QAsymmU8,
827 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100828 };
829
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100830 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
831 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100832 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100833 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
834 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
835
836 const std::string inputName = "input_" + std::to_string(i);
837 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100838 }
telsoa014fcda012018-03-09 14:13:49 +0000839}
840
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100841void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
842{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100843 const std::string descriptorName{"StackQueueDescriptor"};
844
845 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100846
847 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
848 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100849 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100850 }
851
852 // All inputs must have the same shape, which is defined in parameters
853 const TensorShape& inputShape = m_Parameters.m_InputShape;
854 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
855 {
856 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
857 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100858 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100859 }
860 }
861
Matthew Jacksondba634f2019-08-15 15:14:18 +0100862 if (inputShape.GetNumDimensions() > 4)
863 {
864 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
865 }
866
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100867 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
868 // since the output tensor has an additional dimension.
869 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
870 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100871 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100872 "than the number of input dimensions.");
873 }
874
875 // Output shape must be as inferred from the input shape
876 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
877 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
878 {
879 if (outputShape[i] != inputShape[i])
880 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100881 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100882 "match shape inferred from input tensor.");
883 }
884 }
885
886 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
887 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100888 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100889 "match shape inferred from input tensor.");
890 }
891
892 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
893 {
894 if (outputShape[i] != inputShape[i-1])
895 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100896 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100897 "match shape inferred from input tensor.");
898 }
899 }
900
Matthew Jacksondba634f2019-08-15 15:14:18 +0100901 if (outputShape.GetNumDimensions() > 5)
902 {
903 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
904 }
905
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100906 // Check the supported data types
907 std::vector<DataType> supportedTypes =
908 {
James Conroyd47a0642019-09-17 14:22:06 +0100909 DataType::Float32,
910 DataType::Float16,
911 DataType::Boolean,
912 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000913 DataType::QAsymmU8,
914 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100915 };
916
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100917 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100918
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100919 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100920 {
921 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
922 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100923 descriptorName,
924 "input_0",
925 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100926 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100927
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100928 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
929 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100930 descriptorName,
931 "input_0",
932 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100933}
934
telsoa014fcda012018-03-09 14:13:49 +0000935void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
936{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100937 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000938
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100939 ValidateNumInputs(workloadInfo, descriptorName, 1);
940 ValidateNumOutputs(workloadInfo, descriptorName, 1);
941
942 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
943 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
944
945 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
946
947 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +0000948 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100949 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +0000950 }
951
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100952 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000953
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100954 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
955 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000956
957 if (m_Parameters.m_BiasEnabled)
958 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100959 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000960
telsoa01c577f2c2018-08-31 09:22:23 +0100961 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100962 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
963 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000964
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100965 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
966 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000967 }
968
Francis Murtagh46c09d02019-05-28 08:15:28 +0100969 // Check the supported data types
970 std::vector<DataType> supportedTypes =
971 {
James Conroyd47a0642019-09-17 14:22:06 +0100972 DataType::Float32,
973 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000974 DataType::QAsymmU8,
975 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100976 };
977
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100978 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
979 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000980}
981
telsoa014fcda012018-03-09 14:13:49 +0000982void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
983{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100984 const std::string descriptorName{"NormalizationQueueDescriptor"};
985
986 ValidateNumInputs(workloadInfo, descriptorName, 1);
987 ValidateNumOutputs(workloadInfo, descriptorName, 1);
988
989 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
990 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100991
992 // Check the supported data types
993 std::vector<DataType> supportedTypes =
994 {
995 DataType::Float16,
996 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000997 DataType::QAsymmU8,
998 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100999 };
1000
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001001 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001002
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001003 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001004
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001005 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001006}
1007
1008void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1009{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001010 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001011
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001012 ValidateNumInputs(workloadInfo, descriptorName, 2);
1013 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1014
1015 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1016 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1017 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1018
1019 std::vector<DataType> supportedTypes =
1020 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001021 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001022 DataType::QAsymmU8,
1023 DataType::QSymmS16,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001024 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001025 };
1026
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001027 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1028 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1029 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001030
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001031 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1032 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001033
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001034 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1035 inputTensorInfo1,
1036 outputTensorInfo,
1037 descriptorName,
1038 "input_0",
1039 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001040}
1041
telsoa014fcda012018-03-09 14:13:49 +00001042void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1043{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001044 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001045
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001046 ValidateNumInputs(workloadInfo, descriptorName, 2);
1047 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1048
1049 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1050 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1051 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1052
1053 std::vector<DataType> supportedTypes =
1054 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001055 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001056 DataType::QAsymmU8,
1057 DataType::QSymmS16,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001058 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001059 };
1060
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001061 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1062 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1063 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001064
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001065 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1066 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001067
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001068 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1069 inputTensorInfo1,
1070 outputTensorInfo,
1071 descriptorName,
1072 "input_0",
1073 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001074}
1075
1076void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1077{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001078 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001079
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001080 ValidateNumInputs(workloadInfo, descriptorName, 1);
1081 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1082
1083 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1084 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001085
1086 std::vector<DataType> supportedTypes =
1087 {
1088 DataType::Float16,
1089 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001090 DataType::QAsymmU8,
1091 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001092 };
1093
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001094 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1095 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001096
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001097 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1098 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1099 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001100
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001101 ValidatePointer(m_Mean, descriptorName, "mean");
1102 ValidatePointer(m_Variance, descriptorName, "variance");
1103 ValidatePointer(m_Beta, descriptorName, "beta");
1104 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001105
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001106 const TensorInfo& mean = m_Mean->GetTensorInfo();
1107 const TensorInfo& variance = m_Variance->GetTensorInfo();
1108 const TensorInfo& beta = m_Beta->GetTensorInfo();
1109 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001110
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001111 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1112 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1113 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1114 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001115
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001116 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1117 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1118 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001119}
1120
1121void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1122{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001123 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001124
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001125 ValidateNumInputs(workloadInfo, descriptorName, 1);
1126 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001127
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001128 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1129 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001130
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001131 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1132 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001133
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001134 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001135
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001136 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1137 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001138
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001139 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001140
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001141 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001142 if (m_Parameters.m_BiasEnabled)
1143 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001144 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001145
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001146 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1147 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001148
1149 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1150 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001151 }
1152
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001153 ValidatePerAxisQuantization(inputTensorInfo,
1154 outputTensorInfo,
1155 weightTensorInfo,
1156 optionalBiasTensorInfo,
1157 descriptorName);
1158
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001159 std::vector<DataType> supportedTypes =
1160 {
Ruomei Yan88d44b82019-05-23 14:29:06 +01001161 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001162 DataType::QAsymmU8,
1163 DataType::QSymmS16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001164 DataType::Float16
1165 };
1166
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001167 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1168 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1169}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001170
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001171void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1172{
1173 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1174
1175 ValidateNumInputs(workloadInfo, descriptorName, 1);
1176 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1177
1178 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1179 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1180
1181 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1182 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1183
1184 ValidatePointer(m_Weight, descriptorName, "weight");
1185
1186 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1187 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1188
1189 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1190 {
1191 throw InvalidArgumentException(
1192 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
1193 "cannot be smaller than 1.") % descriptorName %
1194 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
1195 }
1196
1197 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1198
1199 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1200 // inputChannels * channelMultiplier should be equal to outputChannels.
1201 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1202 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1203 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1204 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1205 {
1206 throw InvalidArgumentException(
1207 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
1208 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
1209 "(provided %4%).") % descriptorName % numWeightOutputChannels %
1210 numWeightInputChannels % numWeightChannelMultiplier));
1211 }
1212
Teresa Charlind8df0262019-11-11 12:28:15 +00001213 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001214
Teresa Charlind8df0262019-11-11 12:28:15 +00001215 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001216 if (m_Parameters.m_BiasEnabled)
1217 {
1218 ValidatePointer(m_Bias, descriptorName, "bias");
1219
Teresa Charlind8df0262019-11-11 12:28:15 +00001220 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1221 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001222
1223 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1224 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1225 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001226 ValidatePerAxisQuantization(inputTensorInfo,
1227 outputTensorInfo,
1228 weightTensorInfo,
1229 optionalBiasTensorInfo,
1230 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001231
1232 std::vector<DataType> supportedTypes =
1233 {
1234 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001235 DataType::QAsymmU8,
1236 DataType::QSymmS16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001237 DataType::Float16
1238 };
1239
1240 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1241 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001242}
1243
1244void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1245{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001246 const std::string descriptorName{"PermuteQueueDescriptor"};
1247
1248 ValidateNumInputs(workloadInfo, descriptorName, 1);
1249 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001250
1251 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1252
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001253 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1254 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001255
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001256 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1257 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001258
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001259 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001260 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001261 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001262 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001263 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1264 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1265 "must match dst dimension " + to_string(mapping[i]) +
1266 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001267 }
1268 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001269
1270 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001271}
1272
1273void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1274{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001275 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001276
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001277 ValidateNumInputs(workloadInfo, descriptorName, 1);
1278 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1279
1280 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1281 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1282
1283 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1284 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001285
1286 std::vector<DataType> supportedTypes =
1287 {
1288 DataType::Float32,
1289 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001290 DataType::QAsymmU8,
1291 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001292 };
1293
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001294 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1295 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001296}
1297
1298void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1299{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001300 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001301
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001302 ValidateNumInputs(workloadInfo, descriptorName, 1);
1303 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1304
1305 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1306 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1307
1308 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1309 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001310
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001311 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001312 {
1313 DataType::Float16,
1314 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001315 DataType::QAsymmU8,
1316 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001317 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001318
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001319 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1320 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001321
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001322 // ResizeBilinear only changes width and height: batch and channel count must match.
1323 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1324 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001325 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001326 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001327 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001328 boost::str(boost::format("%1%: Input batch size (%2%) "
1329 "does not match output batch size (%3%)") %
1330 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001331 }
1332
Teresa Charlin970f43b2019-07-01 13:51:07 +01001333 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001334 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1335 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001336 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001337 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001338 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001339 boost::str(boost::format("%1%: Input channel count (%2%) "
1340 "does not match output channel count (%3%)") %
1341 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001342 }
1343}
1344
1345void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1346{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001347 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001348
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001349 ValidateNumInputs(workloadInfo, descriptorName, 1);
1350 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1351
1352 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1353 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1354
1355 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1356 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001357
1358 std::vector<DataType> supportedTypes =
1359 {
1360 DataType::Float16,
1361 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001362 DataType::QAsymmU8,
1363 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001364 };
1365
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001366 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1367 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001368
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001369 // Resize only changes width and height: batch and channel count must match.
1370 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1371 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001372 if (inputBatchSize != outputBatchSize)
1373 {
1374 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001375 boost::str(boost::format("%1%: Input batch size (%2%) "
1376 "does not match output batch size (%3%)") %
1377 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001378 }
1379
1380 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001381 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1382 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001383 if (inputChannelCount != outputChannelCount)
1384 {
1385 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001386 boost::str(boost::format("%1%: Input channel count (%2%) "
1387 "does not match output channel count (%3%)") %
1388 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001389 }
1390}
1391
1392void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1393{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001394 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001395
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001396 ValidateNumInputs(workloadInfo, descriptorName, 1);
1397 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1398
1399 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1400 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1401
1402 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1403 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1404
1405 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1406
telsoa014fcda012018-03-09 14:13:49 +00001407 if (m_Parameters.m_Min > m_Parameters.m_Max)
1408 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001409 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001410 }
telsoa014fcda012018-03-09 14:13:49 +00001411}
1412
Kevin Mayce5045a2019-10-02 14:07:47 +01001413void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1414{
1415 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1416
1417 ValidateNumInputs(workloadInfo, descriptorName, 1);
1418 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1419
1420 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1421 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1422
1423 if (inputTensorInfo.GetNumDimensions() > 4)
1424 {
1425 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1426 }
1427
1428 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1429
1430 // Check the supported data types
1431 std::vector<DataType> supportedTypes =
1432 {
1433 DataType::Float32,
1434 DataType::Float16
1435 };
1436
1437 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001438 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001439}
1440
telsoa014fcda012018-03-09 14:13:49 +00001441void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1442{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001443 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001444
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001445 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001446 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1447
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001448 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1449 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1450
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001451 if (inputTensorInfo.GetNumDimensions() > 4)
1452 {
1453 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1454 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001455
1456 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001457
1458 // Check the supported data types
1459 std::vector<DataType> supportedTypes =
1460 {
1461 DataType::Float32,
1462 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001463 DataType::QAsymmU8,
1464 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001465 };
1466
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001467 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001468 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1469}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001470
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001471void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1472{
1473 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1474
1475 ValidateNumInputs(workloadInfo, descriptorName, 1);
1476 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1477
1478 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1479 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1480
1481 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1482
1483 std::vector<DataType> supportedTypes =
1484 {
1485 DataType::Float32,
1486 DataType::Float16,
1487 };
1488
1489 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001490 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001491}
1492
1493void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1494{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001495 const std::string descriptorName{"ConstantQueueDescriptor"};
1496
1497 ValidateNumInputs(workloadInfo, descriptorName, 0);
1498 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001499
1500 if (!m_LayerOutput)
1501 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001502 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001503 }
1504
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001505 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1506 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001507
1508 // Check the supported data types
1509 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001510 {
1511 DataType::Float32,
1512 DataType::Float16,
1513 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001514 DataType::QAsymmU8,
1515 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001516 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001517
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001518 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001519}
1520
1521void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1522{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001523 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001524
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001525 ValidateNumInputs(workloadInfo, descriptorName, 1);
1526 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1527
1528 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1529 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1530
1531 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001532
1533 // Check the supported data types
1534 std::vector<DataType> supportedTypes =
1535 {
1536 DataType::Float32,
1537 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001538 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001539 DataType::QAsymmU8,
1540 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001541 };
1542
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001543 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1544 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001545}
1546
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001547void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1548{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001549 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001550
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001551 ValidateNumInputs(workloadInfo, descriptorName, 1);
1552 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1553
1554 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1555 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1556
1557 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1558 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001559
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001560 if (m_Parameters.m_BlockShape.size() != 2)
1561 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001562 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001563 }
1564
1565 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1566 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001567 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1568 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001569 }
1570
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001571 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001572
1573 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001574 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001575
Matthew Bentham8800c002018-11-19 13:19:28 +00001576 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001577
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001578 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1579 widthPad.first + widthPad.second;
1580 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1581 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001582
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001583 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1584 inputShape[dimensionIndices.GetChannelsIndex()];
1585 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001586
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001587 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001588 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001589 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001590 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001591 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001592 }
1593
1594 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001595 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001596 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1597 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001598 }
nikraj01120522a2019-05-31 11:33:07 +01001599
1600 std::vector<DataType> supportedTypes =
1601 {
1602 DataType::Float16,
1603 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001604 DataType::QAsymmU8,
1605 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001606 };
1607
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001608 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1609 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001610}
1611
Keith Davisa57eccb2019-06-14 17:33:22 +01001612void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1613{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001614 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001615
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001616 ValidateNumInputs(workloadInfo, descriptorName, 1);
1617 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001618
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001619 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1620 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1621
1622 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1623 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001624
1625 std::vector<DataType> supportedTypes =
1626 {
1627 DataType::Float32,
1628 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001629 DataType::QAsymmU8,
1630 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001631 };
1632
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001633 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1634 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001635
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001636 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1637
1638 if (m_Parameters.m_BlockSize == 0)
1639 {
1640 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1641 }
1642
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001643 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1644 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1645 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1646 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001647
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001648 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001649 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001650 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001651 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1652 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001653 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001654
1655 const TensorShape& outputShape = outputTensorInfo.GetShape();
1656 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1657 {
1658 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1659 "must be divisible by the square of block size." );
1660 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001661}
1662
telsoa014fcda012018-03-09 14:13:49 +00001663void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1664{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001665 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001666
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001667 ValidateNumInputs(workloadInfo, descriptorName, 1);
1668 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1669
1670 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1671 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001672
1673 std::vector<DataType> supportedTypes =
1674 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001675 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001676 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001677 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001678 };
1679
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001680 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001681
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001682 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001683 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001684 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001685 }
1686}
1687
telsoa01c577f2c2018-08-31 09:22:23 +01001688void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1689{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001690 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1691
1692 const std::string descriptorName{"LstmQueueDescriptor"};
1693
1694 // check dimensions of all inputs and outputs
1695 if (workloadInfo.m_InputTensorInfos.size() != 3)
1696 {
1697 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1698 }
1699 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1700 {
1701 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1702 }
1703
1704 std::vector<DataType> supportedTypes =
1705 {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001706 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001707 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001708 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001709 };
1710
Jan Eilers38e05bd2019-06-26 13:10:09 +01001711 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001712 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1713
Jan Eilers38e05bd2019-06-26 13:10:09 +01001714 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001715 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001716 {
1717 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1718 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001719 descriptorName,
1720 "input_0",
1721 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001722 }
1723 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001724 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001725 {
1726 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1727 workloadInfo.m_OutputTensorInfos[i],
1728 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001729 "input_0",
1730 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001731 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001732
janeil0117d8d852019-11-15 15:00:16 +00001733 // Making sure clipping parameters have valid values.
1734 // == 0 means no clipping
1735 // > 0 means clipping
1736 if (m_Parameters.m_ClippingThresCell < 0.0f)
1737 {
1738 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
1739 }
1740 if (m_Parameters.m_ClippingThresProj < 0.0f)
1741 {
1742 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
1743 }
1744
Jan Eilers38e05bd2019-06-26 13:10:09 +01001745
1746 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01001747 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1748 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1749 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1750 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1751 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1752 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1753
Jan Eilers38e05bd2019-06-26 13:10:09 +01001754 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001755 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1756 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001757 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001758 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1759 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001760 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001761 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1762 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001763 // scratchBufferTensor
1764 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001765 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1766 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001767 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001768 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1769 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001770 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001771 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1772 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001773 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001774 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1775 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001776
1777
1778 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1779 if ( m_InputToInputWeights )
1780 {
1781 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1782 (n_cell * n_input), "InputLayerNormWeights");
1783 }
1784
1785 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1786 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1787 (n_cell * n_input), "InputToForgetWeights");
1788
1789 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1790 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1791 (n_cell * n_input), "InputToCellWeights");
1792
1793 if ( m_RecurrentToInputWeights )
1794 {
1795 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1796 (n_cell * n_output), "RecurrentToInputWeights");
1797 }
1798
1799 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1800 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1801 (n_cell * n_output), "RecurrentToForgetWeights");
1802
1803 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1804 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1805 (n_cell * n_output), "RecurrentToCellWeights");
1806
1807 // Make sure the input-gate's parameters are either both present (regular
1808 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1809 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1810 !m_Parameters.m_CifgEnabled) ||
1811 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1812 m_Parameters.m_CifgEnabled));
1813 if (!cifg_weights_all_or_none)
1814 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001815 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1816 "RecurrentToInputWeights must either both be present (regular LSTM) "
1817 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1818 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001819 }
1820
1821 if ( m_CellToInputWeights )
1822 {
1823 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1824 n_cell, "CellToInputWeights");
1825 }
1826 if ( m_CellToForgetWeights )
1827 {
1828 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1829 n_cell, "CellToForgetWeights");
1830 }
1831 if ( m_CellToOutputWeights )
1832 {
1833 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1834 n_cell, "CellToOutputWeights");
1835 }
1836
1837 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1838 bool peephole_weights_all_or_none =
1839 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1840 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1841 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1842 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1843 if (!peephole_weights_all_or_none)
1844 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001845 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001846 }
1847
1848 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1849 if (m_Parameters.m_CifgEnabled)
1850 {
1851 if (m_InputGateBias)
1852 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001853 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001854 }
1855 }
1856 else
1857 {
1858 if (!m_InputGateBias)
1859 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001860 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1861 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001862 }
1863 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1864 n_cell, "InputGateBias");
1865 }
1866
1867 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1868 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1869
1870 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1871 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1872
1873 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1874 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1875
1876 if (m_ProjectionWeights)
1877 {
1878 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
1879 (n_cell * n_output), "ProjectionWeights");
1880 }
1881 if (m_ProjectionBias)
1882 {
1883 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
1884 }
1885
1886 // Making sure the projection tensors are consistent:
1887 // 1) If projection weight is not present, then projection bias should not be
1888 // present.
1889 // 2) If projection weight is present, then projection bias is optional.
1890 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
1891 !m_Parameters.m_ProjectionEnabled)
1892 || (m_ProjectionWeights && !m_ProjectionBias &&
1893 m_Parameters.m_ProjectionEnabled)
1894 || (m_ProjectionWeights && m_ProjectionBias &&
1895 m_Parameters.m_ProjectionEnabled));
1896 if (!projecton_tensors_consistent)
1897 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001898 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001899 }
1900
1901 // The four layer normalization weights either all have values or none of them have values. Additionally, if
1902 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
1903 // either all have values or none of them have values. Layer normalization is used when the values of all the
1904 // layer normalization weights are present
1905 if (m_InputLayerNormWeights)
1906 {
1907 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
1908 }
1909 if (m_ForgetLayerNormWeights)
1910 {
1911 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1912 }
1913 if (m_CellLayerNormWeights)
1914 {
1915 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1916 }
1917 if (m_OutputLayerNormWeights)
1918 {
1919 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1920 }
1921
Jan Eilers38e05bd2019-06-26 13:10:09 +01001922 if (m_Parameters.m_LayerNormEnabled)
1923 {
1924 if (!m_Parameters.m_CifgEnabled)
1925 {
1926 if (!m_InputLayerNormWeights)
1927 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001928 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
1929 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001930 }
1931 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
1932 1, n_cell, "InputLayerNormWeights");
1933 }
1934 else if (m_InputLayerNormWeights)
1935 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001936 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
1937 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001938 }
1939
1940 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
1941 "ForgetLayerNormWeights");
1942 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1943
1944 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
1945 "OutputLayerNormWeights");
1946 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1947
1948 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
1949 "CellLayerNormWeights");
1950 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1951 }
1952 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
1953 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001954 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
1955 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001956 }
telsoa01c577f2c2018-08-31 09:22:23 +01001957}
1958
1959void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1960{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001961 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001962
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001963 ValidateNumInputs(workloadInfo, descriptorName, 1);
1964 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1965
1966 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1967 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1968
1969 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01001970 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001971 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01001972 }
1973
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001974 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001975 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001976 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001977 }
1978
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001979 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001980}
1981
1982void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1983{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001984 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001985
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001986 ValidateNumInputs(workloadInfo, descriptorName, 1);
1987 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1988
1989 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1990 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1991
1992 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001993 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001994 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001995 }
1996
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001997 if (outputTensorInfo.GetDataType() != DataType::Float32)
1998 {
1999 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2000 }
2001
2002 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002003}
2004
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002005void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2006{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002007 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002008
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002009 ValidateNumInputs(workloadInfo, descriptorName, 2);
2010 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2011
2012 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2013 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2014 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2015
2016 std::vector<DataType> supportedTypes =
2017 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002018 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002019 DataType::QAsymmU8,
2020 DataType::QSymmS16,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01002021 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002022 };
2023
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002024 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2025 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2026 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002027
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002028 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2029 inputTensorInfo1,
2030 outputTensorInfo,
2031 descriptorName,
2032 "input_0",
2033 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002034}
2035
David Beckc2044fe2018-09-05 15:00:38 +01002036void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2037{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002038 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002039
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002040 ValidateNumInputs(workloadInfo, descriptorName, 2);
2041 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2042
2043 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2044 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2045 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2046
2047 std::vector<DataType> supportedTypes =
2048 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002049 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002050 DataType::QAsymmU8,
2051 DataType::QSymmS16,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01002052 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002053 };
2054
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002055 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2056 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2057 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002058
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002059 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2060 inputTensorInfo1,
2061 outputTensorInfo,
2062 descriptorName,
2063 "input_0",
2064 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002065}
2066
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002067void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2068{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002069 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002070
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002071 ValidateNumInputs(workloadInfo, descriptorName, 2);
2072 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2073
2074 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2075 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2076 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2077
2078 std::vector<DataType> supportedTypes =
2079 {
Mike Kelly1da02362019-08-01 08:43:57 +01002080 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002081 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002082 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002083 DataType::QAsymmU8,
2084 DataType::QSymmS16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002085 };
2086
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002087 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2088 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2089 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002090
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002091 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2092 inputTensorInfo1,
2093 outputTensorInfo,
2094 descriptorName,
2095 "input_0",
2096 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002097}
2098
narpra01a6bf9122018-09-10 09:50:09 +01002099void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2100{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002101 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002102
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002103 ValidateNumInputs(workloadInfo, descriptorName, 1);
2104 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2105
2106 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2107 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002108
2109 std::vector<DataType> supportedTypes =
2110 {
2111 DataType::Float32,
2112 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002113 DataType::QAsymmU8,
2114 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002115 };
narpra01eb061912018-09-10 17:35:27 +01002116
James Conroy4d1ff582019-06-10 17:06:39 +01002117 // First check if input tensor data type is supported, then
2118 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002119 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2120 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002121
narpra0132b90462018-09-13 11:07:48 +01002122 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002123 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002124 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002125 }
narpra0132b90462018-09-13 11:07:48 +01002126 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002127 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002128 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002129 }
2130 else
2131 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002132 unsigned int outputDim =
2133 inputTensorInfo.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
2134 ValidateTensorNumDimensions(outputTensorInfo,
2135 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002136 outputDim > 0 ? outputDim : 1,
2137 "output");
2138 }
narpra01a6bf9122018-09-10 09:50:09 +01002139}
2140
jimfly012c9322a2018-09-19 10:59:49 +01002141void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2142{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002143 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002144
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002145 ValidateNumInputs(workloadInfo, descriptorName, 1);
2146 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2147
2148 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2149 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002150
jimfly012c9322a2018-09-19 10:59:49 +01002151 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002152 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2153
jimfly012c9322a2018-09-19 10:59:49 +01002154 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002155 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2156 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2157 "as there are dimensions in the input tensor that is " +
2158 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2159 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002160 }
2161}
2162
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002163void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2164{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002165 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002166
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002167 ValidateNumInputs(workloadInfo, descriptorName, 1);
2168 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002169
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002170 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2171 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2172
Sadik Armagan2208b602019-07-31 16:36:27 +01002173 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002174 {
James Conroyd47a0642019-09-17 14:22:06 +01002175 DataType::Float32,
2176 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002177 };
2178
2179 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002180
Derek Lambertif90c56d2020-01-10 17:14:08 +00002181 if (outputTensorInfo.GetDataType() != DataType::QAsymmU8 &&
Finn Williamsfd271062019-12-04 14:27:27 +00002182 outputTensorInfo.GetDataType() != DataType::QSymmS8 &&
Derek Lambertif90c56d2020-01-10 17:14:08 +00002183 outputTensorInfo.GetDataType() != DataType::QSymmS16)
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002184 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002185 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002186 }
2187}
2188
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002189void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2190{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002191 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002192
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002193 ValidateNumInputs(workloadInfo, descriptorName, 1);
2194 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002195
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002196 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2197 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002198
2199 std::vector<DataType> supportedTypes =
2200 {
James Conroyd47a0642019-09-17 14:22:06 +01002201 DataType::Float32,
2202 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002203 DataType::QAsymmU8,
2204 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002205 };
2206
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002207 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2208 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002209}
2210
Conor Kennedy430b5d82018-11-14 15:28:28 +00002211void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2212{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002213 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002214
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002215 ValidateNumInputs(workloadInfo, descriptorName, 1);
2216 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2217
2218 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2219 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002220
2221 std::vector<DataType> supportedTypes =
2222 {
2223 DataType::Float16,
2224 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002225 DataType::QAsymmU8,
2226 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002227 };
2228
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002229 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2230 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002231
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002232 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002233
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002234 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002235 if (rank > 4)
2236 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002237 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002238 }
2239
Conor Kennedy430b5d82018-11-14 15:28:28 +00002240 // Begin, End & Stride length must be of rank(input0)
2241 if (m_Parameters.m_Begin.size() != rank)
2242 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002243 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002244 }
2245
2246 if (m_Parameters.m_End.size() != rank)
2247 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002248 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002249 }
2250
2251 if (m_Parameters.m_Stride.size() != rank)
2252 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002253 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002254 }
2255
2256 // Stride entries must be non-zero
2257 for (auto& stride : m_Parameters.m_Stride)
2258 {
2259 if (stride == 0)
2260 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002261 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002262 }
2263 }
2264}
2265
kevmay0190539692018-11-29 08:40:19 +00002266void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2267{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002268 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002269
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002270 ValidateNumInputs(workloadInfo, descriptorName, 2);
2271 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2272
2273 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2274 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2275 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2276
2277 std::vector<DataType> supportedTypes =
2278 {
Mike Kelly1da02362019-08-01 08:43:57 +01002279 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002280 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002281 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002282 DataType::QAsymmU8,
2283 DataType::QSymmS16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002284 };
2285
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002286 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2287 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2288 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002289
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002290 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2291 inputTensorInfo1,
2292 outputTensorInfo,
2293 descriptorName,
2294 "input_0",
2295 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002296}
2297
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002298void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2299{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002300 const std::string descriptorName{"DebugQueueDescriptor"};
2301
2302 ValidateNumInputs(workloadInfo, descriptorName, 1);
2303 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002304}
2305
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002306void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2307{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002308 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002309
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002310 ValidateNumInputs(workloadInfo, descriptorName, 2);
2311 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002312
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002313 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2314 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2315 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2316
2317 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2318 inputTensorInfo1,
2319 outputTensorInfo,
2320 descriptorName,
2321 "input_0",
2322 "input_1");
2323
2324 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002325 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002326 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002327 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002328}
2329
FrancisMurtagh878f0232018-12-19 10:56:15 +00002330void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2331{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002332 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002333
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002334 ValidateNumInputs(workloadInfo, descriptorName, 2);
2335 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002336
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002337 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2338 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2339 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2340
2341 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2342 inputTensorInfo1,
2343 outputTensorInfo,
2344 descriptorName,
2345 "input_0",
2346 "input_1");
2347
2348 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002349 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002350 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002351 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002352}
2353
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002354void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2355{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002356 const std::string descriptorName{"RsqrtQueueDescriptor"};
2357
2358 ValidateNumInputs(workloadInfo, descriptorName, 1);
2359 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2360
2361 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2362 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2363
2364 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002365
2366 std::vector<DataType> supportedTypes =
2367 {
James Conroyd47a0642019-09-17 14:22:06 +01002368 DataType::Float16,
2369 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002370 DataType::QAsymmU8,
2371 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002372 };
2373
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002374 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2375 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002376}
2377
narpra01b89b05f2019-01-16 09:53:09 +00002378void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2379{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002380 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002381
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002382 ValidateNumInputs(workloadInfo, descriptorName, 2);
2383 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002384
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002385 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2386 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002387 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002388 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002389 }
2390
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002391 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2392 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2393
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002394 std::vector<DataType> supportedTypes =
2395 {
James Conroyd47a0642019-09-17 14:22:06 +01002396 DataType::Float16,
2397 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002398 DataType::QAsymmU8,
2399 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002400 };
2401
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002402 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002403
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002404 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002405
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002406 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2407 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002408}
2409
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002410void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2411{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002412 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2413
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002414 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002415
2416 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2417 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002418 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002419 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2420 }
2421
2422 if (m_Anchors == nullptr)
2423 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002424 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002425 }
2426
2427 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002428 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2429 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2430
2431 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002432 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002433 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2434 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002435
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002436 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2437 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2438 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002439
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002440 const std::vector<DataType> supportedInputTypes =
2441 {
2442 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002443 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002444 DataType::QAsymmU8,
2445 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002446 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002447
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002448 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2449 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2450 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2451
2452 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2453 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2454 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2455 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2456
2457 // NOTE: Output is always Float32 regardless of input type
2458 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2459 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2460 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2461 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002462
2463 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2464 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002465 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002466 "must be positive and less than or equal to 1.");
2467 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002468
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002469 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2470 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002471 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002472 "should be equal to number of classes + 1.");
2473 }
2474}
2475
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002476void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2477{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002478 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002479
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002480 ValidateNumInputs(workloadInfo, descriptorName, 1);
2481 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2482
2483 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2484 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2485
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002486 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002487 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002488 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002489 }
2490
Sadik Armagan2208b602019-07-31 16:36:27 +01002491 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002492 {
James Conroyd47a0642019-09-17 14:22:06 +01002493 DataType::Float32,
2494 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002495 };
2496
2497 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002498}
2499
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002500void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2501{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002502 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002503
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002504 ValidateNumInputs(workloadInfo, descriptorName, 2);
2505 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002506
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002507 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2508 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2509 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002510
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002511 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2512 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2513
2514 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2515 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002516}
2517
Sadik Armaganeff363d2019-04-05 15:25:46 +01002518void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2519{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002520 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002521
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002522 ValidateNumInputs(workloadInfo, descriptorName, 2);
2523 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2524
2525 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2526 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2527
2528 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2529 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2530
2531 std::vector<DataType> supportedTypes =
2532 {
Sadik Armaganeff363d2019-04-05 15:25:46 +01002533 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002534 DataType::QAsymmU8,
2535 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002536 };
2537
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002538 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2539 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002540
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002541 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2542 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002543
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002544 ValidateTensorShapesMatch(inputTensorInfo0,
2545 outputTensorInfo0,
2546 descriptorName,
2547 "input_0",
2548 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002549
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002550 ValidateTensorShapesMatch(inputTensorInfo0,
2551 outputTensorInfo1,
2552 descriptorName,
2553 "input_0",
2554 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002555}
2556
Derek Lamberti901ea112019-12-10 22:07:09 +00002557void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002558{
2559 // This is internally generated so it should not need validation.
2560}
2561
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002562void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2563{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002564 const std::string& descriptorName{"PreluQueueDescriptor"};
2565
2566 ValidateNumInputs(workloadInfo, descriptorName, 2);
2567 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2568
2569 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2570 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2571 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002572
2573 std::vector<DataType> supportedTypes
2574 {
2575 DataType::Float16,
2576 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002577 DataType::QAsymmU8,
2578 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002579 };
2580
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002581 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2582 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002583
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002584 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002585
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002586 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2587 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002588
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002589 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2590 alphaTensorInfo,
2591 outputTensorInfo,
2592 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002593 "input",
2594 "alpha");
2595}
2596
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002597void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2598{
2599 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2600
2601 ValidateNumInputs(workloadInfo, descriptorName, 1);
2602 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2603
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002604 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2605 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2606
2607 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2608 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002609
2610 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002611
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002612 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2613 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002614
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002615 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2616
2617 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002618 if (m_Parameters.m_BiasEnabled)
2619 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002620 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002621
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002622 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2623 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002624
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002625 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002626 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002627 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002628
2629 ValidatePerAxisQuantization(inputTensorInfo,
2630 outputTensorInfo,
2631 weightTensorInfo,
2632 optionalBiasTensorInfo,
2633 descriptorName);
2634
2635 std::vector<DataType> supportedTypes =
2636 {
2637 DataType::Float32,
2638 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002639 DataType::QAsymmU8,
2640 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002641 };
2642
2643 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2644 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002645}
2646
James Conroy9c3cae82019-08-01 16:01:48 +01002647void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2648{
2649 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
2650
2651 // Validate number of inputs/outputs
2652 ValidateNumInputs(workloadInfo, descriptorName, 3);
2653 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2654
2655 // Input/output tensor infos
2656 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2657 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
2658 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
2659
2660 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2661 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2662
2663 std::vector<DataType> inputOutputSupportedTypes =
2664 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002665 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01002666 };
2667
2668 std::vector<DataType> cellStateSupportedTypes =
2669 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002670 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01002671 };
2672
2673 std::vector<DataType> weightsSupportedTypes =
2674 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002675 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01002676 };
2677
2678 std::vector<DataType> biasSupportedTypes =
2679 {
2680 DataType::Signed32
2681 };
2682
2683 // Validate types of input/output tensors
2684 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2685 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2686 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2687
2688 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2689 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2690
2691 // Validate matching types of input/output tensors
2692 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2693 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2694 "outputStateIn", "outputStateOut");
2695 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2696
2697 // Validate matching quantization info for input/output tensors
2698 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2699 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
2700 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002701
James Conroy9c3cae82019-08-01 16:01:48 +01002702 // Infer number of batches, input size and output size from tensor dimensions
2703 const uint32_t numBatches = inputInfo.GetShape()[0];
2704 const uint32_t inputSize = inputInfo.GetShape()[1];
2705 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
2706
2707 // Validate number of dimensions and number of elements for input/output tensors
2708 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2709 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
2710 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2711 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
2712 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2713
2714 // Validate number of dimensions and number of elements for weights tensors
2715 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
2716 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
2717 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
2718
2719 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2720 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2721 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
2722
2723 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2724 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2725 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
2726
2727 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2728 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2729 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
2730
2731 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
2732 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
2733 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
2734
2735 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2736 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2737 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
2738 " RecurrentToForgetWeights");
2739
2740 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2741 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2742 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2743
2744 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2745 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2746 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2747
2748 // Validate data types for weights tensors (all should match each other)
2749 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
2750
2751 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
2752 "inputToInputWeights", "inputToForgetWeights");
2753 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
2754 "inputToInputWeights", "inputToCellWeights");
2755 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2756 "inputToInputWeights", "inputToOutputWeights");
2757
2758 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
2759 "inputToInputWeights", "recurrentToInputWeights");
2760 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2761 "inputToInputWeights", "recurrentToForgeteights");
2762 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2763 "inputToInputWeights", "recurrentToCellWeights");
2764 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2765 "inputToInputWeights", "recurrentToOutputWeights");
2766
2767 // Validate matching quantization info for weight tensors (all should match each other)
2768 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
2769 descriptorName, "inputToInputWeights", "inputToForgetWeights");
2770 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
2771 descriptorName, "inputToInputWeights", "inputToCellWeights");
2772 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
2773 descriptorName, "inputToInputWeights", "inputToOutputWeights");
2774
2775 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
2776 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
2777 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
2778 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
2779 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
2780 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
2781 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
2782 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
2783
2784 // Validate number of dimensions and number of elements in bias tensors
2785 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
2786 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
2787 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
2788
2789 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2790 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2791 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
2792
2793 ValidatePointer(m_CellBias, descriptorName, "CellBias");
2794 auto cellBiasInfo = m_CellBias->GetTensorInfo();
2795 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
2796
2797 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
2798 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
2799 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
2800
2801 // Validate data types for bias tensors (all should match each other)
2802 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
2803
2804 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
2805 "inputGateBias", "forgetGateBias");
2806 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
2807 "inputGateBias", "cellBias");
2808 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
2809 "inputGateBias", "outputGateBias");
2810
2811 // Validate bias tensor quantization info
2812 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2813 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2814 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2815 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2816}
2817
Kevin May868eb142019-09-04 17:29:31 +01002818void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2819{
2820 const std::string descriptorName{"AbsQueueDescriptor"};
2821
2822 ValidateNumInputs(workloadInfo, descriptorName, 1);
2823 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2824
2825 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2826 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2827
2828 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2829
2830 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01002831 {
2832 DataType::Float16,
2833 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002834 DataType::QAsymmU8,
2835 DataType::QSymmS16
James Conroyd47a0642019-09-17 14:22:06 +01002836 };
Kevin May868eb142019-09-04 17:29:31 +01002837
2838 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2839 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2840}
2841
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002842void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2843{
2844 const std::string descriptorName{"SliceQueueDescriptor"};
2845
2846 ValidateNumInputs(workloadInfo, descriptorName, 1);
2847 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2848
2849 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2850 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2851
2852 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2853
2854 const unsigned int rank = inputTensorInfo.GetNumDimensions();
2855 if (rank > 4)
2856 {
2857 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
2858 }
2859
2860 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
2861
2862 // Check if m_Begin and m_Size have the expected length
2863 if (m_Parameters.m_Begin.size() != rank)
2864 {
2865 throw InvalidArgumentException(descriptorName +
2866 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
2867 }
2868 if (m_Parameters.m_Size.size() != rank)
2869 {
2870 throw InvalidArgumentException(descriptorName +
2871 ": Length of size descriptor must equal rank " + std::to_string(rank));
2872 }
2873
2874 // Check if the shape of the output tensor matches m_Size
2875 const TensorShape& outputShape = outputTensorInfo.GetShape();
2876 for (unsigned int i = 0u; i < rank; ++i)
2877 {
2878 if (m_Parameters.m_Size[i] != outputShape[i])
2879 {
2880 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
2881 }
2882 }
2883
2884 // Check if the sum of begin offset and size in a given dimension
2885 // does not exceed the size of corresponding input
2886 const TensorShape& inputShape = inputTensorInfo.GetShape();
2887 for(unsigned int i = 0u; i < rank; ++i)
2888 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002889 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002890 {
2891 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
2892 std::to_string(i) + " exceeds input size.");
2893 }
2894 }
2895}
2896
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01002897void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2898{
2899 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
2900
2901 ValidateNumInputs(workloadInfo, descriptorName, 1);
2902 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2903
2904 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
2905 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
2906
2907 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
2908 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
2909
2910 std::vector<DataType> supportedTypes =
2911 {
2912 DataType::Float32,
2913 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002914 DataType::QAsymmU8,
2915 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01002916 };
2917
2918 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
2919 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
2920
2921 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
2922
2923 if (m_Parameters.m_BlockSize == 0)
2924 {
2925 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
2926 }
2927
2928 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
2929 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
2930 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
2931 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
2932
2933 const TensorShape& outputShape = outputInfo.GetShape();
2934 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
2935 {
2936 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
2937 "must be divisible by block size.");
2938 }
2939
2940 const TensorShape& inputShape = inputInfo.GetShape();
2941 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
2942 {
2943 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
2944 "must be divisible by the square of block size." );
2945 }
2946}
2947
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01002948void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2949{
2950 const std::string descriptorName{"ComparisonQueueDescriptor"};
2951
2952 ValidateNumInputs(workloadInfo, descriptorName, 2);
2953 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2954
2955 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2956 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2957 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2958
2959 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2960 inputTensorInfo1,
2961 outputTensorInfo,
2962 descriptorName,
2963 "input_0",
2964 "input_1");
2965
2966 if (outputTensorInfo.GetDataType() != DataType::Boolean)
2967 {
2968 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
2969 }
2970}
2971
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002972} // namespace armnn