blob: 606821b5e50ccc0552117aa8c6f8c071533677cf [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Colm Donelan0c479742021-12-10 12:43:54 +00006#include <armnn/backends/TensorHandle.hpp>
7#include <armnn/backends/WorkloadData.hpp>
8#include <armnn/backends/WorkloadInfo.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00009#include <armnnUtils/DataLayoutIndexed.hpp>
10#include <armnnUtils/TensorUtils.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
mathad01df9a3222021-04-28 11:42:57 +010012#include <armnn/Logging.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000013
telsoa014fcda012018-03-09 14:13:49 +000014#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000015#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000016#include <string>
17#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000018
James Ward47fce872020-09-10 11:57:28 +010019#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000020
Matteo Martincigh21350152018-11-28 16:22:22 +000021using namespace armnnUtils;
22
telsoa014fcda012018-03-09 14:13:49 +000023namespace armnn
24{
25
26//---------------------------------------------------------------
27DataType GetBiasDataType(DataType inputDataType)
28{
29 switch (inputDataType)
30 {
telsoa01c577f2c2018-08-31 09:22:23 +010031 case DataType::Float16:
32 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000033 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000034 case DataType::Float32:
35 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000036 case DataType::QAsymmS8:
37 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000038 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000039 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000040 case DataType::QSymmS8:
41 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000042 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010043 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000044 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010045 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000046 return DataType::Float32;
47 }
48}
49
50namespace
51{
52
53//---------------------------------------------------------------
54//android ndk does not support std::to_string function.
55template <typename T>
56std::string to_string(T value)
57{
58 std::ostringstream os;
59 os << value;
60 return os.str();
61}
62
63//---------------------------------------------------------------
64void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
65{
66 if (!ptr)
67 {
68 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
69 paramName + " parameter must be set.");
70 }
71}
72
73//---------------------------------------------------------------
74void ValidateTensorShapesMatch(const TensorInfo& first,
75 const TensorInfo& second,
76 std::string const& descName,
77 std::string const& firstName,
78 std::string const& secondName)
79{
80 if (first.GetShape() != second.GetShape())
81 {
82 throw InvalidArgumentException(descName + ": "
83 + firstName + " & " + secondName + " must have identical shapes");
84 }
85}
86
87//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010088void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000089{
Sadik Armaganeff363d2019-04-05 15:25:46 +010090 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000091 {
92 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010093 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000094 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
95 }
96}
97
98//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010099void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000100{
Sadik Armaganeff363d2019-04-05 15:25:46 +0100101 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000102 {
103 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100104 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000105 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
106 }
107}
108
109//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000110
111//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100112void ValidateTensorNumElements(const TensorInfo& tensor,
113 std::string const& descName,
114 unsigned int numElements,
115 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100116{
117 if (tensor.GetNumElements() != numElements)
118 {
119 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100120 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100121 tensorName + " tensor.");
122 }
123}
124
125//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000126void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
127 const std::string& descName, std::string const& tensorName)
128{
129 if (tensor.GetDataType() != dataType)
130 {
131 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
132 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
133 }
134}
135
Derek Lambertid466a542020-01-22 15:37:29 +0000136void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
137{
Jan Eilers1b2654f2021-09-24 15:45:46 +0100138 if (tensor.GetDataType() != DataType::QSymmS8)
Derek Lambertid466a542020-01-22 15:37:29 +0000139 {
140 throw InvalidArgumentException(descName +
141 ": Expected data type which supports per-axis quantization scheme but got " +
142 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
143 }
Derek Lambertid466a542020-01-22 15:37:29 +0000144}
145
telsoa014fcda012018-03-09 14:13:49 +0000146//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100147void ValidateTensorQuantizationSpace(const TensorInfo& first,
148 const TensorInfo& second,
149 const std::string& descName,
150 std::string const& firstName,
151 std::string const& secondName)
152{
153 if (!first.IsQuantized() ||
154 !second.IsQuantized())
155 {
156 // Not a quantized type, ignore the validation
157 return;
158 }
159
160 DataType firstDataType = first.GetDataType();
161 DataType secondDataType = second.GetDataType();
162
163 if (firstDataType != secondDataType)
164 {
165 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
166 " must be of the same quantized type, " +
167 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
168 secondName + " is " + GetDataTypeName(secondDataType));
169 }
170
171 if (!first.IsTypeSpaceMatch(second))
172 {
173 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
174 " must have the same quantization space, " +
175 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
176 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
177 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
178 " and scale " + to_string(second.GetQuantizationScale()));
179 }
180}
181
182//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100183void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
184 const TensorInfo& inputTensorInfo,
185 const TensorInfo& weightsTensorInfo,
186 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000187{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000188 // Helper lambda function to validate a single bias quantization scale value
189 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
190 {
mathad01df9a3222021-04-28 11:42:57 +0100191 constexpr float tolerance = 0.0001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000192 if (std::abs(biasScale - expectedScale) > tolerance)
193 {
194 // Print the float values with extra precision to see very small differences
mathad01df9a3222021-04-28 11:42:57 +0100195 ARMNN_LOG(warning) << std::setprecision(6) << descName << ": Expected " << expectedScale <<
196 " for bias quantization scale (product of input and weight scales), but got " <<
197 biasScale << ". Using scale provided.";
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000198 }
199 };
200
telsoa014fcda012018-03-09 14:13:49 +0000201 if (biasTensor.GetQuantizationOffset() != 0)
202 {
203 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
204 to_string(biasTensor.GetQuantizationOffset()));
205 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000206
James Conroy8502ade2020-11-12 19:26:29 +0000207 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000208 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000209 // Validate per-axis quantization scales
210 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
211 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
212
213 if (weightScales.size() != biasScales.size())
214 {
215 std::stringstream msg;
James Conroy8502ade2020-11-12 19:26:29 +0000216 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
217 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
218 << ", biases=" << biasScales.size();
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000219 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
220 }
221
222 for (size_t i = 0ul; i < biasScales.size(); ++i)
223 {
224 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
225 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
226 }
227 }
228 else
229 {
230 // Validate per-tensor quantization scale
231 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
232 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000233 }
234}
235
236//---------------------------------------------------------------
237void ValidateTensors(const std::vector<ITensorHandle*>& vec,
238 unsigned int numExpected,
239 const std::string& descName,
240 const std::string& varName)
241{
242 if (vec.empty() && numExpected > 0)
243 {
244 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
245 }
246
247 for (unsigned int i = 0; i < numExpected; ++i)
248 {
249 if (!vec[i])
250 {
251 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
252 }
253 }
254}
255
256//---------------------------------------------------------------
257void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
258 const TensorInfo& second,
259 const TensorInfo& output,
260 std::string const& descName,
261 std::string const& firstName,
262 std::string const& secondName)
263{
264 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
265 // broadcasted.
266 if (first.GetNumDimensions() != second.GetNumDimensions())
267 {
268 throw InvalidArgumentException(descName + ": Tensors "
269 + firstName + " & " + secondName
270 + " must have the same number of dimensions in order to be broadcasted");
271 }
272 uint32_t numDims = first.GetNumDimensions();
273 std::vector<uint32_t> outputDims(numDims, 0u);
274 for (uint32_t i = 0; i < numDims; i++)
275 {
276 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
277 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
278 if (dimsNotEqual && dimsNotOne)
279 {
280 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
281 }
282 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
283 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100284 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000285 if (broadcastShape != output.GetShape())
286 {
287 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
288 + firstName + " & " + secondName
289 + " does not match the output shape");
290 }
291}
292
293//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100294void ValidateDataTypes(const TensorInfo& info,
295 const std::vector<armnn::DataType>& supportedTypes,
296 std::string const& descName)
297{
298 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
299 if (iterator == supportedTypes.end())
300 {
301 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
302 }
303}
304
James Conroy4d1ff582019-06-10 17:06:39 +0100305//---------------------------------------------------------------
306void ValidateTensorDataTypesMatch(const TensorInfo& first,
307 const TensorInfo& second,
308 std::string const& descName,
309 std::string const& firstName,
310 std::string const& secondName)
311{
312 if (first.GetDataType() != second.GetDataType())
313 {
314 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
315 " must have identical data types.");
316 }
317}
318
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100319//---------------------------------------------------------------
320void ValidateTensorNumElementsMatch(const TensorInfo& first,
321 const TensorInfo& second,
322 std::string const& descName,
323 std::string const& firstName,
324 std::string const& secondName)
325{
326 if (first.GetNumElements() != second.GetNumElements())
327 {
328 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
329 " must have the same number of elements.");
330 }
331}
332
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000333void ValidateWeightDataType(const TensorInfo& inputInfo,
334 const TensorInfo& weightInfo,
335 const std::string& descName)
336{
337 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000338 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000339 {
340 const std::vector<DataType> validTypes =
341 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000342 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100343 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +0100344 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000345 };
346
347 ValidateDataTypes(weightInfo, validTypes, descName);
348 }
349 else
350 {
351 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
352 }
353}
354
355void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
356 const std::string& descName,
357 const std::string& tensorName)
358{
359 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
360 if (!quantizationDim.has_value())
361 {
James Ward47fce872020-09-10 11:57:28 +0100362 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
363 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000364 }
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000365}
366
367void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
368 const std::string& descName,
369 const std::string& tensorName)
370{
371 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
372 if (quantizationOffset != 0)
373 {
James Ward47fce872020-09-10 11:57:28 +0100374 throw InvalidArgumentException(fmt::format(
375 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
376 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000377 }
378}
379
380void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
381 const TensorInfo& outputInfo,
382 const TensorInfo& weightInfo,
383 const Optional<TensorInfo>& optionalBiasInfo,
384 const std::string& descName)
385{
386 if (weightInfo.HasPerAxisQuantization())
387 {
388 const DataType inputDataType = inputInfo.GetDataType();
389 const DataType outputDataType = outputInfo.GetDataType();
390
Keith Davis0c2eeac2020-02-11 16:51:50 +0000391 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000392
393 if (!canHavePerAxisQuantization)
394 {
James Ward47fce872020-09-10 11:57:28 +0100395 throw InvalidArgumentException(fmt::format(
396 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
397 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000398 }
399
Derek Lambertid466a542020-01-22 15:37:29 +0000400
401 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000402 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
403 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
404
405 if (optionalBiasInfo.has_value())
406 {
407 const TensorInfo& biasInfo = optionalBiasInfo.value();
408 if (!biasInfo.HasPerAxisQuantization())
409 {
James Ward47fce872020-09-10 11:57:28 +0100410 throw InvalidArgumentException(fmt::format(
411 "{}: Per-axis quantization parameters not set on bias tensor, "
412 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000413 }
414
415 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
416 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
417 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
418 }
419 }
420}
421
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100422} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000423
Mike Kelly80512b02022-05-16 23:10:42 +0100424//---------------------------------------------------------------
425void QueueDescriptor::ValidateTensorNumDimensions(const TensorInfo& tensor,
426 std::string const& descName,
427 unsigned int numDimensions,
428 std::string const& tensorName) const
429{
430 // If we're allowing expanded dimensions then numDimensions becomes the minimum number of Dimensions we can allow.
431 // Throw an Exception if the tensors has fewer than numDimensions or if the squeezed dimensions are greater than
432 // numDimensions.
433 if (m_AllowExpandedDims)
434 {
435 unsigned int squeezedDims = 0;
436
437 for (unsigned int i = 0; i < tensor.GetNumDimensions(); ++i)
438 {
439 if (tensor.GetShape()[i] != 1)
440 {
441 ++squeezedDims;
442 }
443 }
444 if (tensor.GetNumDimensions() < numDimensions || squeezedDims > numDimensions)
445 {
446 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " or less but got " +
447 to_string(tensor.GetNumDimensions()) + " dimensions for " +
448 tensorName + " tensor.");
449 }
450 }
451 else
452 {
453 if (tensor.GetNumDimensions() != numDimensions)
454 {
455 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
456 to_string(tensor.GetNumDimensions()) + " dimensions for " +
457 tensorName + " tensor.");
458 }
459 }
460}
461
462//---------------------------------------------------------------
463void QueueDescriptor::ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
464 unsigned int numDimension,
465 unsigned int numElements,
466 std::string const& tensorName) const
467{
468 const std::string functionName{"ValidateTensorNumDimNumElem"};
469 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
470 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
471}
472
473//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000474void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
475 unsigned int numExpectedIn, unsigned int numExpectedOut) const
476{
477 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
478 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
479}
480
481//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100482void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
483{
484 const std::string descriptorName{"MapQueueDescriptor"};
485
486 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100487 ValidateNumOutputs(workloadInfo, descriptorName, 0);
488
489 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
490 {
491 if (!m_Inputs[i])
492 {
493 throw InvalidArgumentException(
494 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
495 }
496 }
497}
498
499//---------------------------------------------------------------
500void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
501{
502 const std::string descriptorName{"UnmapQueueDescriptor"};
503
504 ValidateNumInputs(workloadInfo, descriptorName, 1);
505 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100506
507 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
508 {
509 if (!m_Inputs[i])
510 {
511 throw InvalidArgumentException(
512 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
513 }
514 }
515}
516
517//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000518void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
519{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100520 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000521
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100522 ValidateNumInputs(workloadInfo, descriptorName, 1);
523 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000524
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100525 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
526 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
527
528 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
529 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000530
531 if (m_Inputs.size() != m_Outputs.size())
532 {
James Ward47fce872020-09-10 11:57:28 +0100533 throw InvalidArgumentException(fmt::format(
534 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
535 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000536 }
537
538 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
539 {
540 if (!m_Inputs[i])
541 {
James Ward47fce872020-09-10 11:57:28 +0100542 throw InvalidArgumentException(fmt::format(
543 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000544 }
545
546 if (!m_Outputs[i])
547 {
James Ward47fce872020-09-10 11:57:28 +0100548 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000549 }
550 }
551}
552
Derek Lambertif674aa02019-08-01 15:56:25 +0100553//---------------------------------------------------------------
554void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
555{
556 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
557 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
558
559 if (workloadInfo.m_InputTensorInfos.size() != 1)
560 {
James Ward47fce872020-09-10 11:57:28 +0100561 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
562 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100563
564 }
565
566 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
567 {
James Ward47fce872020-09-10 11:57:28 +0100568 throw InvalidArgumentException(fmt::format(
569 "Number of input infos ({0}) does not match the number of output infos ({1})",
570 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100571 }
572
573 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
574 {
575 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
576 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
577 {
James Ward47fce872020-09-10 11:57:28 +0100578 throw InvalidArgumentException(fmt::format(
579 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100580 }
581 }
582
583 if (m_Inputs.size() != 1)
584 {
James Ward47fce872020-09-10 11:57:28 +0100585 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100586 }
587
588 if (m_Inputs.size() != m_Outputs.size())
589 {
James Ward47fce872020-09-10 11:57:28 +0100590 throw InvalidArgumentException(fmt::format(
591 "Number of inputs ({0}) does not match the number of outputs ({1})",
592 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100593 }
594
595 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
596 {
597 if (!m_Inputs[i])
598 {
James Ward47fce872020-09-10 11:57:28 +0100599 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100600 }
601
602 if (!m_Outputs[i])
603 {
James Ward47fce872020-09-10 11:57:28 +0100604 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100605 }
606 }
607}
608
609//---------------------------------------------------------------
610void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
611{
612 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
Derek Lambertif674aa02019-08-01 15:56:25 +0100613
Derek Lambertif674aa02019-08-01 15:56:25 +0100614 if (m_Inputs.size() != 1)
615 {
James Ward47fce872020-09-10 11:57:28 +0100616 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100617 }
618
619 if (m_Outputs.size() != 0)
620 {
James Ward47fce872020-09-10 11:57:28 +0100621 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100622 }
623
624 if (!m_Inputs[0])
625 {
James Ward47fce872020-09-10 11:57:28 +0100626 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100627 }
628}
629
630//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000631void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
632{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100633 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100634
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100635 ValidateNumInputs(workloadInfo, descriptorName, 1);
636 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100637
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100638 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
639 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100640
641 std::vector<DataType> supportedTypes =
642 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000643 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100644 DataType::Float16,
645 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000646 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000647 DataType::QAsymmU8,
648 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100649 };
650
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100651 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
652 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
653 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000654}
655
Nikhil Rajee391d52019-09-05 17:50:44 +0100656void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
657{
658 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
659
660 ValidateNumInputs(workloadInfo, descriptorName, 1);
661 ValidateNumOutputs(workloadInfo, descriptorName, 1);
662
663 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
664 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
665
Inki Daed4619e22020-09-10 15:33:54 +0900666 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
667 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100668 {
Inki Daed4619e22020-09-10 15:33:54 +0900669 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100670 }
671
James Conroyd47a0642019-09-17 14:22:06 +0100672 std::vector<DataType> supportedInputTypes =
673 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000674 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100675 DataType::Float16,
676 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100677 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000678 DataType::QAsymmU8,
679 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900680 DataType::Signed32,
681 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100682 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100683
James Conroyd47a0642019-09-17 14:22:06 +0100684 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100685
686 auto inputShape = inputTensorInfo.GetShape();
687 auto outputShape = outputTensorInfo.GetShape();
688
689 auto inputNumDimensions = inputShape.GetNumDimensions();
690 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
691
692 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
693
694 // 1D input shape results in scalar output shape
695 if (inputShape.GetNumDimensions() == 1)
696 {
697 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
698 {
699 throw InvalidArgumentException(descriptorName + outputShapeError);
700 }
701 }
702 else
703 {
704 for (unsigned int i = 0; i < unsignedAxis; ++i)
705 {
706 if (outputShape[i] != inputShape[i])
707 {
708 throw InvalidArgumentException(descriptorName + outputShapeError);
709 }
710 }
711
712 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
713 {
714 if (outputShape[i - 1] != inputShape[i])
715 {
716 throw InvalidArgumentException(descriptorName + outputShapeError);
717 }
718 }
719 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100720}
721
mathad01b392e982021-04-07 12:07:30 +0100722void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
723{
724 const std::string descriptorName{"CastQueueDescriptor"};
725
726 ValidateNumInputs(workloadInfo, descriptorName, 1);
727 ValidateNumOutputs(workloadInfo, descriptorName, 1);
728
729 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
730 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
731
732 std::vector<DataType> supportedTypes =
733 {
734 DataType::BFloat16,
735 DataType::Float16,
736 DataType::Float32,
737 DataType::QAsymmS8,
738 DataType::QAsymmU8,
739 DataType::QSymmS8,
740 DataType::QSymmS16,
741 DataType::Signed32,
742 DataType::Signed64
743 };
744
745 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
746 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
747}
748
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100749void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
750{
751 const std::string descriptorName{"SoftmaxQueueDescriptor"};
752
753 ValidateNumInputs(workloadInfo, descriptorName, 1);
754 ValidateNumOutputs(workloadInfo, descriptorName, 1);
755
756 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
757 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
758
759 std::vector<DataType> supportedTypes =
760 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000761 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100762 DataType::Float16,
763 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000764 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000765 DataType::QAsymmU8,
766 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100767 };
768
769 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
770 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
771 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
772}
773
telsoa014fcda012018-03-09 14:13:49 +0000774void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
775{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100776 const std::string descriptorName{"SplitterQueueDescriptor"};
777
778 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000779
Ruomei Yan25339c32019-05-28 16:48:20 +0100780 // Check the supported data types
781 std::vector<DataType> supportedTypes =
782 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000783 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100784 DataType::Float32,
785 DataType::Float16,
786 DataType::Boolean,
787 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100788 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000789 DataType::QAsymmU8,
790 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100791 };
792
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100793 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
794 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100795 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100796 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
797 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
798
799 const std::string outputName = "output_" + std::to_string(i);
800 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100801 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100802
telsoa014fcda012018-03-09 14:13:49 +0000803 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
804 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100805 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000806 }
807
808 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
809 {
810 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100811 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000812 "has to match number of workloadInfo.m_OutputTensorInfos. "
813 "Number of windows: " +
814 to_string(m_ViewOrigins.size()) +
815 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
816 }
817
telsoa01c577f2c2018-08-31 09:22:23 +0100818 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000819 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
820 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
821 {
telsoa01c577f2c2018-08-31 09:22:23 +0100822 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000823 ViewOrigin const& e = m_ViewOrigins[w];
824 if (e.m_Origin.size() != inputDims)
825 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100826 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000827 "have the same dimensionality as the input tensor. "
828 "Window origin (index: " +
829 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
830 " dimensions, the input "
831 "tensor has " +
832 to_string(inputDims) + " dimensions.");
833 }
834 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
835 {
836 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
837 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
838 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100839 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000840 "be smaller or equal than the size of the input in that coord.");
841 }
842 }
843 }
844}
845
Jim Flynne242f2d2019-05-22 14:24:13 +0100846void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000847{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100848 const std::string descriptorName{"ConcatQueueDescriptor"};
849
850 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000851
852 if (m_Inputs.size() <= 0)
853 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100854 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000855 }
856 if (m_Outputs.size() <= 0)
857 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100858 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000859 }
860
861 if (workloadInfo.m_InputTensorInfos.size() <= 0)
862 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100863 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000864 }
865 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
866 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100867 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000868 }
869
Nikhil Raj8599a412018-11-19 14:51:07 +0000870 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
871 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100872 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000873 }
874
875 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
876 {
877 return;
878 }
879
telsoa014fcda012018-03-09 14:13:49 +0000880 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
881 {
882 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100883 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000884 "has to match number of workloadInfo.m_InputTensorInfos. "
885 "Number of windows: " +
886 to_string(m_ViewOrigins.size()) +
887 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
888 }
889
telsoa01c577f2c2018-08-31 09:22:23 +0100890 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000891 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
892 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
893 {
telsoa01c577f2c2018-08-31 09:22:23 +0100894 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000895 ViewOrigin const& e = m_ViewOrigins[w];
896 if (e.m_Origin.size() != outputDims)
897 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100898 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000899 "have the same dimensionality as the output tensor. "
900 "Window origin (index: " +
901 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
902 " dimensions, the output "
903 "tensor has " +
904 to_string(outputDims) + " dimensions.");
905 }
telsoa01c577f2c2018-08-31 09:22:23 +0100906 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000907 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
908 {
909 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
910 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
911 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100912 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000913 "be smaller or equal than the size of the output in that coord.");
914 }
915 }
916 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100917
918 // Check the supported data types
919 std::vector<DataType> supportedTypes =
920 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000921 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100922 DataType::Float32,
923 DataType::Float16,
924 DataType::Boolean,
925 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100926 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000927 DataType::QAsymmU8,
928 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100929 };
930
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100931 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
932 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100933 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100934 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
935 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
936
937 const std::string inputName = "input_" + std::to_string(i);
938 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100939 }
telsoa014fcda012018-03-09 14:13:49 +0000940}
941
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100942void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
943{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100944 const std::string descriptorName{"StackQueueDescriptor"};
945
946 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100947
948 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
949 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100950 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100951 }
952
953 // All inputs must have the same shape, which is defined in parameters
954 const TensorShape& inputShape = m_Parameters.m_InputShape;
955 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
956 {
957 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
958 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100959 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100960 }
961 }
962
Matthew Jacksondba634f2019-08-15 15:14:18 +0100963 if (inputShape.GetNumDimensions() > 4)
964 {
965 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
966 }
967
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100968 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
969 // since the output tensor has an additional dimension.
970 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
971 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100972 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100973 "than the number of input dimensions.");
974 }
975
976 // Output shape must be as inferred from the input shape
977 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
978 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
979 {
980 if (outputShape[i] != inputShape[i])
981 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100982 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100983 "match shape inferred from input tensor.");
984 }
985 }
986
987 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
988 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100989 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100990 "match shape inferred from input tensor.");
991 }
992
993 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
994 {
995 if (outputShape[i] != inputShape[i-1])
996 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100997 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100998 "match shape inferred from input tensor.");
999 }
1000 }
1001
Matthew Jacksondba634f2019-08-15 15:14:18 +01001002 if (outputShape.GetNumDimensions() > 5)
1003 {
1004 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
1005 }
1006
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001007 // Check the supported data types
1008 std::vector<DataType> supportedTypes =
1009 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001010 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001011 DataType::Float32,
1012 DataType::Float16,
1013 DataType::Boolean,
1014 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001015 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001016 DataType::QAsymmU8,
1017 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001018 };
1019
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001020 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001021
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001022 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001023 {
1024 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1025 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001026 descriptorName,
1027 "input_0",
1028 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001029 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001030
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001031 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1032 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001033 descriptorName,
1034 "input_0",
1035 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001036}
1037
Ryan OSheaec6c6802020-06-05 17:17:06 +01001038void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1039{
1040 const std::string descriptorName{"FillQueueDescriptor"};
1041
1042 ValidateNumInputs(workloadInfo, descriptorName, 1);
1043 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1044
1045 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1046 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1047
1048 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1049
1050 std::vector<DataType> supportedTypes =
1051 {
1052 DataType::BFloat16,
1053 DataType::Float32,
1054 DataType::Float16,
1055 DataType::Signed32
1056 };
1057
1058 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1059}
1060
telsoa014fcda012018-03-09 14:13:49 +00001061void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1062{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001063 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001064
Matthew Sloyan81beae32021-07-13 19:46:11 +01001065 uint32_t numInputs = 2;
1066 if (m_Parameters.m_BiasEnabled)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001067 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001068 numInputs = 3;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001069 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001070
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001071 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001072 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1073
1074 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1075 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1076
1077 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1078
1079 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001080 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001081 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001082 }
1083
Matthew Sloyan81beae32021-07-13 19:46:11 +01001084 TensorInfo weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001085 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001086
1087 if (m_Parameters.m_BiasEnabled)
1088 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001089 TensorInfo biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
telsoa01c577f2c2018-08-31 09:22:23 +01001090 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001091 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001092 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1093 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001094 }
1095
Francis Murtagh46c09d02019-05-28 08:15:28 +01001096 // Check the supported data types
1097 std::vector<DataType> supportedTypes =
1098 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001099 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001100 DataType::Float32,
1101 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001102 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001103 DataType::QAsymmU8,
1104 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001105 };
1106
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001107 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001108
1109 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1110 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1111 {
1112 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1113 {
1114 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1115 "for BFloat16 input.");
1116 }
1117 }
1118 else
1119 {
1120 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1121 }
telsoa014fcda012018-03-09 14:13:49 +00001122}
1123
telsoa014fcda012018-03-09 14:13:49 +00001124void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1125{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001126 const std::string descriptorName{"NormalizationQueueDescriptor"};
1127
1128 ValidateNumInputs(workloadInfo, descriptorName, 1);
1129 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1130
1131 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1132 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001133
1134 // Check the supported data types
1135 std::vector<DataType> supportedTypes =
1136 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001137 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001138 DataType::Float16,
1139 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001140 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001141 DataType::QAsymmU8,
1142 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001143 };
1144
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001145 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001146
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001147 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001148
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001149 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001150}
1151
1152void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1153{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001154 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001155
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001156 ValidateNumInputs(workloadInfo, descriptorName, 2);
1157 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1158
1159 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1160 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1161 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1162
1163 std::vector<DataType> supportedTypes =
1164 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001165 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001166 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001167 DataType::Float16,
1168 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001169 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001170 DataType::QSymmS16,
1171 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001172 };
1173
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001174 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1175 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1176 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001177
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001178 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1179 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001180
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001181 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1182 inputTensorInfo1,
1183 outputTensorInfo,
1184 descriptorName,
1185 "input_0",
1186 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001187}
1188
telsoa014fcda012018-03-09 14:13:49 +00001189void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1190{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001191 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001192
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001193 ValidateNumInputs(workloadInfo, descriptorName, 2);
1194 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1195
1196 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1197 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1198 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1199
1200 std::vector<DataType> supportedTypes =
1201 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001202 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001203 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001204 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001205 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001206 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001207 DataType::QSymmS16,
1208 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001209 };
1210
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001211 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1212 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1213 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001214
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001215 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1216 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001217
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001218 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1219 inputTensorInfo1,
1220 outputTensorInfo,
1221 descriptorName,
1222 "input_0",
1223 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001224}
1225
1226void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1227{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001228 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001229
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001230 ValidateNumInputs(workloadInfo, descriptorName, 1);
1231 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1232
1233 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1234 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001235
1236 std::vector<DataType> supportedTypes =
1237 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001238 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001239 DataType::Float16,
1240 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001241 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001242 DataType::QAsymmU8,
1243 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001244 };
1245
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001246 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1247 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001248
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001249 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001250 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001251
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001252 ValidatePointer(m_Mean, descriptorName, "mean");
1253 ValidatePointer(m_Variance, descriptorName, "variance");
1254 ValidatePointer(m_Beta, descriptorName, "beta");
1255 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001256
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001257 const TensorInfo& mean = m_Mean->GetTensorInfo();
1258 const TensorInfo& variance = m_Variance->GetTensorInfo();
1259 const TensorInfo& beta = m_Beta->GetTensorInfo();
1260 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001261
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001262 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1263 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1264 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1265 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001266
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001267 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1268 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1269 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001270}
1271
1272void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1273{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001274 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001275
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001276 uint32_t numInputs = 2;
1277 if (m_Parameters.m_BiasEnabled)
1278 {
1279 numInputs = 3;
1280 }
1281
1282 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001283 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001284
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001285 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1286 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001287
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001288 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1289 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001290
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001291 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
telsoa014fcda012018-03-09 14:13:49 +00001292
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001293 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001294
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001295 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001296
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001297 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001298 if (m_Parameters.m_BiasEnabled)
1299 {
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001300 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001301 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001302
1303 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1304 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001305 }
1306
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001307 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1308 {
1309 throw InvalidArgumentException(
1310 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1311 "cannot be either negative or 0.",
1312 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1313 }
1314
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001315 ValidatePerAxisQuantization(inputTensorInfo,
1316 outputTensorInfo,
1317 weightTensorInfo,
1318 optionalBiasTensorInfo,
1319 descriptorName);
1320
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001321 std::vector<DataType> supportedTypes =
1322 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001323 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001324 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001325 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001326 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001327 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001328 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001329 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001330 };
1331
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001332 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001333
1334 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1335 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1336 {
1337 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1338 {
1339 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1340 "for BFloat16 input.");
1341 }
1342 }
1343 else
1344 {
1345 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1346 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001347}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001348
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001349void Convolution3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1350{
1351 const std::string descriptorName{"Convolution3dQueueDescriptor"};
1352
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001353 uint32_t numInputs = 2;
1354 if (m_Parameters.m_BiasEnabled)
1355 {
1356 numInputs = 3;
1357 }
1358 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001359 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1360
1361 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1362 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1363
1364 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1365 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1366
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001367 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001368 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 5, "weight");
1369
1370 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1371
1372 Optional<TensorInfo> optionalBiasTensorInfo;
1373 if (m_Parameters.m_BiasEnabled)
1374 {
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001375 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001376 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1377
1378 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1379 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1380 }
1381
1382 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 || m_Parameters.m_StrideZ <= 0 )
1383 {
1384 throw InvalidArgumentException(
1385 fmt::format("{}: strideX (provided {}), strideY (provided {}) or strideZ (provided {})"
1386 "cannot be either negative or 0.",
1387 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY, m_Parameters.m_StrideZ));
1388 }
1389
1390 ValidatePerAxisQuantization(inputTensorInfo,
1391 outputTensorInfo,
1392 weightTensorInfo,
1393 optionalBiasTensorInfo,
1394 descriptorName);
1395
1396 std::vector<DataType> supportedTypes =
1397 {
1398 DataType::BFloat16,
1399 DataType::Float16,
1400 DataType::Float32,
1401 DataType::QAsymmS8,
1402 DataType::QAsymmU8,
1403 DataType::QSymmS16,
1404 DataType::QSymmS8
1405 };
1406
1407 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1408 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1409}
1410
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001411void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1412{
1413 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1414
Cathal Corbett06902652022-04-14 17:55:11 +01001415 uint32_t numInputs = 2;
1416 if (m_Parameters.m_BiasEnabled)
1417 {
1418 numInputs = 3;
1419 }
1420
1421 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001422 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1423
1424 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1425 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1426
1427 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1428 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1429
Cathal Corbett06902652022-04-14 17:55:11 +01001430 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001431 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1432
1433 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1434 {
1435 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001436 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1437 "cannot be smaller than 1.",
1438 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001439 }
1440
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001441 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1442 {
1443 throw InvalidArgumentException(
1444 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1445 "cannot be either negative or 0.",
1446 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1447 }
1448
Jan Eilers53ef7952021-06-02 12:01:25 +01001449 if (weightTensorInfo.GetShape()[0] != 1)
1450 {
1451 throw InvalidArgumentException(fmt::format(
1452 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1453 "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
1454 descriptorName,
1455 weightTensorInfo.GetShape()[0],
1456 weightTensorInfo.GetShape()[1],
1457 weightTensorInfo.GetShape()[2],
1458 weightTensorInfo.GetShape()[3]));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001459 }
1460
Cathal Corbett4b19d222022-05-11 20:12:17 +01001461 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1462 const unsigned int numWeightOutputChannelsRefFormat = weightTensorInfo.GetShape()[3];
1463 const unsigned int numWeightOutputChannelsAclFormat = weightTensorInfo.GetShape()[1];
1464 const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1465
1466 // Weights format has two valid options: [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] (CpuAcc/GpuAcc).
1467 bool validRefFormat = (numWeightOutputChannelsRefFormat == numOutputChannels);
1468 bool validAclFormat = (numWeightOutputChannelsAclFormat == numOutputChannels);
1469
1470 if (!(validRefFormat || validAclFormat))
1471 {
1472 throw InvalidArgumentException(fmt::format(
1473 "{0}: The weight format in armnn is expected to be [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] "
1474 "(CpuAcc/GpuAcc). But neither the 4th (CpuRef) or 2nd (CpuAcc/GpuAcc) dimension is equal to Cout."
1475 "Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
1476 descriptorName,
1477 numOutputChannels,
1478 weightTensorInfo.GetShape()[0],
1479 weightTensorInfo.GetShape()[1],
1480 weightTensorInfo.GetShape()[2],
1481 weightTensorInfo.GetShape()[3]));
1482 }
1483
Teresa Charlind8df0262019-11-11 12:28:15 +00001484 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001485
Teresa Charlind8df0262019-11-11 12:28:15 +00001486 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001487 if (m_Parameters.m_BiasEnabled)
1488 {
Cathal Corbett06902652022-04-14 17:55:11 +01001489 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Teresa Charlind8df0262019-11-11 12:28:15 +00001490 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001491
1492 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1493 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1494 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001495 ValidatePerAxisQuantization(inputTensorInfo,
1496 outputTensorInfo,
1497 weightTensorInfo,
1498 optionalBiasTensorInfo,
1499 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001500
1501 std::vector<DataType> supportedTypes =
1502 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001503 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001504 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001505 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001506 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001507 DataType::QAsymmU8,
1508 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001509 };
1510
1511 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1512 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001513}
1514
1515void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1516{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001517 const std::string descriptorName{"PermuteQueueDescriptor"};
1518
1519 ValidateNumInputs(workloadInfo, descriptorName, 1);
1520 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001521
1522 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1523
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001524 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1525 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001526
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001527 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1528 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001529
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001530 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001531 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001532 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001533 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001534 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1535 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1536 "must match dst dimension " + to_string(mapping[i]) +
1537 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001538 }
1539 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001540
1541 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001542}
1543
1544void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1545{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001546 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001547
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001548 ValidateNumInputs(workloadInfo, descriptorName, 1);
1549 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1550
1551 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1552 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1553
1554 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1555 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001556
1557 std::vector<DataType> supportedTypes =
1558 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001559 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001560 DataType::Float32,
1561 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001562 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001563 DataType::QAsymmU8,
1564 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001565 };
1566
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001567 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1568 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001569}
1570
Tamás Nyíri7b885b32021-10-26 14:47:57 +01001571void Pooling3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1572{
1573 const std::string descriptorName{"Pooling3dQueueDescriptor"};
1574
1575 ValidateNumInputs(workloadInfo, descriptorName, 1);
1576 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1577
1578 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1579 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1580
1581 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1582 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1583
1584 std::vector<DataType> supportedTypes =
1585 {
1586 DataType::BFloat16,
1587 DataType::Float32,
1588 DataType::Float16,
1589 DataType::QAsymmS8,
1590 DataType::QAsymmU8,
1591 DataType::QSymmS16
1592 };
1593
1594 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1595 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1596}
1597
1598
telsoa014fcda012018-03-09 14:13:49 +00001599void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1600{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001601 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001602
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001603 ValidateNumInputs(workloadInfo, descriptorName, 1);
1604 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1605
1606 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1607 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1608
1609 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1610 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001611
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001612 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001613 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001614 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001615 DataType::Float16,
1616 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001617 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001618 DataType::QAsymmU8,
1619 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001620 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001621
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001622 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1623 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001624
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001625 // ResizeBilinear only changes width and height: batch and channel count must match.
1626 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1627 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001628 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001629 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001630 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001631 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1632 descriptorName, inputBatchSize, outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001633 }
1634
Teresa Charlin970f43b2019-07-01 13:51:07 +01001635 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001636 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1637 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001638 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001639 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001640 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001641 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1642 descriptorName, inputChannelCount, outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001643 }
1644}
1645
1646void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1647{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001648 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001649
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001650 ValidateNumInputs(workloadInfo, descriptorName, 1);
1651 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1652
1653 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1654 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1655
1656 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1657 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001658
1659 std::vector<DataType> supportedTypes =
1660 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001661 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001662 DataType::Float16,
1663 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001664 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001665 DataType::QAsymmU8,
1666 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001667 };
1668
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001669 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1670 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001671
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001672 // Resize only changes width and height: batch and channel count must match.
1673 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1674 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001675 if (inputBatchSize != outputBatchSize)
1676 {
1677 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001678 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1679 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001680 }
1681
1682 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001683 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1684 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001685 if (inputChannelCount != outputChannelCount)
1686 {
1687 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001688 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1689 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001690 }
1691}
1692
1693void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1694{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001695 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001696
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001697 ValidateNumInputs(workloadInfo, descriptorName, 1);
1698 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1699
1700 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1701 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1702
1703 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1704 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1705
1706 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1707
telsoa014fcda012018-03-09 14:13:49 +00001708 if (m_Parameters.m_Min > m_Parameters.m_Max)
1709 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001710 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001711 }
telsoa014fcda012018-03-09 14:13:49 +00001712}
1713
Kevin Mayce5045a2019-10-02 14:07:47 +01001714void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1715{
1716 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1717
1718 ValidateNumInputs(workloadInfo, descriptorName, 1);
1719 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1720
1721 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1722 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1723
1724 if (inputTensorInfo.GetNumDimensions() > 4)
1725 {
1726 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1727 }
1728
1729 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1730
1731 // Check the supported data types
1732 std::vector<DataType> supportedTypes =
1733 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001734 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001735 DataType::Float32,
1736 DataType::Float16
1737 };
1738
1739 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001740 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001741}
1742
telsoa014fcda012018-03-09 14:13:49 +00001743void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1744{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001745 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001746
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001747 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001748 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1749
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001750 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1751 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1752
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001753 if (inputTensorInfo.GetNumDimensions() > 4)
1754 {
1755 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1756 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001757
1758 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001759
1760 // Check the supported data types
1761 std::vector<DataType> supportedTypes =
1762 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001763 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001764 DataType::Float32,
1765 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001766 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001767 DataType::QAsymmU8,
1768 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001769 };
1770
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001771 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001772 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1773}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001774
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001775void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1776{
1777 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1778
1779 ValidateNumInputs(workloadInfo, descriptorName, 1);
1780 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1781
1782 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1783 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1784
1785 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1786
1787 std::vector<DataType> supportedTypes =
1788 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001789 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001790 DataType::Float32,
1791 DataType::Float16,
1792 };
1793
1794 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001795 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001796}
1797
1798void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1799{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001800 const std::string descriptorName{"ConstantQueueDescriptor"};
1801
1802 ValidateNumInputs(workloadInfo, descriptorName, 0);
1803 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001804
1805 if (!m_LayerOutput)
1806 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001807 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001808 }
1809
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001810 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1811 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001812
1813 // Check the supported data types
1814 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001815 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001816 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001817 DataType::Float32,
1818 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001819 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001820 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001821 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001822 DataType::QSymmS16,
1823 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001824 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001825
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001826 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001827}
1828
1829void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1830{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001831 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001832
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001833 ValidateNumInputs(workloadInfo, descriptorName, 1);
1834 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1835
1836 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1837 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1838
1839 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001840
1841 // Check the supported data types
1842 std::vector<DataType> supportedTypes =
1843 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001844 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001845 DataType::Float32,
1846 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001847 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001848 DataType::QAsymmU8,
1849 DataType::QSymmS16,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001850 DataType::Signed32,
1851 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001852 };
1853
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001854 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1855 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001856}
1857
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001858void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1859{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001860 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001861
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001862 ValidateNumInputs(workloadInfo, descriptorName, 1);
1863 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1864
1865 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1866 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1867
1868 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1869 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001870
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001871 if (m_Parameters.m_BlockShape.size() != 2)
1872 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001873 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001874 }
1875
1876 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1877 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001878 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1879 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001880 }
1881
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001882 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001883
1884 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001885 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001886
Matthew Bentham8800c002018-11-19 13:19:28 +00001887 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001888
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001889 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1890 widthPad.first + widthPad.second;
1891 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1892 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001893
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001894 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1895 inputShape[dimensionIndices.GetChannelsIndex()];
1896 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001897
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001898 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001899 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001900 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001901 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001902 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001903 }
1904
1905 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001906 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001907 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1908 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001909 }
nikraj01120522a2019-05-31 11:33:07 +01001910
1911 std::vector<DataType> supportedTypes =
1912 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001913 DataType::BFloat16,
1914 DataType::Float16,
1915 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001916 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001917 DataType::QAsymmU8,
1918 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001919 };
1920
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001921 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1922 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001923}
1924
Keith Davisa57eccb2019-06-14 17:33:22 +01001925void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1926{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001927 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001928
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001929 ValidateNumInputs(workloadInfo, descriptorName, 1);
1930 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001931
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001932 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1933 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1934
1935 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1936 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001937
1938 std::vector<DataType> supportedTypes =
1939 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001940 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001941 DataType::Float32,
1942 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001943 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001944 DataType::QAsymmU8,
1945 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001946 };
1947
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001948 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1949 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001950
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001951 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1952
1953 if (m_Parameters.m_BlockSize == 0)
1954 {
1955 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1956 }
1957
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001958 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1959 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1960 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1961 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001962
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001963 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001964 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001965 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001966 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1967 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001968 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001969
1970 const TensorShape& outputShape = outputTensorInfo.GetShape();
1971 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1972 {
1973 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1974 "must be divisible by the square of block size." );
1975 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001976}
1977
telsoa014fcda012018-03-09 14:13:49 +00001978void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1979{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001980 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001981
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001982 ValidateNumInputs(workloadInfo, descriptorName, 1);
1983 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1984
1985 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1986 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001987
1988 std::vector<DataType> supportedTypes =
1989 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001990 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001991 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001992 DataType::Float16,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001993 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001994 };
1995
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001996 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matthew Sloyan81beae32021-07-13 19:46:11 +01001997 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1998 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1999 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00002000}
2001
telsoa01c577f2c2018-08-31 09:22:23 +01002002void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2003{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002004 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
2005
2006 const std::string descriptorName{"LstmQueueDescriptor"};
2007
2008 // check dimensions of all inputs and outputs
2009 if (workloadInfo.m_InputTensorInfos.size() != 3)
2010 {
2011 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
2012 }
2013 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2014 {
2015 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
2016 }
2017
2018 std::vector<DataType> supportedTypes =
2019 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002020 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01002021 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002022 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002023 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002024 };
2025
Jan Eilers38e05bd2019-06-26 13:10:09 +01002026 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002027 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
2028
Jan Eilers38e05bd2019-06-26 13:10:09 +01002029 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002030 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01002031 {
2032 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2033 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002034 descriptorName,
2035 "input_0",
2036 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01002037 }
2038 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002039 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01002040 {
2041 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2042 workloadInfo.m_OutputTensorInfos[i],
2043 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002044 "input_0",
2045 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01002046 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002047
janeil0117d8d852019-11-15 15:00:16 +00002048 // Making sure clipping parameters have valid values.
2049 // == 0 means no clipping
2050 // > 0 means clipping
2051 if (m_Parameters.m_ClippingThresCell < 0.0f)
2052 {
2053 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
2054 }
2055 if (m_Parameters.m_ClippingThresProj < 0.0f)
2056 {
2057 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
2058 }
2059
Jan Eilers38e05bd2019-06-26 13:10:09 +01002060 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01002061 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
2062 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
2063 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
2064 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
2065 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
2066 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
2067
Jan Eilers38e05bd2019-06-26 13:10:09 +01002068 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002069 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
2070 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002071 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002072 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
2073 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002074 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002075 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
2076 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002077 // scratchBufferTensor
2078 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002079 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
2080 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002081 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002082 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
2083 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002084 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002085 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
2086 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002087 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002088 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
2089 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002090
Jan Eilers38e05bd2019-06-26 13:10:09 +01002091 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
2092 if ( m_InputToInputWeights )
2093 {
2094 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
2095 (n_cell * n_input), "InputLayerNormWeights");
2096 }
2097
2098 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2099 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
2100 (n_cell * n_input), "InputToForgetWeights");
2101
2102 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2103 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2104 (n_cell * n_input), "InputToCellWeights");
2105
2106 if ( m_RecurrentToInputWeights )
2107 {
2108 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2109 (n_cell * n_output), "RecurrentToInputWeights");
2110 }
2111
2112 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2113 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2114 (n_cell * n_output), "RecurrentToForgetWeights");
2115
2116 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2117 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2118 (n_cell * n_output), "RecurrentToCellWeights");
2119
2120 // Make sure the input-gate's parameters are either both present (regular
2121 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2122 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2123 !m_Parameters.m_CifgEnabled) ||
2124 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2125 m_Parameters.m_CifgEnabled));
2126 if (!cifg_weights_all_or_none)
2127 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002128 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2129 "RecurrentToInputWeights must either both be present (regular LSTM) "
2130 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2131 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002132 }
2133
2134 if ( m_CellToInputWeights )
2135 {
2136 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2137 n_cell, "CellToInputWeights");
2138 }
2139 if ( m_CellToForgetWeights )
2140 {
2141 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2142 n_cell, "CellToForgetWeights");
2143 }
2144 if ( m_CellToOutputWeights )
2145 {
2146 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2147 n_cell, "CellToOutputWeights");
2148 }
2149
2150 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2151 bool peephole_weights_all_or_none =
2152 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2153 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2154 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2155 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2156 if (!peephole_weights_all_or_none)
2157 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002158 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002159 }
2160
2161 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2162 if (m_Parameters.m_CifgEnabled)
2163 {
2164 if (m_InputGateBias)
2165 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002166 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002167 }
2168 }
2169 else
2170 {
2171 if (!m_InputGateBias)
2172 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002173 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2174 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002175 }
2176 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2177 n_cell, "InputGateBias");
2178 }
2179
2180 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2181 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2182
2183 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2184 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2185
2186 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2187 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2188
2189 if (m_ProjectionWeights)
2190 {
2191 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2192 (n_cell * n_output), "ProjectionWeights");
2193 }
2194 if (m_ProjectionBias)
2195 {
2196 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2197 }
2198
2199 // Making sure the projection tensors are consistent:
2200 // 1) If projection weight is not present, then projection bias should not be
2201 // present.
2202 // 2) If projection weight is present, then projection bias is optional.
2203 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2204 !m_Parameters.m_ProjectionEnabled)
2205 || (m_ProjectionWeights && !m_ProjectionBias &&
2206 m_Parameters.m_ProjectionEnabled)
2207 || (m_ProjectionWeights && m_ProjectionBias &&
2208 m_Parameters.m_ProjectionEnabled));
2209 if (!projecton_tensors_consistent)
2210 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002211 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002212 }
2213
2214 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2215 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2216 // either all have values or none of them have values. Layer normalization is used when the values of all the
2217 // layer normalization weights are present
2218 if (m_InputLayerNormWeights)
2219 {
2220 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2221 }
2222 if (m_ForgetLayerNormWeights)
2223 {
2224 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2225 }
2226 if (m_CellLayerNormWeights)
2227 {
2228 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2229 }
2230 if (m_OutputLayerNormWeights)
2231 {
2232 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2233 }
2234
Jan Eilers38e05bd2019-06-26 13:10:09 +01002235 if (m_Parameters.m_LayerNormEnabled)
2236 {
2237 if (!m_Parameters.m_CifgEnabled)
2238 {
2239 if (!m_InputLayerNormWeights)
2240 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002241 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2242 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002243 }
2244 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2245 1, n_cell, "InputLayerNormWeights");
2246 }
2247 else if (m_InputLayerNormWeights)
2248 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002249 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2250 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002251 }
2252
2253 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2254 "ForgetLayerNormWeights");
2255 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2256
2257 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2258 "OutputLayerNormWeights");
2259 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2260
2261 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2262 "CellLayerNormWeights");
2263 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2264 }
2265 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2266 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002267 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2268 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002269 }
telsoa01c577f2c2018-08-31 09:22:23 +01002270}
2271
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002272void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2273{
2274 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2275
2276 ValidateNumInputs(workloadInfo, descriptorName, 1);
2277 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2278
2279 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2280 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2281
2282 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2283 {
2284 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2285 }
2286
2287 if (outputTensorInfo.GetDataType() != DataType::Float32)
2288 {
2289 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2290 }
2291
2292 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2293}
2294
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002295void ConvertFp32ToBf16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2296{
2297 const std::string descriptorName{"ConvertFp32ToBf16QueueDescriptor"};
2298
2299 ValidateNumInputs(workloadInfo, descriptorName, 1);
2300 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2301
2302 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2303 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2304
2305 if (inputTensorInfo.GetDataType() != DataType::Float32)
2306 {
2307 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2308 }
2309
2310 if (outputTensorInfo.GetDataType() != DataType::BFloat16)
2311 {
2312 throw InvalidArgumentException(descriptorName + ": Output tensor type must be BFloat16.");
2313 }
2314
2315 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2316}
2317
telsoa01c577f2c2018-08-31 09:22:23 +01002318void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2319{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002320 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002321
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002322 ValidateNumInputs(workloadInfo, descriptorName, 1);
2323 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2324
2325 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2326 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2327
2328 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002329 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002330 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002331 }
2332
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002333 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002334 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002335 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002336 }
2337
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002338 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002339}
2340
2341void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2342{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002343 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002344
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002345 ValidateNumInputs(workloadInfo, descriptorName, 1);
2346 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2347
2348 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2349 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2350
2351 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002352 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002353 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002354 }
2355
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002356 if (outputTensorInfo.GetDataType() != DataType::Float32)
2357 {
2358 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2359 }
2360
2361 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002362}
2363
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002364void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2365{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002366 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002367
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002368 ValidateNumInputs(workloadInfo, descriptorName, 2);
2369 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2370
2371 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2372 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2373 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2374
2375 std::vector<DataType> supportedTypes =
2376 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002377 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002378 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002379 DataType::Float32,
2380 DataType::QAsymmS8,
2381 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002382 DataType::QSymmS16,
2383 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002384 };
2385
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002386 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2387 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2388 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002389
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002390 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2391 inputTensorInfo1,
2392 outputTensorInfo,
2393 descriptorName,
2394 "input_0",
2395 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002396}
2397
David Beckc2044fe2018-09-05 15:00:38 +01002398void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2399{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002400 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002401
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002402 ValidateNumInputs(workloadInfo, descriptorName, 2);
2403 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2404
2405 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2406 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2407 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2408
2409 std::vector<DataType> supportedTypes =
2410 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002411 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002412 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002413 DataType::Float32,
2414 DataType::QAsymmS8,
2415 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002416 DataType::QSymmS16,
2417 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002418 };
2419
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002420 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2421 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2422 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002423
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002424 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2425 inputTensorInfo1,
2426 outputTensorInfo,
2427 descriptorName,
2428 "input_0",
2429 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002430}
2431
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002432void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2433{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002434 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002435
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002436 ValidateNumInputs(workloadInfo, descriptorName, 2);
2437 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2438
2439 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2440 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2441 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2442
2443 std::vector<DataType> supportedTypes =
2444 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002445 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002446 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002447 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002448 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002449 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002450 DataType::QSymmS16,
2451 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002452 };
2453
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002454 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2455 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2456 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002457
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002458 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2459 inputTensorInfo1,
2460 outputTensorInfo,
2461 descriptorName,
2462 "input_0",
2463 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002464}
2465
narpra01a6bf9122018-09-10 09:50:09 +01002466void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2467{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002468 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002469
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002470 ValidateNumInputs(workloadInfo, descriptorName, 1);
2471 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2472
2473 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2474 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002475
2476 std::vector<DataType> supportedTypes =
2477 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002478 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002479 DataType::Float32,
2480 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002481 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002482 DataType::QAsymmU8,
2483 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002484 };
narpra01eb061912018-09-10 17:35:27 +01002485
James Conroy4d1ff582019-06-10 17:06:39 +01002486 // First check if input tensor data type is supported, then
2487 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002488 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2489 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002490
narpra0132b90462018-09-13 11:07:48 +01002491 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002492 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002493 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002494 }
narpra0132b90462018-09-13 11:07:48 +01002495 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002496 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002497 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002498 }
2499 else
2500 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002501 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002502 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002503 ValidateTensorNumDimensions(outputTensorInfo,
2504 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002505 outputDim > 0 ? outputDim : 1,
2506 "output");
2507 }
narpra01a6bf9122018-09-10 09:50:09 +01002508}
2509
jimfly012c9322a2018-09-19 10:59:49 +01002510void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2511{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002512 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002513
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002514 ValidateNumInputs(workloadInfo, descriptorName, 1);
2515 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2516
2517 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2518 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002519
jimfly012c9322a2018-09-19 10:59:49 +01002520 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002521 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2522
jimfly012c9322a2018-09-19 10:59:49 +01002523 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002524 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2525 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2526 "as there are dimensions in the input tensor that is " +
2527 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2528 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002529 }
2530}
2531
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002532void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2533{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002534 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002535
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002536 ValidateNumInputs(workloadInfo, descriptorName, 1);
2537 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002538
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002539 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2540 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2541
Sadik Armagan2208b602019-07-31 16:36:27 +01002542 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002543 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002544 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002545 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002546 DataType::Float16,
2547 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002548 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002549 DataType::QAsymmU8,
2550 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002551 };
2552
2553 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002554
Keith Davis0c2eeac2020-02-11 16:51:50 +00002555 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002556 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002557 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002558 }
2559}
2560
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002561void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2562{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002563 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002564
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002565 ValidateNumInputs(workloadInfo, descriptorName, 1);
2566 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002567
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002568 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2569 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002570
2571 std::vector<DataType> supportedTypes =
2572 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002573 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002574 DataType::Float32,
2575 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002576 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002577 DataType::QAsymmU8,
2578 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002579 };
2580
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002581 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2582 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002583}
2584
Conor Kennedy430b5d82018-11-14 15:28:28 +00002585void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2586{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002587 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002588
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002589 ValidateNumInputs(workloadInfo, descriptorName, 1);
2590 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2591
2592 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2593 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002594
2595 std::vector<DataType> supportedTypes =
2596 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002597 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002598 DataType::Float16,
2599 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002600 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002601 DataType::QAsymmU8,
2602 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002603 };
2604
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002605 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2606 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002607
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002608 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002609
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002610 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002611 if (rank > 4)
2612 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002613 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002614 }
2615
Conor Kennedy430b5d82018-11-14 15:28:28 +00002616 // Begin, End & Stride length must be of rank(input0)
2617 if (m_Parameters.m_Begin.size() != rank)
2618 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002619 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002620 }
2621
2622 if (m_Parameters.m_End.size() != rank)
2623 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002624 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002625 }
2626
2627 if (m_Parameters.m_Stride.size() != rank)
2628 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002629 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002630 }
2631
2632 // Stride entries must be non-zero
2633 for (auto& stride : m_Parameters.m_Stride)
2634 {
2635 if (stride == 0)
2636 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002637 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002638 }
2639 }
2640}
2641
kevmay0190539692018-11-29 08:40:19 +00002642void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2643{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002644 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002645
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002646 ValidateNumInputs(workloadInfo, descriptorName, 2);
2647 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2648
2649 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2650 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2651 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2652
2653 std::vector<DataType> supportedTypes =
2654 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002655 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002656 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002657 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002658 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002659 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002660 DataType::QSymmS16,
2661 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002662 };
2663
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002664 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2665 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2666 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002667
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002668 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2669 inputTensorInfo1,
2670 outputTensorInfo,
2671 descriptorName,
2672 "input_0",
2673 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002674}
2675
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002676void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2677{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002678 const std::string descriptorName{"DebugQueueDescriptor"};
2679
2680 ValidateNumInputs(workloadInfo, descriptorName, 1);
2681 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002682}
2683
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002684void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2685{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002686 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002687
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002688 ValidateNumInputs(workloadInfo, descriptorName, 2);
2689 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002690
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002691 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2692 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2693 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2694
2695 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2696 inputTensorInfo1,
2697 outputTensorInfo,
2698 descriptorName,
2699 "input_0",
2700 "input_1");
2701
2702 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002703 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002704 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002705 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002706}
2707
FrancisMurtagh878f0232018-12-19 10:56:15 +00002708void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2709{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002710 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002711
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002712 ValidateNumInputs(workloadInfo, descriptorName, 2);
2713 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002714
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002715 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2716 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2717 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2718
2719 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2720 inputTensorInfo1,
2721 outputTensorInfo,
2722 descriptorName,
2723 "input_0",
2724 "input_1");
2725
2726 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002727 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002728 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002729 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002730}
2731
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002732void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2733{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002734 const std::string descriptorName{"RsqrtQueueDescriptor"};
2735
2736 ValidateNumInputs(workloadInfo, descriptorName, 1);
2737 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2738
2739 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2740 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2741
2742 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002743
2744 std::vector<DataType> supportedTypes =
2745 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002746 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002747 DataType::Float16,
2748 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002749 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002750 DataType::QAsymmU8,
2751 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002752 };
2753
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002754 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2755 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002756}
2757
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01002758void GatherNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2759{
2760 const std::string descriptorName{"GatherNdQueueDescriptor"};
2761
2762 ValidateNumInputs(workloadInfo, descriptorName, 2);
2763 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2764
2765 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2766 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
2767 {
2768 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
2769 }
2770
2771 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2772 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2773
2774 std::vector<DataType> supportedTypes =
2775 {
2776 DataType::BFloat16,
2777 DataType::Float16,
2778 DataType::Float32,
2779 DataType::QAsymmS8,
2780 DataType::QAsymmU8,
2781 DataType::QSymmS16,
2782 DataType::Signed32,
2783 };
2784
2785 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2786
2787 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2788
2789 unsigned int outputDim = outputTensorInfo.GetNumDimensions();
2790 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
2791}
2792
narpra01b89b05f2019-01-16 09:53:09 +00002793void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2794{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002795 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002796
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002797 ValidateNumInputs(workloadInfo, descriptorName, 2);
2798 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002799
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002800 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2801 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002802 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002803 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002804 }
2805
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002806 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2807 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2808
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002809 std::vector<DataType> supportedTypes =
2810 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002811 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002812 DataType::Float16,
2813 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002814 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002815 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002816 DataType::QSymmS16,
2817 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002818 };
2819
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002820 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002821
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002822 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002823
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002824 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2825 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002826}
2827
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002828void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2829{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002830 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2831
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002832 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002833
2834 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2835 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002836 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002837 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2838 }
2839
2840 if (m_Anchors == nullptr)
2841 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002842 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002843 }
2844
2845 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002846 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2847 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2848
2849 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002850 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002851 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2852 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002853
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002854 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2855 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2856 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002857
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002858 const std::vector<DataType> supportedInputTypes =
2859 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002860 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002861 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002862 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002863 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002864 DataType::QAsymmU8,
2865 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002866 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002867
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002868 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2869 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2870 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2871
2872 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2873 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2874 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2875 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2876
2877 // NOTE: Output is always Float32 regardless of input type
2878 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2879 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2880 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2881 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002882
2883 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2884 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002885 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002886 "must be positive and less than or equal to 1.");
2887 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002888
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002889 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2890 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002891 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002892 "should be equal to number of classes + 1.");
2893 }
2894}
2895
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002896void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2897{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002898 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002899
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002900 ValidateNumInputs(workloadInfo, descriptorName, 1);
2901 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2902
2903 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2904 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2905
Teresa Charlin07307f32022-05-15 14:07:05 +01002906 std::vector<DataType> inputSupportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002907 {
Teresa Charlin07307f32022-05-15 14:07:05 +01002908 DataType::QAsymmS8,
2909 DataType::QAsymmU8,
2910 DataType::QSymmS8,
2911 DataType::QSymmS16,
2912 DataType::Float16
2913 };
2914 ValidateDataTypes(inputTensorInfo, inputSupportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002915
Teresa Charlin07307f32022-05-15 14:07:05 +01002916 std::vector<DataType> outputSupportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002917 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002918 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002919 DataType::Float32,
2920 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002921 };
2922
Teresa Charlin07307f32022-05-15 14:07:05 +01002923 ValidateDataTypes(outputTensorInfo, outputSupportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002924}
2925
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002926void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2927{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002928 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002929
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002930 ValidateNumInputs(workloadInfo, descriptorName, 2);
2931 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002932
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002933 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2934 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2935 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002936
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002937 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2938 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2939
2940 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2941 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002942}
2943
Keith Davis3ae3f972021-05-21 16:33:48 +01002944void ShapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2945{
2946 const std::string& descriptorName{"ShapeQueueDescriptor"};
2947
2948 ValidateNumInputs(workloadInfo, descriptorName, 1);
2949 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2950
2951 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2952 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2953
2954 std::vector<DataType> supportedTypes =
2955 {
2956 DataType::BFloat16,
2957 DataType::Float16,
2958 DataType::Float32,
2959 DataType::QAsymmS8,
2960 DataType::QAsymmU8,
2961 DataType::QAsymmS8,
2962 DataType::QSymmS8,
2963 DataType::QSymmS16,
2964 DataType::Signed32
2965 };
2966
2967 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2968 ValidateDataTypes(outputTensorInfo, {DataType::Signed32}, descriptorName);
2969}
2970
Sadik Armaganeff363d2019-04-05 15:25:46 +01002971void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2972{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002973 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002974
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002975 ValidateNumInputs(workloadInfo, descriptorName, 2);
2976 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2977
2978 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2979 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2980
2981 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2982 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2983
2984 std::vector<DataType> supportedTypes =
2985 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002986 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002987 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002988 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002989 DataType::QAsymmU8,
2990 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002991 };
2992
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002993 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2994 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002995
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002996 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2997 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002998
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002999 ValidateTensorShapesMatch(inputTensorInfo0,
3000 outputTensorInfo0,
3001 descriptorName,
3002 "input_0",
3003 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01003004
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003005 ValidateTensorShapesMatch(inputTensorInfo0,
3006 outputTensorInfo1,
3007 descriptorName,
3008 "input_0",
3009 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01003010}
3011
Derek Lamberti901ea112019-12-10 22:07:09 +00003012void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00003013{
3014 // This is internally generated so it should not need validation.
3015}
3016
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003017void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3018{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003019 const std::string& descriptorName{"PreluQueueDescriptor"};
3020
3021 ValidateNumInputs(workloadInfo, descriptorName, 2);
3022 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3023
3024 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3025 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
3026 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003027
3028 std::vector<DataType> supportedTypes
3029 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003030 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003031 DataType::Float16,
3032 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003033 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003034 DataType::QAsymmU8,
3035 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003036 };
3037
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003038 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3039 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003040
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003041 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003042
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003043 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
3044 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003045
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003046 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
3047 alphaTensorInfo,
3048 outputTensorInfo,
3049 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003050 "input",
3051 "alpha");
3052}
3053
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003054void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3055{
3056 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
3057
3058 ValidateNumInputs(workloadInfo, descriptorName, 1);
3059 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3060
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003061 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3062 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3063
3064 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
3065 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003066
3067 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003068
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003069 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
3070 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003071
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003072 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
3073
3074 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003075 if (m_Parameters.m_BiasEnabled)
3076 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003077 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003078
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003079 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
3080 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003081
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003082 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003083 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003084 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003085
3086 ValidatePerAxisQuantization(inputTensorInfo,
3087 outputTensorInfo,
3088 weightTensorInfo,
3089 optionalBiasTensorInfo,
3090 descriptorName);
3091
3092 std::vector<DataType> supportedTypes =
3093 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003094 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003095 DataType::Float32,
3096 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003097 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003098 DataType::QAsymmU8,
3099 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003100 };
3101
3102 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3103 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003104}
3105
Mike Kellyc9ea45a2020-02-28 18:11:58 +00003106void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3107{
3108 const std::string descriptorName{"TransposeQueueDescriptor"};
3109
3110 ValidateNumInputs(workloadInfo, descriptorName, 1);
3111 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3112
3113 const PermutationVector& mapping = m_Parameters.m_DimMappings;
3114
3115 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3116 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3117
3118 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
3119 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
3120
3121 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
3122 {
3123 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
3124 {
3125 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
3126 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
3127 "must match dst dimension " + to_string(i) +
3128 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
3129 }
3130 }
3131
3132 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3133}
3134
Simon Obute51f67772021-09-03 15:50:13 +01003135void ChannelShuffleQueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
3136{
3137 const std::string descriptorName{"TransposeQueueDescriptor"};
3138
3139 ValidateNumInputs(workloadInfo, descriptorName, 1);
3140 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3141
3142 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3143 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3144
3145 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3146}
3147
James Conroy4f1f8992020-04-29 20:01:10 +01003148void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3149{
3150 const std::string descriptorName{"QLstmQueueDescriptor"};
3151
3152 // Validate number of inputs/outputs
3153 ValidateNumInputs(workloadInfo, descriptorName, 3);
3154 ValidateNumOutputs(workloadInfo, descriptorName, 3);
3155
3156 // Input/output tensor info
3157 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3158 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
3159 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
3160
3161 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3162 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3163 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
3164
3165 // Supported types for various tensors in QLSTM
3166 std::vector<DataType> inputOutputSupportedTypes =
3167 {
3168 DataType::QAsymmS8
3169 };
3170
3171 std::vector<DataType> cellStateSupportedTypes =
3172 {
3173 DataType::QSymmS16
3174 };
3175
3176 std::vector<DataType> weightsSupportedTypes =
3177 {
3178 DataType::QSymmS8
3179 };
3180
3181 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3182 {
3183 DataType::QSymmS16
3184 };
3185
3186 std::vector<DataType> biasSupportedTypes =
3187 {
3188 DataType::Signed32
3189 };
3190
3191 // Validate types of input/output tensors
3192 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3193 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3194 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3195
3196 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3197 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3198 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3199
3200 // Validate matching types of input/output tensors
3201 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3202 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3203 "outputStateIn", "outputStateOut");
3204 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3205
3206 // Infer number of batches, number of units, input size and output size from tensor dimensions
3207 const uint32_t numBatches = inputInfo.GetShape()[0];
3208 const uint32_t inputSize = inputInfo.GetShape()[1];
3209 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3210 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3211
3212 // Validate number of dimensions and number of elements for input/output tensors
3213 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3214 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3215 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3216
3217 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3218 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3219 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3220
3221 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3222 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3223 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3224 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3225
3226 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3227 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3228 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3229
3230 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3231 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3232 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3233
3234 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3235 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3236 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3237 " RecurrentToForgetWeights");
3238
3239 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3240 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3241 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3242
3243 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3244 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3245 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3246
3247 // Validate data types for MANDATORY weights tensors (all should match each other)
3248 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3249
3250 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3251 "inputToForgetWeights", "inputToCellWeights");
3252 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3253 "inputToForgetWeights", "inputToOutputWeights");
3254
3255 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3256 "inputToForgetWeights", "recurrentToForgeteights");
3257 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3258 "inputToForgetWeights", "recurrentToCellWeights");
3259 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3260 "inputToForgetWeights", "recurrentToOutputWeights");
3261
3262 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3263 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3264 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3265 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3266
3267 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3268 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3269 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3270
3271 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3272 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3273 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3274
3275 // Validate data types for MANDATORY bias tensors
3276 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3277
3278 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3279 "forgetGateBias", "cellBias");
3280 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3281 "forgetGateBias", "outputGateBias");
3282
3283 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3284 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3285 !m_Parameters.m_CifgEnabled) ||
3286 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3287 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3288
3289 if (!allCifgParamsPresentOrNot)
3290 {
3291 throw InvalidArgumentException(descriptorName +
3292 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3293 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3294 "set appropriately.");
3295 }
3296
3297 if (!m_Parameters.m_CifgEnabled)
3298 {
3299 // Validate number of dimensions and number of elements
3300 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3301 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3302
3303 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3304 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3305 " RecurrentToInputWeights");
3306
3307 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3308 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3309
3310 // Validate data types
3311 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3312 "inputToForgetWeights", "inputToInputWeights");
3313 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3314 "inputToForgetWeights", "recurrentToInputWeights");
3315 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3316 "forgetGateBias", "inputGateBias");
3317 }
3318
3319 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3320 bool allPeepholeWeightsPresentOrNot =
3321 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3322 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3323 || (!m_CellToInputWeights && !m_CellToForgetWeights
3324 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3325
3326 if (!allPeepholeWeightsPresentOrNot)
3327 {
3328 throw InvalidArgumentException(descriptorName +
3329 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3330 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3331 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3332 "appropriately.");
3333 }
3334
3335 if (m_Parameters.m_PeepholeEnabled)
3336 {
3337 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3338 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3339 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3340
3341 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3342 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3343 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3344 "cellToForgetWeight", "cellToOutputWeights");
3345
3346 if (!m_Parameters.m_CifgEnabled)
3347 {
3348 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3349 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3350 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3351 "cellToForgetWeights", "cellToInputWeights");
3352 }
3353 }
3354
3355 // Validate OPTIONAL params: Layer Norm Weights
3356 bool allLayerNormWeightsPresentOrNot =
3357 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3358 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3359 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3360 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3361
3362 if (!allLayerNormWeightsPresentOrNot)
3363 {
3364 throw InvalidArgumentException(descriptorName +
3365 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3366 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3367 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3368 "only be present when Layer Norm is enabled and CIFG is disabled. "
3369 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3370 }
3371
3372 if (m_Parameters.m_LayerNormEnabled)
3373 {
3374 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3375 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3376 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3377
3378 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3379 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3380 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3381 "forgetLayerNormWeights", "cellLayerNormWeights");
3382
3383 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3384 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3385 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3386 "forgetLayerNormWeights", "outputLayerNormWeights");
3387
3388 if (!m_Parameters.m_CifgEnabled)
3389 {
3390 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3391 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3392 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3393 "forgetLayerNormWeights", "inputLayerNormWeights");
3394 }
3395 }
3396
3397 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3398 bool correctProjectionTensorsPresent =
3399 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3400 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3401 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3402
3403 if (!correctProjectionTensorsPresent)
3404 {
3405 throw InvalidArgumentException(descriptorName +
3406 ": If projection is enabled, ProjectionWeights should be present and "
3407 "ProjectionBias is optional. If projection is disabled, neither "
3408 "ProjectionWeights nor ProjectionBias should be present.");
3409 }
3410
3411 if (m_Parameters.m_ProjectionEnabled)
3412 {
3413 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3414 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3415 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3416
3417 if (m_ProjectionBias)
3418 {
3419 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003420 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003421 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3422 }
3423
3424 }
3425 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3426 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3427 throw InvalidArgumentException(descriptorName +
3428 ": If projection is disabled, output quantization info (scale, offset) "
3429 "should match HiddenStateScale and HiddenStateZeroPoint.");
3430 }
3431
3432}
3433
James Conroy9c3cae82019-08-01 16:01:48 +01003434void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3435{
3436 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3437
3438 // Validate number of inputs/outputs
3439 ValidateNumInputs(workloadInfo, descriptorName, 3);
3440 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3441
3442 // Input/output tensor infos
3443 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3444 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3445 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3446
3447 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3448 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3449
3450 std::vector<DataType> inputOutputSupportedTypes =
3451 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003452 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003453 };
3454
3455 std::vector<DataType> cellStateSupportedTypes =
3456 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003457 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003458 };
3459
3460 std::vector<DataType> weightsSupportedTypes =
3461 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003462 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003463 };
3464
3465 std::vector<DataType> biasSupportedTypes =
3466 {
3467 DataType::Signed32
3468 };
3469
3470 // Validate types of input/output tensors
3471 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3472 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3473 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3474
3475 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3476 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3477
3478 // Validate matching types of input/output tensors
3479 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3480 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3481 "outputStateIn", "outputStateOut");
3482 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3483
3484 // Validate matching quantization info for input/output tensors
3485 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3486 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3487 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003488
James Conroy9c3cae82019-08-01 16:01:48 +01003489 // Infer number of batches, input size and output size from tensor dimensions
3490 const uint32_t numBatches = inputInfo.GetShape()[0];
3491 const uint32_t inputSize = inputInfo.GetShape()[1];
3492 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3493
3494 // Validate number of dimensions and number of elements for input/output tensors
3495 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3496 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3497 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3498 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3499 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3500
3501 // Validate number of dimensions and number of elements for weights tensors
3502 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3503 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3504 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3505
3506 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3507 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3508 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3509
3510 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3511 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3512 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3513
3514 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3515 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3516 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3517
3518 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3519 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3520 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3521
3522 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3523 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3524 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3525 " RecurrentToForgetWeights");
3526
3527 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3528 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3529 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3530
3531 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3532 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3533 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3534
3535 // Validate data types for weights tensors (all should match each other)
3536 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3537
3538 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3539 "inputToInputWeights", "inputToForgetWeights");
3540 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3541 "inputToInputWeights", "inputToCellWeights");
3542 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3543 "inputToInputWeights", "inputToOutputWeights");
3544
3545 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3546 "inputToInputWeights", "recurrentToInputWeights");
3547 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3548 "inputToInputWeights", "recurrentToForgeteights");
3549 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3550 "inputToInputWeights", "recurrentToCellWeights");
3551 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3552 "inputToInputWeights", "recurrentToOutputWeights");
3553
3554 // Validate matching quantization info for weight tensors (all should match each other)
3555 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3556 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3557 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3558 descriptorName, "inputToInputWeights", "inputToCellWeights");
3559 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3560 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3561
3562 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3563 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3564 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3565 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3566 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3567 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3568 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3569 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3570
3571 // Validate number of dimensions and number of elements in bias tensors
3572 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3573 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3574 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3575
3576 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3577 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3578 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3579
3580 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3581 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3582 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3583
3584 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3585 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3586 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3587
3588 // Validate data types for bias tensors (all should match each other)
3589 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3590
3591 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3592 "inputGateBias", "forgetGateBias");
3593 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3594 "inputGateBias", "cellBias");
3595 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3596 "inputGateBias", "outputGateBias");
3597
3598 // Validate bias tensor quantization info
3599 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3600 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3601 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3602 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3603}
3604
Kevin May868eb142019-09-04 17:29:31 +01003605void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3606{
3607 const std::string descriptorName{"AbsQueueDescriptor"};
3608
3609 ValidateNumInputs(workloadInfo, descriptorName, 1);
3610 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3611
3612 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3613 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3614
3615 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3616
3617 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003618 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003619 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003620 DataType::Float16,
3621 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003622 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003623 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003624 DataType::QSymmS16,
3625 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003626 };
Kevin May868eb142019-09-04 17:29:31 +01003627
3628 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3629 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3630}
3631
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003632void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3633{
3634 const std::string descriptorName{"SliceQueueDescriptor"};
3635
3636 ValidateNumInputs(workloadInfo, descriptorName, 1);
3637 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3638
3639 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3640 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3641
3642 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3643
3644 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3645 if (rank > 4)
3646 {
3647 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3648 }
3649
3650 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3651
3652 // Check if m_Begin and m_Size have the expected length
3653 if (m_Parameters.m_Begin.size() != rank)
3654 {
3655 throw InvalidArgumentException(descriptorName +
3656 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3657 }
3658 if (m_Parameters.m_Size.size() != rank)
3659 {
3660 throw InvalidArgumentException(descriptorName +
3661 ": Length of size descriptor must equal rank " + std::to_string(rank));
3662 }
3663
3664 // Check if the shape of the output tensor matches m_Size
3665 const TensorShape& outputShape = outputTensorInfo.GetShape();
3666 for (unsigned int i = 0u; i < rank; ++i)
3667 {
3668 if (m_Parameters.m_Size[i] != outputShape[i])
3669 {
3670 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3671 }
3672 }
3673
3674 // Check if the sum of begin offset and size in a given dimension
3675 // does not exceed the size of corresponding input
3676 const TensorShape& inputShape = inputTensorInfo.GetShape();
3677 for(unsigned int i = 0u; i < rank; ++i)
3678 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003679 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003680 {
3681 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3682 std::to_string(i) + " exceeds input size.");
3683 }
3684 }
3685}
3686
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003687void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3688{
3689 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3690
3691 ValidateNumInputs(workloadInfo, descriptorName, 1);
3692 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3693
3694 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3695 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3696
3697 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3698 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3699
3700 std::vector<DataType> supportedTypes =
3701 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003702 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003703 DataType::Float32,
3704 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003705 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003706 DataType::QAsymmU8,
3707 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003708 };
3709
3710 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3711 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3712
3713 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3714
3715 if (m_Parameters.m_BlockSize == 0)
3716 {
3717 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3718 }
3719
3720 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3721 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3722 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3723 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3724
3725 const TensorShape& outputShape = outputInfo.GetShape();
3726 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3727 {
3728 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3729 "must be divisible by block size.");
3730 }
3731
3732 const TensorShape& inputShape = inputInfo.GetShape();
3733 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3734 {
3735 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3736 "must be divisible by the square of block size." );
3737 }
3738}
3739
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003740void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3741{
3742 const std::string descriptorName{"ComparisonQueueDescriptor"};
3743
3744 ValidateNumInputs(workloadInfo, descriptorName, 2);
3745 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3746
3747 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3748 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3749 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3750
3751 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3752 inputTensorInfo1,
3753 outputTensorInfo,
3754 descriptorName,
3755 "input_0",
3756 "input_1");
3757
3758 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3759 {
3760 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3761 }
3762}
3763
josh minor4a3c6102020-01-06 16:40:46 -06003764void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3765{
3766 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3767
3768 ValidateNumInputs(workloadInfo, descriptorName, 1);
3769 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3770
3771 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3772 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3773
3774 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3775
3776 std::vector<DataType> supportedTypes =
3777 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003778 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003779 DataType::Float16,
3780 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003781 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003782 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003783 DataType::QSymmS16,
3784 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003785 };
3786
James Conroyaba90cd2020-11-06 16:28:18 +00003787 std::vector<DataType> logicalSupportedTypes =
3788 {
3789 DataType::Boolean
3790 };
3791
3792 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3793 {
3794 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3795 }
3796 else
3797 {
3798 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3799 }
3800
3801
josh minor4a3c6102020-01-06 16:40:46 -06003802 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3803}
3804
Finn Williams2605b232020-06-10 15:53:46 +01003805void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3806{
3807 const std::string descriptorName{"RankQueueDescriptor"};
3808
3809 ValidateNumInputs(workloadInfo, descriptorName, 1);
3810 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3811
3812 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3813 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3814
3815 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3816 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3817
3818 std::vector<DataType> supportedTypes =
3819 {
3820 DataType::BFloat16,
3821 DataType::Float16,
3822 DataType::Float32,
3823 DataType::QAsymmS8,
3824 DataType::QAsymmU8,
3825 DataType::QSymmS8,
3826 DataType::QSymmS16,
3827 DataType::Signed32
3828 };
3829
3830 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3831 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3832}
3833
James Conroyaba90cd2020-11-06 16:28:18 +00003834void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3835{
3836 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3837
3838 ValidateNumInputs(workloadInfo, descriptorName, 2);
3839 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3840
3841 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3842 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3843 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3844
3845 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3846 inputTensorInfo1,
3847 outputTensorInfo,
3848 descriptorName,
3849 "input_0",
3850 "input_1");
3851
3852 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3853 {
3854 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3855 }
3856
3857 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3858 {
3859 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3860 }
3861
3862 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3863 {
3864 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3865 }
3866}
3867
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003868void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3869{
3870 const std::string descriptorName{"ReduceQueueDescriptor"};
3871
3872 ValidateNumInputs(workloadInfo, descriptorName, 1);
3873 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3874
3875 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3876 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3877
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003878 std::vector<DataType> supportedTypes =
3879 {
3880 DataType::BFloat16,
3881 DataType::Float16,
3882 DataType::Float32,
3883 DataType::QAsymmS8,
3884 DataType::QAsymmU8,
3885 DataType::QSymmS16,
3886 DataType::Signed32
3887 };
3888
3889 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3890 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3891}
3892
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003893void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3894{
3895 // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm
3896
3897 const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"};
3898
3899 // check dimensions of all inputs and outputs
3900 if (workloadInfo.m_InputTensorInfos.size() != 3)
3901 {
3902 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
3903 }
Mike Kelly12994962022-04-21 11:57:09 +01003904 if (workloadInfo.m_OutputTensorInfos.size() != 3)
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003905 {
3906 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
3907 }
3908
3909 std::vector<DataType> supportedTypes =
3910 {
Mike Kelly12994962022-04-21 11:57:09 +01003911 DataType::Float32,
3912 DataType::QAsymmS8
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003913 };
3914
3915 // check for supported type of one input and match them with all the other input and output
3916 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
3917
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003918 // Making sure clipping parameters have valid values.
3919 // == 0 means no clipping
3920 // > 0 means clipping
3921 if (m_Parameters.m_ClippingThresCell < 0.0f)
3922 {
3923 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
3924 }
3925 if (m_Parameters.m_ClippingThresProj < 0.0f)
3926 {
3927 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
3928 }
3929
3930 unsigned int batchIndx = 0;
3931 unsigned int inputIndx = 1;
3932 uint32_t timeStep = 1;
3933 unsigned int timeIndx = 1;
3934 inputIndx = 2;
3935 if (m_Parameters.m_TimeMajor)
3936 {
3937 batchIndx = 1;
3938 timeIndx = 0;
3939
3940 }
3941 timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx];
3942
3943 // Inferring batch size, number of outputs and number of cells from the inputs.
3944 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx];
3945 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx];
3946 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
3947 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
3948 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
3949 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
3950
3951 // input tensor
3952 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input),
3953 descriptorName + " input_0");
3954 // outputStateInTensor
3955 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
3956 descriptorName + " input_1");
3957 // outputStateInTensor
3958 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
3959 descriptorName + " input_2");
3960
3961 // outputTensor
Mike Kelly12994962022-04-21 11:57:09 +01003962 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 3, (timeStep * n_batch * n_output),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003963 descriptorName + " output_0");
3964
3965 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
3966 if ( m_InputToInputWeights )
3967 {
3968 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
3969 (n_cell * n_input), "InputLayerNormWeights");
3970 }
3971
3972 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
3973 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
3974 (n_cell * n_input), "InputToForgetWeights");
3975
3976 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
3977 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
3978 (n_cell * n_input), "InputToCellWeights");
3979
3980 if ( m_RecurrentToInputWeights )
3981 {
3982 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
3983 (n_cell * n_output), "RecurrentToInputWeights");
3984 }
3985
3986 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
3987 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
3988 (n_cell * n_output), "RecurrentToForgetWeights");
3989
3990 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
3991 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
3992 (n_cell * n_output), "RecurrentToCellWeights");
3993
3994 // Make sure the input-gate's parameters are either both present (regular
3995 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
3996 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
3997 !m_Parameters.m_CifgEnabled) ||
3998 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3999 m_Parameters.m_CifgEnabled));
4000 if (!cifg_weights_all_or_none)
4001 {
4002 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
4003 "RecurrentToInputWeights must either both be present (regular LSTM) "
4004 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
4005 "accordingly.");
4006 }
4007
4008 if ( m_CellToInputWeights )
4009 {
4010 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
4011 n_cell, "CellToInputWeights");
4012 }
4013 if ( m_CellToForgetWeights )
4014 {
4015 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
4016 n_cell, "CellToForgetWeights");
4017 }
4018 if ( m_CellToOutputWeights )
4019 {
4020 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
4021 n_cell, "CellToOutputWeights");
4022 }
4023
4024 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
4025 bool peephole_weights_all_or_none =
4026 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
4027 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
4028 || ( !m_CellToInputWeights && !m_CellToForgetWeights
4029 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
4030 if (!peephole_weights_all_or_none)
4031 {
4032 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
4033 }
4034
4035 // Make sure the input gate bias is present only when not a CIFG-LSTM.
4036 if (m_Parameters.m_CifgEnabled)
4037 {
4038 if (m_InputGateBias)
4039 {
4040 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
4041 }
4042 }
4043 else
4044 {
4045 if (!m_InputGateBias)
4046 {
4047 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
4048 "must be present.");
4049 }
4050 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
4051 n_cell, "InputGateBias");
4052 }
4053
4054 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
4055 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
4056
4057 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
4058 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
4059
4060 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
4061 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
4062
4063 if (m_ProjectionWeights)
4064 {
4065 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
4066 (n_cell * n_output), "ProjectionWeights");
4067 }
4068 if (m_ProjectionBias)
4069 {
4070 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
4071 }
4072
4073 // Making sure the projection tensors are consistent:
4074 // 1) If projection weight is not present, then projection bias should not be
4075 // present.
4076 // 2) If projection weight is present, then projection bias is optional.
4077 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
4078 !m_Parameters.m_ProjectionEnabled)
4079 || (m_ProjectionWeights && !m_ProjectionBias &&
4080 m_Parameters.m_ProjectionEnabled)
4081 || (m_ProjectionWeights && m_ProjectionBias &&
4082 m_Parameters.m_ProjectionEnabled));
4083 if (!projecton_tensors_consistent)
4084 {
4085 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
4086 }
4087
4088 // The four layer normalization weights either all have values or none of them have values. Additionally, if
4089 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
4090 // either all have values or none of them have values. Layer normalization is used when the values of all the
4091 // layer normalization weights are present
4092 if (m_InputLayerNormWeights)
4093 {
4094 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
4095 }
4096 if (m_ForgetLayerNormWeights)
4097 {
4098 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4099 }
4100 if (m_CellLayerNormWeights)
4101 {
4102 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4103 }
4104 if (m_OutputLayerNormWeights)
4105 {
4106 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4107 }
4108
4109 if (m_Parameters.m_LayerNormEnabled)
4110 {
4111 if (!m_Parameters.m_CifgEnabled)
4112 {
4113 if (!m_InputLayerNormWeights)
4114 {
4115 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
4116 "disabled but InputLayerNormWeights are not present");
4117 }
4118 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
4119 1, n_cell, "InputLayerNormWeights");
4120 }
4121 else if (m_InputLayerNormWeights)
4122 {
4123 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
4124 "enabled");
4125 }
4126
4127 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
4128 "ForgetLayerNormWeights");
4129 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4130
4131 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
4132 "OutputLayerNormWeights");
4133 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4134
4135 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
4136 "CellLayerNormWeights");
4137 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4138 }
4139 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
4140 {
4141 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
4142 "normalisation weights are present.");
4143 }
4144}
4145
4146
mathad01df9a3222021-04-28 11:42:57 +01004147} // namespace armnn