blob: 2d7a5fdffc3a6e7cff6465209ef6e8e169c08d15 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Mike Kelly3ec30772023-03-08 13:47:17 +00002// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Colm Donelan0c479742021-12-10 12:43:54 +00006#include <armnn/backends/TensorHandle.hpp>
7#include <armnn/backends/WorkloadData.hpp>
8#include <armnn/backends/WorkloadInfo.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00009#include <armnnUtils/DataLayoutIndexed.hpp>
10#include <armnnUtils/TensorUtils.hpp>
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010011#include <armnnUtils/Permute.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010012#include <armnn/utility/NumericCast.hpp>
mathad01df9a3222021-04-28 11:42:57 +010013#include <armnn/Logging.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000014
telsoa014fcda012018-03-09 14:13:49 +000015#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000017#include <string>
18#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000019
James Ward47fce872020-09-10 11:57:28 +010020#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000021
Matteo Martincigh21350152018-11-28 16:22:22 +000022using namespace armnnUtils;
23
telsoa014fcda012018-03-09 14:13:49 +000024namespace armnn
25{
26
27//---------------------------------------------------------------
28DataType GetBiasDataType(DataType inputDataType)
29{
30 switch (inputDataType)
31 {
telsoa01c577f2c2018-08-31 09:22:23 +010032 case DataType::Float16:
33 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000034 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000035 case DataType::Float32:
36 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000037 case DataType::QAsymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +000038 case DataType::QAsymmU8:
Keith Davis5204aa82020-01-27 15:24:59 +000039 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +000040 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010041 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000042 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010043 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000044 return DataType::Float32;
45 }
46}
47
48namespace
49{
50
51//---------------------------------------------------------------
52//android ndk does not support std::to_string function.
53template <typename T>
54std::string to_string(T value)
55{
56 std::ostringstream os;
57 os << value;
58 return os.str();
59}
60
61//---------------------------------------------------------------
62void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
63{
64 if (!ptr)
65 {
66 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
67 paramName + " parameter must be set.");
68 }
69}
70
71//---------------------------------------------------------------
72void ValidateTensorShapesMatch(const TensorInfo& first,
73 const TensorInfo& second,
74 std::string const& descName,
75 std::string const& firstName,
76 std::string const& secondName)
77{
78 if (first.GetShape() != second.GetShape())
79 {
80 throw InvalidArgumentException(descName + ": "
81 + firstName + " & " + secondName + " must have identical shapes");
82 }
83}
84
85//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010086void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000087{
Sadik Armaganeff363d2019-04-05 15:25:46 +010088 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000089 {
90 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010091 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000092 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
93 }
94}
95
96//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010097void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000098{
Sadik Armaganeff363d2019-04-05 15:25:46 +010099 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000100 {
101 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100102 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000103 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
104 }
105}
106
107//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000108
109//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100110void ValidateTensorNumElements(const TensorInfo& tensor,
111 std::string const& descName,
112 unsigned int numElements,
113 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100114{
115 if (tensor.GetNumElements() != numElements)
116 {
117 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100118 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100119 tensorName + " tensor.");
120 }
121}
122
123//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000124void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
125 const std::string& descName, std::string const& tensorName)
126{
127 if (tensor.GetDataType() != dataType)
128 {
129 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
130 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
131 }
132}
133
Derek Lambertid466a542020-01-22 15:37:29 +0000134void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
135{
Jan Eilers1b2654f2021-09-24 15:45:46 +0100136 if (tensor.GetDataType() != DataType::QSymmS8)
Derek Lambertid466a542020-01-22 15:37:29 +0000137 {
138 throw InvalidArgumentException(descName +
139 ": Expected data type which supports per-axis quantization scheme but got " +
140 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
141 }
Derek Lambertid466a542020-01-22 15:37:29 +0000142}
143
telsoa014fcda012018-03-09 14:13:49 +0000144//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100145void ValidateTensorQuantizationSpace(const TensorInfo& first,
146 const TensorInfo& second,
147 const std::string& descName,
148 std::string const& firstName,
149 std::string const& secondName)
150{
151 if (!first.IsQuantized() ||
152 !second.IsQuantized())
153 {
154 // Not a quantized type, ignore the validation
155 return;
156 }
157
158 DataType firstDataType = first.GetDataType();
159 DataType secondDataType = second.GetDataType();
160
161 if (firstDataType != secondDataType)
162 {
163 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
164 " must be of the same quantized type, " +
165 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
166 secondName + " is " + GetDataTypeName(secondDataType));
167 }
168
169 if (!first.IsTypeSpaceMatch(second))
170 {
171 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
172 " must have the same quantization space, " +
173 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
174 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
175 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
176 " and scale " + to_string(second.GetQuantizationScale()));
177 }
178}
179
180//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100181void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100182 const TensorInfo& weightsTensorInfo,
183 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000184{
185 if (biasTensor.GetQuantizationOffset() != 0)
186 {
187 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
188 to_string(biasTensor.GetQuantizationOffset()));
189 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000190
James Conroy8502ade2020-11-12 19:26:29 +0000191 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000192 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000193 // Validate per-axis quantization scales
194 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
195 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
196
197 if (weightScales.size() != biasScales.size())
198 {
199 std::stringstream msg;
James Conroy8502ade2020-11-12 19:26:29 +0000200 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
201 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
202 << ", biases=" << biasScales.size();
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000203 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
204 }
telsoa014fcda012018-03-09 14:13:49 +0000205 }
206}
207
208//---------------------------------------------------------------
209void ValidateTensors(const std::vector<ITensorHandle*>& vec,
Teresa Charlin79a06a52023-07-13 17:16:45 +0100210 unsigned int numExpected,
211 const std::string& descName,
212 const std::string& varName)
telsoa014fcda012018-03-09 14:13:49 +0000213{
214 if (vec.empty() && numExpected > 0)
215 {
216 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
217 }
218
219 for (unsigned int i = 0; i < numExpected; ++i)
220 {
221 if (!vec[i])
222 {
223 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
224 }
225 }
226}
227
228//---------------------------------------------------------------
229void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
230 const TensorInfo& second,
231 const TensorInfo& output,
232 std::string const& descName,
233 std::string const& firstName,
234 std::string const& secondName)
235{
236 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
237 // broadcasted.
238 if (first.GetNumDimensions() != second.GetNumDimensions())
239 {
240 throw InvalidArgumentException(descName + ": Tensors "
241 + firstName + " & " + secondName
242 + " must have the same number of dimensions in order to be broadcasted");
243 }
244 uint32_t numDims = first.GetNumDimensions();
245 std::vector<uint32_t> outputDims(numDims, 0u);
246 for (uint32_t i = 0; i < numDims; i++)
247 {
248 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
249 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
250 if (dimsNotEqual && dimsNotOne)
251 {
252 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
253 }
254 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
255 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100256 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000257 if (broadcastShape != output.GetShape())
258 {
259 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
260 + firstName + " & " + secondName
261 + " does not match the output shape");
262 }
263}
264
265//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100266void ValidateDataTypes(const TensorInfo& info,
267 const std::vector<armnn::DataType>& supportedTypes,
268 std::string const& descName)
269{
270 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
271 if (iterator == supportedTypes.end())
272 {
273 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
274 }
275}
276
James Conroy4d1ff582019-06-10 17:06:39 +0100277//---------------------------------------------------------------
278void ValidateTensorDataTypesMatch(const TensorInfo& first,
279 const TensorInfo& second,
280 std::string const& descName,
281 std::string const& firstName,
282 std::string const& secondName)
283{
284 if (first.GetDataType() != second.GetDataType())
285 {
286 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
287 " must have identical data types.");
288 }
289}
290
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100291//---------------------------------------------------------------
292void ValidateTensorNumElementsMatch(const TensorInfo& first,
293 const TensorInfo& second,
294 std::string const& descName,
295 std::string const& firstName,
296 std::string const& secondName)
297{
298 if (first.GetNumElements() != second.GetNumElements())
299 {
300 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
301 " must have the same number of elements.");
302 }
303}
304
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000305void ValidateWeightDataType(const TensorInfo& inputInfo,
306 const TensorInfo& weightInfo,
307 const std::string& descName)
308{
309 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000310 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000311 {
312 const std::vector<DataType> validTypes =
313 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000314 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100315 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +0100316 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000317 };
318
319 ValidateDataTypes(weightInfo, validTypes, descName);
320 }
321 else
322 {
323 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
324 }
325}
326
327void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
328 const std::string& descName,
329 const std::string& tensorName)
330{
331 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
332 if (!quantizationDim.has_value())
333 {
James Ward47fce872020-09-10 11:57:28 +0100334 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
335 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000336 }
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000337}
338
339void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
340 const std::string& descName,
341 const std::string& tensorName)
342{
343 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
344 if (quantizationOffset != 0)
345 {
James Ward47fce872020-09-10 11:57:28 +0100346 throw InvalidArgumentException(fmt::format(
347 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
348 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000349 }
350}
351
352void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
353 const TensorInfo& outputInfo,
354 const TensorInfo& weightInfo,
355 const Optional<TensorInfo>& optionalBiasInfo,
356 const std::string& descName)
357{
358 if (weightInfo.HasPerAxisQuantization())
359 {
360 const DataType inputDataType = inputInfo.GetDataType();
361 const DataType outputDataType = outputInfo.GetDataType();
362
Keith Davis0c2eeac2020-02-11 16:51:50 +0000363 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000364
365 if (!canHavePerAxisQuantization)
366 {
James Ward47fce872020-09-10 11:57:28 +0100367 throw InvalidArgumentException(fmt::format(
368 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
369 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000370 }
371
Derek Lambertid466a542020-01-22 15:37:29 +0000372
373 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000374 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
375 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
376
377 if (optionalBiasInfo.has_value())
378 {
379 const TensorInfo& biasInfo = optionalBiasInfo.value();
380 if (!biasInfo.HasPerAxisQuantization())
381 {
James Ward47fce872020-09-10 11:57:28 +0100382 throw InvalidArgumentException(fmt::format(
383 "{}: Per-axis quantization parameters not set on bias tensor, "
384 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000385 }
386
387 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
388 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
389 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
390 }
391 }
392}
393
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100394} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000395
Mike Kelly80512b02022-05-16 23:10:42 +0100396//---------------------------------------------------------------
397void QueueDescriptor::ValidateTensorNumDimensions(const TensorInfo& tensor,
398 std::string const& descName,
399 unsigned int numDimensions,
400 std::string const& tensorName) const
401{
402 // If we're allowing expanded dimensions then numDimensions becomes the minimum number of Dimensions we can allow.
403 // Throw an Exception if the tensors has fewer than numDimensions or if the squeezed dimensions are greater than
404 // numDimensions.
405 if (m_AllowExpandedDims)
406 {
407 unsigned int squeezedDims = 0;
408
409 for (unsigned int i = 0; i < tensor.GetNumDimensions(); ++i)
410 {
411 if (tensor.GetShape()[i] != 1)
412 {
413 ++squeezedDims;
414 }
415 }
416 if (tensor.GetNumDimensions() < numDimensions || squeezedDims > numDimensions)
417 {
418 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " or less but got " +
419 to_string(tensor.GetNumDimensions()) + " dimensions for " +
420 tensorName + " tensor.");
421 }
422 }
423 else
424 {
425 if (tensor.GetNumDimensions() != numDimensions)
426 {
427 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
428 to_string(tensor.GetNumDimensions()) + " dimensions for " +
429 tensorName + " tensor.");
430 }
431 }
432}
433
434//---------------------------------------------------------------
435void QueueDescriptor::ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Teresa Charlin79a06a52023-07-13 17:16:45 +0100436 unsigned int numDimension,
437 unsigned int numElements,
438 std::string const& tensorName) const
Mike Kelly80512b02022-05-16 23:10:42 +0100439{
440 const std::string functionName{"ValidateTensorNumDimNumElem"};
441 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
442 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
443}
444
445//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000446void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
447 unsigned int numExpectedIn, unsigned int numExpectedOut) const
448{
449 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
450 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
451}
452
453//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100454void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
455{
456 const std::string descriptorName{"MapQueueDescriptor"};
457
458 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100459 ValidateNumOutputs(workloadInfo, descriptorName, 0);
460
461 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
462 {
463 if (!m_Inputs[i])
464 {
465 throw InvalidArgumentException(
466 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
467 }
468 }
469}
470
471//---------------------------------------------------------------
472void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
473{
474 const std::string descriptorName{"UnmapQueueDescriptor"};
475
476 ValidateNumInputs(workloadInfo, descriptorName, 1);
477 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100478
479 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
480 {
481 if (!m_Inputs[i])
482 {
483 throw InvalidArgumentException(
484 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
485 }
486 }
487}
488
489//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000490void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
491{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100492 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000493
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100494 ValidateNumInputs(workloadInfo, descriptorName, 1);
495 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000496
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100497 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
498 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
499
500 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
501 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000502
503 if (m_Inputs.size() != m_Outputs.size())
504 {
James Ward47fce872020-09-10 11:57:28 +0100505 throw InvalidArgumentException(fmt::format(
506 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
507 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000508 }
509
510 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
511 {
512 if (!m_Inputs[i])
513 {
James Ward47fce872020-09-10 11:57:28 +0100514 throw InvalidArgumentException(fmt::format(
515 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000516 }
517
518 if (!m_Outputs[i])
519 {
James Ward47fce872020-09-10 11:57:28 +0100520 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000521 }
522 }
523}
524
Derek Lambertif674aa02019-08-01 15:56:25 +0100525//---------------------------------------------------------------
526void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
527{
528 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
529 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
530
531 if (workloadInfo.m_InputTensorInfos.size() != 1)
532 {
James Ward47fce872020-09-10 11:57:28 +0100533 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
534 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100535
536 }
537
538 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
539 {
James Ward47fce872020-09-10 11:57:28 +0100540 throw InvalidArgumentException(fmt::format(
541 "Number of input infos ({0}) does not match the number of output infos ({1})",
542 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100543 }
544
545 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
546 {
547 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
548 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
549 {
James Ward47fce872020-09-10 11:57:28 +0100550 throw InvalidArgumentException(fmt::format(
551 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100552 }
553 }
554
555 if (m_Inputs.size() != 1)
556 {
James Ward47fce872020-09-10 11:57:28 +0100557 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100558 }
559
560 if (m_Inputs.size() != m_Outputs.size())
561 {
James Ward47fce872020-09-10 11:57:28 +0100562 throw InvalidArgumentException(fmt::format(
563 "Number of inputs ({0}) does not match the number of outputs ({1})",
564 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100565 }
566
567 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
568 {
569 if (!m_Inputs[i])
570 {
James Ward47fce872020-09-10 11:57:28 +0100571 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100572 }
573
574 if (!m_Outputs[i])
575 {
James Ward47fce872020-09-10 11:57:28 +0100576 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100577 }
578 }
579}
580
581//---------------------------------------------------------------
582void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
583{
584 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
Derek Lambertif674aa02019-08-01 15:56:25 +0100585
Derek Lambertif674aa02019-08-01 15:56:25 +0100586 if (m_Inputs.size() != 1)
587 {
James Ward47fce872020-09-10 11:57:28 +0100588 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100589 }
590
591 if (m_Outputs.size() != 0)
592 {
James Ward47fce872020-09-10 11:57:28 +0100593 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100594 }
595
596 if (!m_Inputs[0])
597 {
James Ward47fce872020-09-10 11:57:28 +0100598 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100599 }
600}
601
602//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000603void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
604{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100605 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100606
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100607 ValidateNumInputs(workloadInfo, descriptorName, 1);
608 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100609
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100610 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
611 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100612
613 std::vector<DataType> supportedTypes =
614 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000615 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100616 DataType::Float16,
617 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000618 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000619 DataType::QAsymmU8,
620 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100621 };
622
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100623 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
624 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
625 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000626}
627
Nikhil Rajee391d52019-09-05 17:50:44 +0100628void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
629{
630 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
631
632 ValidateNumInputs(workloadInfo, descriptorName, 1);
633 ValidateNumOutputs(workloadInfo, descriptorName, 1);
634
635 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
636 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
637
Inki Daed4619e22020-09-10 15:33:54 +0900638 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
639 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100640 {
Inki Daed4619e22020-09-10 15:33:54 +0900641 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100642 }
643
James Conroyd47a0642019-09-17 14:22:06 +0100644 std::vector<DataType> supportedInputTypes =
645 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000646 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100647 DataType::Float16,
648 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100649 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000650 DataType::QAsymmU8,
651 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900652 DataType::Signed32,
653 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100654 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100655
James Conroyd47a0642019-09-17 14:22:06 +0100656 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100657
658 auto inputShape = inputTensorInfo.GetShape();
659 auto outputShape = outputTensorInfo.GetShape();
660
661 auto inputNumDimensions = inputShape.GetNumDimensions();
662 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
663
664 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
665
666 // 1D input shape results in scalar output shape
667 if (inputShape.GetNumDimensions() == 1)
668 {
669 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
670 {
671 throw InvalidArgumentException(descriptorName + outputShapeError);
672 }
673 }
674 else
675 {
676 for (unsigned int i = 0; i < unsignedAxis; ++i)
677 {
678 if (outputShape[i] != inputShape[i])
679 {
680 throw InvalidArgumentException(descriptorName + outputShapeError);
681 }
682 }
683
684 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
685 {
686 if (outputShape[i - 1] != inputShape[i])
687 {
688 throw InvalidArgumentException(descriptorName + outputShapeError);
689 }
690 }
691 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100692}
693
mathad01b392e982021-04-07 12:07:30 +0100694void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
695{
696 const std::string descriptorName{"CastQueueDescriptor"};
697
698 ValidateNumInputs(workloadInfo, descriptorName, 1);
699 ValidateNumOutputs(workloadInfo, descriptorName, 1);
700
701 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
702 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
703
704 std::vector<DataType> supportedTypes =
705 {
706 DataType::BFloat16,
707 DataType::Float16,
708 DataType::Float32,
709 DataType::QAsymmS8,
710 DataType::QAsymmU8,
711 DataType::QSymmS8,
712 DataType::QSymmS16,
713 DataType::Signed32,
714 DataType::Signed64
715 };
716
717 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
718 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
719}
720
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100721void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
722{
723 const std::string descriptorName{"SoftmaxQueueDescriptor"};
724
725 ValidateNumInputs(workloadInfo, descriptorName, 1);
726 ValidateNumOutputs(workloadInfo, descriptorName, 1);
727
728 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
729 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
730
731 std::vector<DataType> supportedTypes =
732 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000733 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100734 DataType::Float16,
735 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000736 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000737 DataType::QAsymmU8,
738 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100739 };
740
741 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
742 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
743 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
744}
745
telsoa014fcda012018-03-09 14:13:49 +0000746void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
747{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100748 const std::string descriptorName{"SplitterQueueDescriptor"};
749
750 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000751
Ruomei Yan25339c32019-05-28 16:48:20 +0100752 // Check the supported data types
753 std::vector<DataType> supportedTypes =
754 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000755 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100756 DataType::Float32,
757 DataType::Float16,
758 DataType::Boolean,
759 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100760 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000761 DataType::QAsymmU8,
762 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100763 };
764
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100765 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
766 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100767 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100768 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
769 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
770
771 const std::string outputName = "output_" + std::to_string(i);
772 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100773 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100774
telsoa014fcda012018-03-09 14:13:49 +0000775 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
776 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100777 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000778 }
779
780 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
781 {
782 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100783 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000784 "has to match number of workloadInfo.m_OutputTensorInfos. "
785 "Number of windows: " +
786 to_string(m_ViewOrigins.size()) +
787 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
788 }
789
telsoa01c577f2c2018-08-31 09:22:23 +0100790 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000791 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
792 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
793 {
telsoa01c577f2c2018-08-31 09:22:23 +0100794 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000795 ViewOrigin const& e = m_ViewOrigins[w];
796 if (e.m_Origin.size() != inputDims)
797 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100798 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000799 "have the same dimensionality as the input tensor. "
800 "Window origin (index: " +
801 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
802 " dimensions, the input "
803 "tensor has " +
804 to_string(inputDims) + " dimensions.");
805 }
806 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
807 {
808 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
809 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
810 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100811 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000812 "be smaller or equal than the size of the input in that coord.");
813 }
814 }
815 }
816}
817
Jim Flynne242f2d2019-05-22 14:24:13 +0100818void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000819{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100820 const std::string descriptorName{"ConcatQueueDescriptor"};
821
822 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000823
824 if (m_Inputs.size() <= 0)
825 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100826 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000827 }
828 if (m_Outputs.size() <= 0)
829 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100830 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000831 }
832
833 if (workloadInfo.m_InputTensorInfos.size() <= 0)
834 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100835 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000836 }
837 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
838 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100839 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000840 }
841
Nikhil Raj8599a412018-11-19 14:51:07 +0000842 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
843 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100844 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000845 }
846
847 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
848 {
849 return;
850 }
851
telsoa014fcda012018-03-09 14:13:49 +0000852 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
853 {
854 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100855 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000856 "has to match number of workloadInfo.m_InputTensorInfos. "
857 "Number of windows: " +
858 to_string(m_ViewOrigins.size()) +
859 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
860 }
861
telsoa01c577f2c2018-08-31 09:22:23 +0100862 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000863 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
864 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
865 {
telsoa01c577f2c2018-08-31 09:22:23 +0100866 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000867 ViewOrigin const& e = m_ViewOrigins[w];
868 if (e.m_Origin.size() != outputDims)
869 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100870 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000871 "have the same dimensionality as the output tensor. "
872 "Window origin (index: " +
873 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
874 " dimensions, the output "
875 "tensor has " +
876 to_string(outputDims) + " dimensions.");
877 }
telsoa01c577f2c2018-08-31 09:22:23 +0100878 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000879 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
880 {
881 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
882 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
883 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100884 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000885 "be smaller or equal than the size of the output in that coord.");
886 }
887 }
888 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100889
890 // Check the supported data types
891 std::vector<DataType> supportedTypes =
892 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000893 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100894 DataType::Float32,
895 DataType::Float16,
896 DataType::Boolean,
897 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100898 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000899 DataType::QAsymmU8,
900 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100901 };
902
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100903 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
904 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100905 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100906 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
907 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
908
909 const std::string inputName = "input_" + std::to_string(i);
910 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100911 }
telsoa014fcda012018-03-09 14:13:49 +0000912}
913
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100914void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
915{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100916 const std::string descriptorName{"StackQueueDescriptor"};
917
918 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100919
920 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
921 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100922 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100923 }
924
925 // All inputs must have the same shape, which is defined in parameters
926 const TensorShape& inputShape = m_Parameters.m_InputShape;
927 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
928 {
929 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
930 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100931 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100932 }
933 }
934
Matthew Jacksondba634f2019-08-15 15:14:18 +0100935 if (inputShape.GetNumDimensions() > 4)
936 {
937 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
938 }
939
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100940 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
941 // since the output tensor has an additional dimension.
942 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
943 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100944 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100945 "than the number of input dimensions.");
946 }
947
948 // Output shape must be as inferred from the input shape
949 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
950 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
951 {
952 if (outputShape[i] != inputShape[i])
953 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100954 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100955 "match shape inferred from input tensor.");
956 }
957 }
958
959 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
960 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100961 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100962 "match shape inferred from input tensor.");
963 }
964
965 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
966 {
967 if (outputShape[i] != inputShape[i-1])
968 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100969 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100970 "match shape inferred from input tensor.");
971 }
972 }
973
Matthew Jacksondba634f2019-08-15 15:14:18 +0100974 if (outputShape.GetNumDimensions() > 5)
975 {
976 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
977 }
978
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100979 // Check the supported data types
980 std::vector<DataType> supportedTypes =
981 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000982 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100983 DataType::Float32,
984 DataType::Float16,
985 DataType::Boolean,
986 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100987 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000988 DataType::QAsymmU8,
989 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100990 };
991
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100992 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100993
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100994 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100995 {
996 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
997 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100998 descriptorName,
999 "input_0",
1000 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001001 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001002
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001003 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1004 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001005 descriptorName,
1006 "input_0",
1007 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001008}
1009
Ryan OSheaec6c6802020-06-05 17:17:06 +01001010void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1011{
1012 const std::string descriptorName{"FillQueueDescriptor"};
1013
1014 ValidateNumInputs(workloadInfo, descriptorName, 1);
1015 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1016
1017 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1018 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1019
1020 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1021
1022 std::vector<DataType> supportedTypes =
1023 {
1024 DataType::BFloat16,
1025 DataType::Float32,
1026 DataType::Float16,
1027 DataType::Signed32
1028 };
1029
1030 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1031}
1032
telsoa014fcda012018-03-09 14:13:49 +00001033void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1034{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001035 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001036
Matthew Sloyan81beae32021-07-13 19:46:11 +01001037 uint32_t numInputs = 2;
1038 if (m_Parameters.m_BiasEnabled)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001039 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001040 numInputs = 3;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001041 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001042
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001043 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001044 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1045
1046 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1047 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1048
1049 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1050
1051 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001052 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001053 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001054 }
1055
Matthew Sloyan81beae32021-07-13 19:46:11 +01001056 TensorInfo weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001057 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001058
1059 if (m_Parameters.m_BiasEnabled)
1060 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001061 TensorInfo biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
telsoa01c577f2c2018-08-31 09:22:23 +01001062 // Validates type and quantization values.
Ryan OSheaf183acd2023-07-06 11:41:25 +01001063 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001064 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1065 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001066 }
1067
Francis Murtagh46c09d02019-05-28 08:15:28 +01001068 // Check the supported data types
1069 std::vector<DataType> supportedTypes =
1070 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001071 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001072 DataType::Float32,
1073 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001074 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001075 DataType::QAsymmU8,
1076 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001077 };
1078
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001079 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001080
1081 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1082 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1083 {
1084 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1085 {
1086 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1087 "for BFloat16 input.");
1088 }
1089 }
1090 else
1091 {
1092 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1093 }
telsoa014fcda012018-03-09 14:13:49 +00001094}
1095
Teresa Charlin9145e382023-08-17 18:44:58 +01001096void FusedQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
1097{
1098 // This is internally generated, so it should not need validation.
1099}
1100
telsoa014fcda012018-03-09 14:13:49 +00001101void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1102{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001103 const std::string descriptorName{"NormalizationQueueDescriptor"};
1104
1105 ValidateNumInputs(workloadInfo, descriptorName, 1);
1106 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1107
1108 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1109 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001110
1111 // Check the supported data types
1112 std::vector<DataType> supportedTypes =
1113 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001114 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001115 DataType::Float16,
1116 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001117 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001118 DataType::QAsymmU8,
1119 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001120 };
1121
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001122 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001123
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001124 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001125
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001126 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001127}
1128
1129void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1130{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001131 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001132
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001133 ValidateNumInputs(workloadInfo, descriptorName, 2);
1134 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1135
1136 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1137 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1138 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1139
1140 std::vector<DataType> supportedTypes =
1141 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001142 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001143 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001144 DataType::Float16,
1145 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001146 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001147 DataType::QSymmS16,
1148 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001149 };
1150
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001151 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1152 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1153 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001154
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001155 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1156 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001157
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001158 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1159 inputTensorInfo1,
1160 outputTensorInfo,
1161 descriptorName,
1162 "input_0",
1163 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001164}
1165
telsoa014fcda012018-03-09 14:13:49 +00001166void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1167{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001168 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001169
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001170 ValidateNumInputs(workloadInfo, descriptorName, 2);
1171 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1172
1173 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1174 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1175 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1176
1177 std::vector<DataType> supportedTypes =
1178 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001179 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001180 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001181 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001182 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001183 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001184 DataType::QSymmS16,
1185 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001186 };
1187
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001188 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1189 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1190 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001191
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001192 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1193 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001194
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001195 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1196 inputTensorInfo1,
1197 outputTensorInfo,
1198 descriptorName,
1199 "input_0",
1200 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001201}
1202
1203void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1204{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001205 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001206
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001207 ValidateNumInputs(workloadInfo, descriptorName, 1);
1208 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1209
1210 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1211 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001212
1213 std::vector<DataType> supportedTypes =
1214 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001215 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001216 DataType::Float16,
1217 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001218 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001219 DataType::QAsymmU8,
1220 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001221 };
1222
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001223 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1224 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001225
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001226 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001227 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001228
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001229 ValidatePointer(m_Mean, descriptorName, "mean");
1230 ValidatePointer(m_Variance, descriptorName, "variance");
1231 ValidatePointer(m_Beta, descriptorName, "beta");
1232 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001233
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001234 const TensorInfo& mean = m_Mean->GetTensorInfo();
1235 const TensorInfo& variance = m_Variance->GetTensorInfo();
1236 const TensorInfo& beta = m_Beta->GetTensorInfo();
1237 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001238
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001239 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1240 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1241 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1242 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001243
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001244 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1245 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1246 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001247}
1248
1249void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1250{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001251 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001252
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001253 uint32_t numInputs = 2;
1254 if (m_Parameters.m_BiasEnabled)
1255 {
1256 numInputs = 3;
1257 }
1258
1259 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001260 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001261
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001262 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1263 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001264
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001265 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1266 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001267
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001268 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
telsoa014fcda012018-03-09 14:13:49 +00001269
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001270 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001271
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001272 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001273
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001274 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001275 if (m_Parameters.m_BiasEnabled)
1276 {
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001277 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001278 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001279
1280 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Ryan OSheaf183acd2023-07-06 11:41:25 +01001281 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001282 }
1283
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001284 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1285 {
1286 throw InvalidArgumentException(
1287 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1288 "cannot be either negative or 0.",
1289 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1290 }
1291
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001292 ValidatePerAxisQuantization(inputTensorInfo,
1293 outputTensorInfo,
1294 weightTensorInfo,
1295 optionalBiasTensorInfo,
1296 descriptorName);
1297
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001298 std::vector<DataType> supportedTypes =
1299 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001300 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001301 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001302 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001303 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001304 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001305 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001306 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001307 };
1308
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001309 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001310
1311 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1312 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1313 {
1314 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1315 {
1316 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1317 "for BFloat16 input.");
1318 }
1319 }
1320 else
1321 {
1322 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1323 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001324}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001325
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001326void Convolution3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1327{
1328 const std::string descriptorName{"Convolution3dQueueDescriptor"};
1329
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001330 uint32_t numInputs = 2;
1331 if (m_Parameters.m_BiasEnabled)
1332 {
1333 numInputs = 3;
1334 }
1335 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001336 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1337
1338 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1339 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1340
1341 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1342 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1343
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001344 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001345 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 5, "weight");
1346
1347 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1348
1349 Optional<TensorInfo> optionalBiasTensorInfo;
1350 if (m_Parameters.m_BiasEnabled)
1351 {
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001352 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001353 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1354
1355 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Ryan OSheaf183acd2023-07-06 11:41:25 +01001356 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001357 }
1358
1359 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 || m_Parameters.m_StrideZ <= 0 )
1360 {
1361 throw InvalidArgumentException(
1362 fmt::format("{}: strideX (provided {}), strideY (provided {}) or strideZ (provided {})"
1363 "cannot be either negative or 0.",
1364 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY, m_Parameters.m_StrideZ));
1365 }
1366
1367 ValidatePerAxisQuantization(inputTensorInfo,
1368 outputTensorInfo,
1369 weightTensorInfo,
1370 optionalBiasTensorInfo,
1371 descriptorName);
1372
1373 std::vector<DataType> supportedTypes =
1374 {
1375 DataType::BFloat16,
1376 DataType::Float16,
1377 DataType::Float32,
1378 DataType::QAsymmS8,
1379 DataType::QAsymmU8,
1380 DataType::QSymmS16,
1381 DataType::QSymmS8
1382 };
1383
1384 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1385 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1386}
1387
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001388void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1389{
1390 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1391
Cathal Corbett06902652022-04-14 17:55:11 +01001392 uint32_t numInputs = 2;
1393 if (m_Parameters.m_BiasEnabled)
1394 {
1395 numInputs = 3;
1396 }
1397
1398 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001399 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1400
1401 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1402 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1403
1404 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1405 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1406
Cathal Corbett06902652022-04-14 17:55:11 +01001407 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001408 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1409
1410 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1411 {
1412 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001413 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1414 "cannot be smaller than 1.",
1415 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001416 }
1417
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001418 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1419 {
1420 throw InvalidArgumentException(
1421 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1422 "cannot be either negative or 0.",
1423 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1424 }
1425
Jan Eilers53ef7952021-06-02 12:01:25 +01001426 if (weightTensorInfo.GetShape()[0] != 1)
1427 {
1428 throw InvalidArgumentException(fmt::format(
1429 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1430 "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
1431 descriptorName,
1432 weightTensorInfo.GetShape()[0],
1433 weightTensorInfo.GetShape()[1],
1434 weightTensorInfo.GetShape()[2],
1435 weightTensorInfo.GetShape()[3]));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001436 }
1437
Cathal Corbett4b19d222022-05-11 20:12:17 +01001438 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1439 const unsigned int numWeightOutputChannelsRefFormat = weightTensorInfo.GetShape()[3];
1440 const unsigned int numWeightOutputChannelsAclFormat = weightTensorInfo.GetShape()[1];
1441 const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1442
1443 // Weights format has two valid options: [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] (CpuAcc/GpuAcc).
1444 bool validRefFormat = (numWeightOutputChannelsRefFormat == numOutputChannels);
1445 bool validAclFormat = (numWeightOutputChannelsAclFormat == numOutputChannels);
1446
1447 if (!(validRefFormat || validAclFormat))
1448 {
1449 throw InvalidArgumentException(fmt::format(
1450 "{0}: The weight format in armnn is expected to be [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] "
1451 "(CpuAcc/GpuAcc). But neither the 4th (CpuRef) or 2nd (CpuAcc/GpuAcc) dimension is equal to Cout."
1452 "Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
1453 descriptorName,
1454 numOutputChannels,
1455 weightTensorInfo.GetShape()[0],
1456 weightTensorInfo.GetShape()[1],
1457 weightTensorInfo.GetShape()[2],
1458 weightTensorInfo.GetShape()[3]));
1459 }
1460
Teresa Charlind8df0262019-11-11 12:28:15 +00001461 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001462
Teresa Charlind8df0262019-11-11 12:28:15 +00001463 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001464 if (m_Parameters.m_BiasEnabled)
1465 {
Cathal Corbett06902652022-04-14 17:55:11 +01001466 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Teresa Charlind8df0262019-11-11 12:28:15 +00001467 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001468
Ryan OSheaf183acd2023-07-06 11:41:25 +01001469 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001470 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1471 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001472 ValidatePerAxisQuantization(inputTensorInfo,
1473 outputTensorInfo,
1474 weightTensorInfo,
1475 optionalBiasTensorInfo,
1476 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001477
1478 std::vector<DataType> supportedTypes =
1479 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001480 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001481 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001482 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001483 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001484 DataType::QAsymmU8,
1485 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001486 };
1487
1488 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1489 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001490}
1491
1492void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1493{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001494 const std::string descriptorName{"PermuteQueueDescriptor"};
1495
1496 ValidateNumInputs(workloadInfo, descriptorName, 1);
1497 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001498
1499 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1500
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001501 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1502 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001503
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001504 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1505 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001506
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001507 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001508 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001509 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001510 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001511 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1512 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1513 "must match dst dimension " + to_string(mapping[i]) +
1514 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001515 }
1516 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001517
1518 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001519}
1520
1521void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1522{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001523 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001524
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001525 ValidateNumInputs(workloadInfo, descriptorName, 1);
1526 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1527
1528 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1529 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1530
1531 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1532 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001533
1534 std::vector<DataType> supportedTypes =
1535 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001536 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001537 DataType::Float32,
1538 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001539 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001540 DataType::QAsymmU8,
1541 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001542 };
1543
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001544 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1545 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001546}
1547
Tamás Nyíri7b885b32021-10-26 14:47:57 +01001548void Pooling3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1549{
1550 const std::string descriptorName{"Pooling3dQueueDescriptor"};
1551
1552 ValidateNumInputs(workloadInfo, descriptorName, 1);
1553 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1554
1555 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1556 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1557
1558 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1559 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1560
1561 std::vector<DataType> supportedTypes =
1562 {
1563 DataType::BFloat16,
1564 DataType::Float32,
1565 DataType::Float16,
1566 DataType::QAsymmS8,
1567 DataType::QAsymmU8,
1568 DataType::QSymmS16
1569 };
1570
1571 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1572 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1573}
1574
Teresa Charlin970f43b2019-07-01 13:51:07 +01001575void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1576{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001577 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001578
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001579 ValidateNumInputs(workloadInfo, descriptorName, 1);
1580 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1581
1582 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1583 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1584
1585 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1586 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001587
1588 std::vector<DataType> supportedTypes =
1589 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001590 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001591 DataType::Float16,
1592 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001593 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001594 DataType::QAsymmU8,
Teresa Charlince655882023-11-21 15:44:13 +00001595 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001596 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001597 };
1598
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001599 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1600 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001601
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001602 // Resize only changes width and height: batch and channel count must match.
1603 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1604 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001605 if (inputBatchSize != outputBatchSize)
1606 {
1607 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001608 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1609 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001610 }
1611
1612 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001613 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1614 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001615 if (inputChannelCount != outputChannelCount)
1616 {
1617 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001618 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1619 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001620 }
1621}
1622
Teresa Charlin79a06a52023-07-13 17:16:45 +01001623void ReverseV2QueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
1624{
Tianle Cheng988354d2023-06-28 13:20:47 +01001625 const std::string descriptorName{"ReverseV2QueueDescriptor"};
1626
Tracy Narinebb8d7592023-07-13 16:50:54 +01001627 // Backend restriction
1628 const unsigned int maxDimensions = 4;
1629
1630 ValidateNumInputs(workloadInfo, descriptorName, 2);
Tianle Cheng988354d2023-06-28 13:20:47 +01001631 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1632
1633 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
Tracy Narinebb8d7592023-07-13 16:50:54 +01001634 const TensorInfo& axisTensorInfo = workloadInfo.m_InputTensorInfos[1];
Tianle Cheng988354d2023-06-28 13:20:47 +01001635 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1636
Tracy Narinebb8d7592023-07-13 16:50:54 +01001637 const auto inputTensorNumDimensions = inputTensorInfo.GetNumDimensions();
1638 if (inputTensorNumDimensions > maxDimensions)
Tianle Cheng988354d2023-06-28 13:20:47 +01001639 {
1640 throw InvalidArgumentException(descriptorName +
1641 ": Input tensors with rank greater than " +
Tracy Narinebb8d7592023-07-13 16:50:54 +01001642 std::to_string(maxDimensions) + " are not supported.");
1643 }
1644
1645 const auto axisTensorNumDimensions = axisTensorInfo.GetNumDimensions();
1646 if (axisTensorNumDimensions > maxDimensions)
1647 {
1648 throw InvalidArgumentException(descriptorName +
1649 ": More than " + std::to_string(maxDimensions) + " axes cannot be specified.");
1650 }
1651
1652 if (axisTensorNumDimensions > inputTensorNumDimensions)
1653 {
1654 throw InvalidArgumentException(descriptorName +
1655 ": More axes specified than the number of axes on the input tensor.");
Tianle Cheng988354d2023-06-28 13:20:47 +01001656 }
1657
1658 std::vector<DataType> supportedTypes =
1659 {
1660 DataType::BFloat16,
1661 DataType::Float16,
1662 DataType::Float32,
1663 DataType::QAsymmS8,
1664 DataType::QAsymmU8,
Declan-ARM1bf56cd2023-07-20 17:32:57 +01001665 DataType::QSymmS8,
1666 DataType::QSymmS16,
1667 DataType::Signed32
Tianle Cheng988354d2023-06-28 13:20:47 +01001668 };
1669
1670 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Tracy Narinebb8d7592023-07-13 16:50:54 +01001671
1672 std::vector<DataType> axisSupportedTypes =
1673 {
1674 DataType::Signed32,
1675 };
1676
1677 ValidateDataTypes(axisTensorInfo, axisSupportedTypes, descriptorName);
1678
Tianle Cheng988354d2023-06-28 13:20:47 +01001679 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1680 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Tianle Cheng988354d2023-06-28 13:20:47 +01001681}
1682
telsoa014fcda012018-03-09 14:13:49 +00001683void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1684{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001685 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001686
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001687 ValidateNumInputs(workloadInfo, descriptorName, 1);
1688 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1689
1690 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1691 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1692
1693 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1694 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1695
1696 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1697
telsoa014fcda012018-03-09 14:13:49 +00001698 if (m_Parameters.m_Min > m_Parameters.m_Max)
1699 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001700 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001701 }
telsoa014fcda012018-03-09 14:13:49 +00001702}
1703
Kevin Mayce5045a2019-10-02 14:07:47 +01001704void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1705{
1706 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1707
1708 ValidateNumInputs(workloadInfo, descriptorName, 1);
1709 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1710
1711 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1712 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1713
1714 if (inputTensorInfo.GetNumDimensions() > 4)
1715 {
1716 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1717 }
1718
1719 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1720
1721 // Check the supported data types
1722 std::vector<DataType> supportedTypes =
1723 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001724 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001725 DataType::Float32,
1726 DataType::Float16
1727 };
1728
1729 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001730 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001731}
1732
telsoa014fcda012018-03-09 14:13:49 +00001733void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1734{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001735 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001736
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001737 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001738 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1739
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001740 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1741 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1742
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001743 if (inputTensorInfo.GetNumDimensions() > 4)
1744 {
1745 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1746 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001747
1748 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001749
1750 // Check the supported data types
1751 std::vector<DataType> supportedTypes =
1752 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001753 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001754 DataType::Float32,
1755 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001756 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001757 DataType::QAsymmU8,
1758 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001759 };
1760
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001761 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001762 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1763}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001764
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001765void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1766{
1767 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1768
1769 ValidateNumInputs(workloadInfo, descriptorName, 1);
1770 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1771
1772 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1773 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1774
1775 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1776
1777 std::vector<DataType> supportedTypes =
1778 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001779 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001780 DataType::Float32,
1781 DataType::Float16,
1782 };
1783
1784 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001785 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001786}
1787
1788void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1789{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001790 const std::string descriptorName{"ConstantQueueDescriptor"};
1791
1792 ValidateNumInputs(workloadInfo, descriptorName, 0);
1793 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001794
1795 if (!m_LayerOutput)
1796 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001797 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001798 }
1799
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001800 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1801 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001802
1803 // Check the supported data types
1804 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001805 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001806 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001807 DataType::Float32,
1808 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001809 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001810 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001811 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001812 DataType::QSymmS16,
1813 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001814 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001815
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001816 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001817}
1818
1819void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1820{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001821 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001822
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001823 ValidateNumInputs(workloadInfo, descriptorName, 1);
1824 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1825
1826 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1827 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1828
1829 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001830
1831 // Check the supported data types
1832 std::vector<DataType> supportedTypes =
1833 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001834 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001835 DataType::Float32,
1836 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001837 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001838 DataType::QAsymmU8,
1839 DataType::QSymmS16,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001840 DataType::Signed32,
1841 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001842 };
1843
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001844 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1845 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001846}
1847
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001848void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1849{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001850 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001851
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001852 ValidateNumInputs(workloadInfo, descriptorName, 1);
1853 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1854
1855 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1856 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1857
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001858 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1859 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001860 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1861 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001862 }
1863
Teresa Charlinf77cab52023-06-01 16:15:13 +01001864 if (m_Parameters.m_BlockShape.size() == 2)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001865 {
Teresa Charlinf77cab52023-06-01 16:15:13 +01001866 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1867 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1868 }
1869 else if (m_Parameters.m_BlockShape.size() == 1)
1870 {
1871 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 3, "input");
1872 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 3, "output");
1873 }
1874 else
1875 {
1876 throw InvalidArgumentException(descriptorName + ": Invalid Block and Crops size.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001877 }
1878
Teresa Charlinf77cab52023-06-01 16:15:13 +01001879 // Check input + padding and output have the same number of elements
1880 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1881 const unsigned int inputHeight = inputTensorInfo.GetShape()[dimensionIndices.GetHeightIndex()] +
1882 m_Parameters.m_PadList[0].first + m_Parameters.m_PadList[0].second;
1883 const unsigned int inputWidth = (inputTensorInfo.GetNumDimensions() == 3) ? 1 :
1884 inputTensorInfo.GetShape()[dimensionIndices.GetWidthIndex()] +
1885 m_Parameters.m_PadList[1].first + m_Parameters.m_PadList[1].second;
1886
1887 const int channelsIndex_int = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : -1;
1888 const unsigned int channelsIndex = channelsIndex_int < 0 ?
1889 static_cast<unsigned int>(channelsIndex_int) + inputTensorInfo.GetNumDimensions()
1890 : static_cast<unsigned int>(channelsIndex_int);
1891
1892 const unsigned int numInputElements = inputTensorInfo.GetShape()[0] *
1893 inputHeight *
1894 inputWidth *
1895 inputTensorInfo.GetShape()[channelsIndex];
1896
1897 if (outputTensorInfo.GetNumElements() != numInputElements)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001898 {
Teresa Charlinf77cab52023-06-01 16:15:13 +01001899 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
1900 to_string(numInputElements) + " after padding but output tensor has " +
1901 to_string(outputTensorInfo.GetNumElements()) + " elements.");
1902 }
1903
1904 // In a 4D tensor, there will be 2 spatialDimensions (H and W), and the for loop will run twice.
1905 // In a 3D tensor, there will be 1 spatialDimensions, and the for loop will run once.
1906 unsigned int firstSpatialDimension = m_Parameters.m_DataLayout == DataLayout::NCHW ? 2 : 1;
1907 for (unsigned int i = 0; i < m_Parameters.m_BlockShape.size(); ++i)
1908 {
1909 unsigned int spatialDimension = firstSpatialDimension + i;
1910 auto inputSize = inputTensorInfo.GetShape()[spatialDimension] +
1911 m_Parameters.m_PadList[i].first +
1912 m_Parameters.m_PadList[i].second;
1913 if (inputSize % m_Parameters.m_BlockShape[i] != 0)
1914 {
1915 throw InvalidArgumentException(descriptorName + ": Input dimension size after padding must be "
1916 "divisible by Block Shape in dimension: " + to_string(spatialDimension) + ".");
1917 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001918 }
nikraj01120522a2019-05-31 11:33:07 +01001919
1920 std::vector<DataType> supportedTypes =
1921 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001922 DataType::BFloat16,
1923 DataType::Float16,
1924 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001925 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001926 DataType::QAsymmU8,
1927 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001928 };
1929
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001930 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1931 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001932}
1933
Keith Davisa57eccb2019-06-14 17:33:22 +01001934void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1935{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001936 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001937
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001938 ValidateNumInputs(workloadInfo, descriptorName, 1);
1939 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001940
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001941 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1942 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1943
1944 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1945 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001946
1947 std::vector<DataType> supportedTypes =
1948 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001949 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001950 DataType::Float32,
1951 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001952 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001953 DataType::QAsymmU8,
1954 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001955 };
1956
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001957 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1958 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001959
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001960 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1961
1962 if (m_Parameters.m_BlockSize == 0)
1963 {
1964 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1965 }
1966
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001967 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1968 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1969 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1970 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001971
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001972 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001973 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001974 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001975 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1976 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001977 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001978
1979 const TensorShape& outputShape = outputTensorInfo.GetShape();
1980 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1981 {
1982 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1983 "must be divisible by the square of block size." );
1984 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001985}
1986
telsoa014fcda012018-03-09 14:13:49 +00001987void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1988{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001989 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001990
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001991 ValidateNumInputs(workloadInfo, descriptorName, 1);
1992 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1993
1994 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1995 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001996
1997 std::vector<DataType> supportedTypes =
1998 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001999 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002000 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002001 DataType::Float16,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01002002 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01002003 };
2004
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002005 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matthew Sloyan81beae32021-07-13 19:46:11 +01002006 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2007 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2008 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00002009}
2010
telsoa01c577f2c2018-08-31 09:22:23 +01002011void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2012{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002013 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
2014
2015 const std::string descriptorName{"LstmQueueDescriptor"};
2016
2017 // check dimensions of all inputs and outputs
2018 if (workloadInfo.m_InputTensorInfos.size() != 3)
2019 {
2020 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
2021 }
2022 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2023 {
2024 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
2025 }
2026
2027 std::vector<DataType> supportedTypes =
2028 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002029 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01002030 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002031 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002032 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002033 };
2034
Jan Eilers38e05bd2019-06-26 13:10:09 +01002035 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002036 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
2037
Jan Eilers38e05bd2019-06-26 13:10:09 +01002038 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002039 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01002040 {
2041 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2042 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002043 descriptorName,
2044 "input_0",
2045 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01002046 }
2047 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002048 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01002049 {
2050 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2051 workloadInfo.m_OutputTensorInfos[i],
2052 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002053 "input_0",
2054 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01002055 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002056
janeil0117d8d852019-11-15 15:00:16 +00002057 // Making sure clipping parameters have valid values.
2058 // == 0 means no clipping
2059 // > 0 means clipping
2060 if (m_Parameters.m_ClippingThresCell < 0.0f)
2061 {
2062 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
2063 }
2064 if (m_Parameters.m_ClippingThresProj < 0.0f)
2065 {
2066 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
2067 }
2068
Jan Eilers38e05bd2019-06-26 13:10:09 +01002069 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01002070 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
2071 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
2072 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
2073 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
2074 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
2075 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
2076
Jan Eilers38e05bd2019-06-26 13:10:09 +01002077 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002078 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
2079 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002080 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002081 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
2082 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002083 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002084 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
2085 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002086 // scratchBufferTensor
2087 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002088 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
2089 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002090 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002091 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
2092 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002093 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002094 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
2095 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002096 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002097 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
2098 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002099
Jan Eilers38e05bd2019-06-26 13:10:09 +01002100 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
2101 if ( m_InputToInputWeights )
2102 {
2103 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
2104 (n_cell * n_input), "InputLayerNormWeights");
2105 }
2106
2107 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2108 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
2109 (n_cell * n_input), "InputToForgetWeights");
2110
2111 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2112 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2113 (n_cell * n_input), "InputToCellWeights");
2114
2115 if ( m_RecurrentToInputWeights )
2116 {
2117 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2118 (n_cell * n_output), "RecurrentToInputWeights");
2119 }
2120
2121 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2122 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2123 (n_cell * n_output), "RecurrentToForgetWeights");
2124
2125 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2126 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2127 (n_cell * n_output), "RecurrentToCellWeights");
2128
2129 // Make sure the input-gate's parameters are either both present (regular
2130 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2131 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2132 !m_Parameters.m_CifgEnabled) ||
2133 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2134 m_Parameters.m_CifgEnabled));
2135 if (!cifg_weights_all_or_none)
2136 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002137 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2138 "RecurrentToInputWeights must either both be present (regular LSTM) "
2139 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2140 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002141 }
2142
2143 if ( m_CellToInputWeights )
2144 {
2145 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2146 n_cell, "CellToInputWeights");
2147 }
2148 if ( m_CellToForgetWeights )
2149 {
2150 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2151 n_cell, "CellToForgetWeights");
2152 }
2153 if ( m_CellToOutputWeights )
2154 {
2155 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2156 n_cell, "CellToOutputWeights");
2157 }
2158
2159 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2160 bool peephole_weights_all_or_none =
2161 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2162 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2163 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2164 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2165 if (!peephole_weights_all_or_none)
2166 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002167 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002168 }
2169
2170 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2171 if (m_Parameters.m_CifgEnabled)
2172 {
2173 if (m_InputGateBias)
2174 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002175 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002176 }
2177 }
2178 else
2179 {
2180 if (!m_InputGateBias)
2181 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002182 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2183 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002184 }
2185 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2186 n_cell, "InputGateBias");
2187 }
2188
2189 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2190 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2191
2192 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2193 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2194
2195 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2196 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2197
2198 if (m_ProjectionWeights)
2199 {
2200 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2201 (n_cell * n_output), "ProjectionWeights");
2202 }
2203 if (m_ProjectionBias)
2204 {
2205 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2206 }
2207
2208 // Making sure the projection tensors are consistent:
2209 // 1) If projection weight is not present, then projection bias should not be
2210 // present.
2211 // 2) If projection weight is present, then projection bias is optional.
2212 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2213 !m_Parameters.m_ProjectionEnabled)
2214 || (m_ProjectionWeights && !m_ProjectionBias &&
2215 m_Parameters.m_ProjectionEnabled)
2216 || (m_ProjectionWeights && m_ProjectionBias &&
2217 m_Parameters.m_ProjectionEnabled));
2218 if (!projecton_tensors_consistent)
2219 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002220 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002221 }
2222
2223 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2224 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2225 // either all have values or none of them have values. Layer normalization is used when the values of all the
2226 // layer normalization weights are present
2227 if (m_InputLayerNormWeights)
2228 {
2229 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2230 }
2231 if (m_ForgetLayerNormWeights)
2232 {
2233 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2234 }
2235 if (m_CellLayerNormWeights)
2236 {
2237 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2238 }
2239 if (m_OutputLayerNormWeights)
2240 {
2241 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2242 }
2243
Jan Eilers38e05bd2019-06-26 13:10:09 +01002244 if (m_Parameters.m_LayerNormEnabled)
2245 {
2246 if (!m_Parameters.m_CifgEnabled)
2247 {
2248 if (!m_InputLayerNormWeights)
2249 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002250 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2251 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002252 }
2253 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2254 1, n_cell, "InputLayerNormWeights");
2255 }
2256 else if (m_InputLayerNormWeights)
2257 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002258 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2259 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002260 }
2261
2262 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2263 "ForgetLayerNormWeights");
2264 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2265
2266 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2267 "OutputLayerNormWeights");
2268 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2269
2270 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2271 "CellLayerNormWeights");
2272 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2273 }
2274 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2275 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002276 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2277 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002278 }
telsoa01c577f2c2018-08-31 09:22:23 +01002279}
2280
2281void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2282{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002283 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002284
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002285 ValidateNumInputs(workloadInfo, descriptorName, 1);
2286 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2287
2288 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2289 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2290
2291 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002292 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002293 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002294 }
2295
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002296 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002297 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002298 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002299 }
2300
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002301 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002302}
2303
2304void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2305{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002306 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002307
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002308 ValidateNumInputs(workloadInfo, descriptorName, 1);
2309 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2310
2311 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2312 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2313
2314 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002315 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002316 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002317 }
2318
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002319 if (outputTensorInfo.GetDataType() != DataType::Float32)
2320 {
2321 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2322 }
2323
2324 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002325}
2326
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002327void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2328{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002329 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002330
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002331 ValidateNumInputs(workloadInfo, descriptorName, 2);
2332 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2333
2334 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2335 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2336 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2337
2338 std::vector<DataType> supportedTypes =
2339 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002340 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002341 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002342 DataType::Float32,
2343 DataType::QAsymmS8,
2344 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002345 DataType::QSymmS16,
2346 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002347 };
2348
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002349 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2350 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2351 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002352
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002353 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2354 inputTensorInfo1,
2355 outputTensorInfo,
2356 descriptorName,
2357 "input_0",
2358 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002359}
2360
David Beckc2044fe2018-09-05 15:00:38 +01002361void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2362{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002363 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002364
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002365 ValidateNumInputs(workloadInfo, descriptorName, 2);
2366 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2367
2368 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2369 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2370 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2371
2372 std::vector<DataType> supportedTypes =
2373 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002374 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002375 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002376 DataType::Float32,
2377 DataType::QAsymmS8,
2378 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002379 DataType::QSymmS16,
2380 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002381 };
2382
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002383 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2384 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2385 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002386
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002387 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2388 inputTensorInfo1,
2389 outputTensorInfo,
2390 descriptorName,
2391 "input_0",
2392 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002393}
2394
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002395void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2396{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002397 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002398
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002399 ValidateNumInputs(workloadInfo, descriptorName, 2);
2400 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2401
2402 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2403 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2404 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2405
2406 std::vector<DataType> supportedTypes =
2407 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002408 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002409 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002410 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002411 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002412 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002413 DataType::QSymmS16,
2414 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002415 };
2416
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002417 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2418 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2419 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002420
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002421 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2422 inputTensorInfo1,
2423 outputTensorInfo,
2424 descriptorName,
2425 "input_0",
2426 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002427}
2428
narpra01a6bf9122018-09-10 09:50:09 +01002429void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2430{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002431 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002432
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002433 ValidateNumInputs(workloadInfo, descriptorName, 1);
2434 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2435
2436 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2437 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002438
2439 std::vector<DataType> supportedTypes =
2440 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002441 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002442 DataType::Float32,
2443 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002444 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002445 DataType::QAsymmU8,
2446 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002447 };
narpra01eb061912018-09-10 17:35:27 +01002448
James Conroy4d1ff582019-06-10 17:06:39 +01002449 // First check if input tensor data type is supported, then
2450 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002451 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2452 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002453
narpra0132b90462018-09-13 11:07:48 +01002454 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002455 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002456 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002457 }
narpra0132b90462018-09-13 11:07:48 +01002458 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002459 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002460 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002461 }
2462 else
2463 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002464 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002465 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002466 ValidateTensorNumDimensions(outputTensorInfo,
2467 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002468 outputDim > 0 ? outputDim : 1,
2469 "output");
2470 }
narpra01a6bf9122018-09-10 09:50:09 +01002471}
2472
jimfly012c9322a2018-09-19 10:59:49 +01002473void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2474{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002475 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002476
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002477 ValidateNumInputs(workloadInfo, descriptorName, 1);
2478 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2479
2480 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2481 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002482
jimfly012c9322a2018-09-19 10:59:49 +01002483 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002484 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2485
jimfly012c9322a2018-09-19 10:59:49 +01002486 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002487 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2488 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2489 "as there are dimensions in the input tensor that is " +
2490 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2491 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002492 }
2493}
2494
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002495void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2496{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002497 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002498
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002499 ValidateNumInputs(workloadInfo, descriptorName, 1);
2500 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002501
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002502 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2503 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2504
Sadik Armagan2208b602019-07-31 16:36:27 +01002505 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002506 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002507 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002508 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002509 DataType::Float16,
2510 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002511 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002512 DataType::QAsymmU8,
2513 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002514 };
2515
2516 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002517
Keith Davis0c2eeac2020-02-11 16:51:50 +00002518 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002519 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002520 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002521 }
2522}
2523
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002524void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2525{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002526 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002527
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002528 ValidateNumInputs(workloadInfo, descriptorName, 1);
2529 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002530
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002531 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2532 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002533
Teresa Charlinf77cab52023-06-01 16:15:13 +01002534 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_Crops.size())
2535 {
2536 throw InvalidArgumentException(descriptorName + ": Crops must contain the same number of "
2537 "dimensions as Block Shape.");
2538 }
2539
2540 if (m_Parameters.m_BlockShape.size() == 2)
2541 {
2542 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2543 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
2544 }
2545 else if (m_Parameters.m_BlockShape.size() == 1)
2546 {
2547 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 3, "input");
2548 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 3, "output");
2549 }
2550 else
2551 {
2552 throw InvalidArgumentException(descriptorName + ": Invalid Block and Crops size.");
2553 }
2554
2555 // In a 4D tensor, there will be 2 spatialDimensions (H and W), and the for loop will run twice.
2556 // In a 3D tensor, there will be 1 spatialDimensions, and the for loop will run once.
2557 unsigned int firstSpatialDimension = m_Parameters.m_DataLayout == DataLayout::NCHW ? 2 : 1;
2558 for (unsigned int i = 0; i < m_Parameters.m_BlockShape.size(); ++i)
2559 {
2560 unsigned int spatialDimension = firstSpatialDimension + i;
2561 unsigned int cropSize = m_Parameters.m_Crops[i].first + m_Parameters.m_Crops[i].second;
2562 unsigned int outputSize = inputTensorInfo.GetShape()[spatialDimension] * m_Parameters.m_BlockShape[i];
2563 if (cropSize > outputSize)
2564 {
2565 throw InvalidArgumentException(descriptorName + ": CropSize must be less than or equal to the uncropped"
2566 "outputSize in dimension: " + to_string(spatialDimension) + ".");
2567 }
2568 }
2569
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002570 std::vector<DataType> supportedTypes =
2571 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002572 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002573 DataType::Float32,
2574 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002575 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002576 DataType::QAsymmU8,
2577 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002578 };
2579
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002580 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2581 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002582}
2583
Conor Kennedy430b5d82018-11-14 15:28:28 +00002584void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2585{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002586 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002587
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002588 ValidateNumInputs(workloadInfo, descriptorName, 1);
2589 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2590
2591 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2592 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002593
2594 std::vector<DataType> supportedTypes =
2595 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002596 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002597 DataType::Float16,
2598 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002599 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002600 DataType::QAsymmU8,
2601 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002602 };
2603
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002604 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2605 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002606
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002607 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002608
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002609 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002610 if (rank > 4)
2611 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002612 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002613 }
2614
Conor Kennedy430b5d82018-11-14 15:28:28 +00002615 // Begin, End & Stride length must be of rank(input0)
2616 if (m_Parameters.m_Begin.size() != rank)
2617 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002618 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002619 }
2620
2621 if (m_Parameters.m_End.size() != rank)
2622 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002623 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002624 }
2625
2626 if (m_Parameters.m_Stride.size() != rank)
2627 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002628 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002629 }
2630
2631 // Stride entries must be non-zero
2632 for (auto& stride : m_Parameters.m_Stride)
2633 {
2634 if (stride == 0)
2635 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002636 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002637 }
2638 }
2639}
2640
kevmay0190539692018-11-29 08:40:19 +00002641void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2642{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002643 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002644
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002645 ValidateNumInputs(workloadInfo, descriptorName, 2);
2646 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2647
2648 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2649 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2650 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2651
2652 std::vector<DataType> supportedTypes =
2653 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002654 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002655 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002656 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002657 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002658 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002659 DataType::QSymmS16,
2660 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002661 };
2662
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002663 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2664 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2665 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002666
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002667 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2668 inputTensorInfo1,
2669 outputTensorInfo,
2670 descriptorName,
2671 "input_0",
2672 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002673}
2674
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002675void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2676{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002677 const std::string descriptorName{"DebugQueueDescriptor"};
2678
2679 ValidateNumInputs(workloadInfo, descriptorName, 1);
2680 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002681}
2682
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002683void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2684{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002685 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002686
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002687 ValidateNumInputs(workloadInfo, descriptorName, 2);
2688 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002689
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002690 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2691 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2692 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2693
2694 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2695 inputTensorInfo1,
2696 outputTensorInfo,
2697 descriptorName,
2698 "input_0",
2699 "input_1");
2700
2701 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002702 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002703 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002704 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002705}
2706
FrancisMurtagh878f0232018-12-19 10:56:15 +00002707void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2708{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002709 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002710
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002711 ValidateNumInputs(workloadInfo, descriptorName, 2);
2712 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002713
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002714 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2715 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2716 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2717
2718 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2719 inputTensorInfo1,
2720 outputTensorInfo,
2721 descriptorName,
2722 "input_0",
2723 "input_1");
2724
2725 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002726 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002727 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002728 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002729}
2730
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002731void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2732{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002733 const std::string descriptorName{"RsqrtQueueDescriptor"};
2734
2735 ValidateNumInputs(workloadInfo, descriptorName, 1);
2736 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2737
2738 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2739 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2740
2741 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002742
2743 std::vector<DataType> supportedTypes =
2744 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002745 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002746 DataType::Float16,
2747 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002748 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002749 DataType::QAsymmU8,
2750 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002751 };
2752
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002753 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2754 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002755}
2756
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01002757void GatherNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2758{
2759 const std::string descriptorName{"GatherNdQueueDescriptor"};
2760
2761 ValidateNumInputs(workloadInfo, descriptorName, 2);
2762 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2763
2764 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2765 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
2766 {
2767 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
2768 }
2769
2770 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2771 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2772
2773 std::vector<DataType> supportedTypes =
2774 {
2775 DataType::BFloat16,
2776 DataType::Float16,
2777 DataType::Float32,
2778 DataType::QAsymmS8,
2779 DataType::QAsymmU8,
2780 DataType::QSymmS16,
2781 DataType::Signed32,
2782 };
2783
2784 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2785
2786 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2787
2788 unsigned int outputDim = outputTensorInfo.GetNumDimensions();
2789 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
2790}
2791
narpra01b89b05f2019-01-16 09:53:09 +00002792void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2793{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002794 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002795
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002796 ValidateNumInputs(workloadInfo, descriptorName, 2);
2797 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002798
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002799 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2800 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002801 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002802 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002803 }
2804
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002805 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2806 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2807
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002808 std::vector<DataType> supportedTypes =
2809 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002810 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002811 DataType::Float16,
2812 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002813 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002814 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002815 DataType::QSymmS16,
2816 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002817 };
2818
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002819 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002820
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002821 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002822
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002823 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2824 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002825}
2826
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002827void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2828{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002829 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2830
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002831 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002832
2833 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2834 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002835 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002836 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2837 }
2838
2839 if (m_Anchors == nullptr)
2840 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002841 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002842 }
2843
2844 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002845 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2846 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2847
2848 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002849 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002850 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2851 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002852
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002853 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2854 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2855 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002856
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002857 const std::vector<DataType> supportedInputTypes =
2858 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002859 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002860 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002861 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002862 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002863 DataType::QAsymmU8,
2864 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002865 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002866
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002867 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2868 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2869 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2870
2871 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2872 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2873 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2874 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2875
2876 // NOTE: Output is always Float32 regardless of input type
2877 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2878 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2879 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2880 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002881
2882 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2883 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002884 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002885 "must be positive and less than or equal to 1.");
2886 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002887
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002888 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2889 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002890 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002891 "should be equal to number of classes + 1.");
2892 }
2893}
2894
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002895void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2896{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002897 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002898
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002899 ValidateNumInputs(workloadInfo, descriptorName, 1);
2900 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2901
2902 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2903 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2904
Teresa Charlin07307f32022-05-15 14:07:05 +01002905 std::vector<DataType> inputSupportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002906 {
Teresa Charlin07307f32022-05-15 14:07:05 +01002907 DataType::QAsymmS8,
2908 DataType::QAsymmU8,
2909 DataType::QSymmS8,
2910 DataType::QSymmS16,
2911 DataType::Float16
2912 };
2913 ValidateDataTypes(inputTensorInfo, inputSupportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002914
Teresa Charlin07307f32022-05-15 14:07:05 +01002915 std::vector<DataType> outputSupportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002916 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002917 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002918 DataType::Float32,
2919 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002920 };
2921
Teresa Charlin07307f32022-05-15 14:07:05 +01002922 ValidateDataTypes(outputTensorInfo, outputSupportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002923}
2924
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002925void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2926{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002927 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002928
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002929 ValidateNumInputs(workloadInfo, descriptorName, 2);
2930 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002931
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002932 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2933 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2934 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002935
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002936 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2937 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2938
2939 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2940 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002941}
2942
Keith Davis3ae3f972021-05-21 16:33:48 +01002943void ShapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2944{
2945 const std::string& descriptorName{"ShapeQueueDescriptor"};
2946
2947 ValidateNumInputs(workloadInfo, descriptorName, 1);
2948 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2949
2950 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2951 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2952
2953 std::vector<DataType> supportedTypes =
2954 {
2955 DataType::BFloat16,
2956 DataType::Float16,
2957 DataType::Float32,
2958 DataType::QAsymmS8,
2959 DataType::QAsymmU8,
Keith Davis3ae3f972021-05-21 16:33:48 +01002960 DataType::QSymmS8,
2961 DataType::QSymmS16,
2962 DataType::Signed32
2963 };
2964
2965 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2966 ValidateDataTypes(outputTensorInfo, {DataType::Signed32}, descriptorName);
2967}
2968
Sadik Armaganeff363d2019-04-05 15:25:46 +01002969void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2970{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002971 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002972
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002973 ValidateNumInputs(workloadInfo, descriptorName, 2);
2974 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2975
2976 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2977 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2978
2979 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2980 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2981
2982 std::vector<DataType> supportedTypes =
2983 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002984 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002985 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002986 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002987 DataType::QAsymmU8,
2988 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002989 };
2990
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002991 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2992 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002993
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002994 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2995 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002996
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002997 ValidateTensorShapesMatch(inputTensorInfo0,
2998 outputTensorInfo0,
2999 descriptorName,
3000 "input_0",
3001 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01003002
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003003 ValidateTensorShapesMatch(inputTensorInfo0,
3004 outputTensorInfo1,
3005 descriptorName,
3006 "input_0",
3007 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01003008}
3009
Derek Lamberti901ea112019-12-10 22:07:09 +00003010void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00003011{
Teresa Charlin9145e382023-08-17 18:44:58 +01003012 // This is internally generated, so it should not need validation.
Matteo Martincigh49124022019-01-11 13:25:59 +00003013}
3014
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003015void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3016{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003017 const std::string& descriptorName{"PreluQueueDescriptor"};
3018
3019 ValidateNumInputs(workloadInfo, descriptorName, 2);
3020 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3021
3022 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3023 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
3024 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003025
3026 std::vector<DataType> supportedTypes
3027 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003028 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003029 DataType::Float16,
3030 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003031 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003032 DataType::QAsymmU8,
3033 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003034 };
3035
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003036 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3037 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003038
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003039 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003040
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003041 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
3042 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003043
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003044 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
3045 alphaTensorInfo,
3046 outputTensorInfo,
3047 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003048 "input",
3049 "alpha");
3050}
3051
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003052void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3053{
3054 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
3055
3056 ValidateNumInputs(workloadInfo, descriptorName, 1);
3057 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3058
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003059 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3060 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3061
3062 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
3063 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003064
3065 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003066
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003067 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
3068 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003069
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003070 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
3071
3072 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003073 if (m_Parameters.m_BiasEnabled)
3074 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003075 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003076
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003077 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
3078 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003079
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003080 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Ryan OSheaf183acd2023-07-06 11:41:25 +01003081 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003082 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003083
3084 ValidatePerAxisQuantization(inputTensorInfo,
3085 outputTensorInfo,
3086 weightTensorInfo,
3087 optionalBiasTensorInfo,
3088 descriptorName);
3089
3090 std::vector<DataType> supportedTypes =
3091 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003092 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003093 DataType::Float32,
3094 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003095 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003096 DataType::QAsymmU8,
3097 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003098 };
3099
3100 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3101 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003102}
3103
Mike Kellyc9ea45a2020-02-28 18:11:58 +00003104void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3105{
3106 const std::string descriptorName{"TransposeQueueDescriptor"};
3107
3108 ValidateNumInputs(workloadInfo, descriptorName, 1);
3109 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3110
3111 const PermutationVector& mapping = m_Parameters.m_DimMappings;
3112
3113 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3114 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3115
3116 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
3117 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
3118
3119 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
3120 {
3121 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
3122 {
3123 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
3124 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
3125 "must match dst dimension " + to_string(i) +
3126 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
3127 }
3128 }
3129
3130 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3131}
3132
Simon Obute51f67772021-09-03 15:50:13 +01003133void ChannelShuffleQueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
3134{
3135 const std::string descriptorName{"TransposeQueueDescriptor"};
3136
3137 ValidateNumInputs(workloadInfo, descriptorName, 1);
3138 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3139
3140 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3141 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3142
3143 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3144}
3145
James Conroy4f1f8992020-04-29 20:01:10 +01003146void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3147{
3148 const std::string descriptorName{"QLstmQueueDescriptor"};
3149
3150 // Validate number of inputs/outputs
3151 ValidateNumInputs(workloadInfo, descriptorName, 3);
3152 ValidateNumOutputs(workloadInfo, descriptorName, 3);
3153
3154 // Input/output tensor info
3155 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3156 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
3157 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
3158
3159 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3160 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3161 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
3162
3163 // Supported types for various tensors in QLSTM
3164 std::vector<DataType> inputOutputSupportedTypes =
3165 {
3166 DataType::QAsymmS8
3167 };
3168
3169 std::vector<DataType> cellStateSupportedTypes =
3170 {
3171 DataType::QSymmS16
3172 };
3173
3174 std::vector<DataType> weightsSupportedTypes =
3175 {
3176 DataType::QSymmS8
3177 };
3178
3179 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3180 {
3181 DataType::QSymmS16
3182 };
3183
3184 std::vector<DataType> biasSupportedTypes =
3185 {
3186 DataType::Signed32
3187 };
3188
3189 // Validate types of input/output tensors
3190 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3191 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3192 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3193
3194 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3195 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3196 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3197
3198 // Validate matching types of input/output tensors
3199 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3200 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3201 "outputStateIn", "outputStateOut");
3202 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3203
3204 // Infer number of batches, number of units, input size and output size from tensor dimensions
3205 const uint32_t numBatches = inputInfo.GetShape()[0];
3206 const uint32_t inputSize = inputInfo.GetShape()[1];
3207 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3208 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3209
3210 // Validate number of dimensions and number of elements for input/output tensors
3211 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3212 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3213 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3214
3215 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3216 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3217 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3218
3219 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3220 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3221 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3222 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3223
3224 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3225 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3226 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3227
3228 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3229 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3230 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3231
3232 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3233 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3234 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3235 " RecurrentToForgetWeights");
3236
3237 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3238 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3239 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3240
3241 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3242 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3243 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3244
3245 // Validate data types for MANDATORY weights tensors (all should match each other)
3246 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3247
3248 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3249 "inputToForgetWeights", "inputToCellWeights");
3250 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3251 "inputToForgetWeights", "inputToOutputWeights");
3252
3253 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3254 "inputToForgetWeights", "recurrentToForgeteights");
3255 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3256 "inputToForgetWeights", "recurrentToCellWeights");
3257 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3258 "inputToForgetWeights", "recurrentToOutputWeights");
3259
3260 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3261 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3262 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3263 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3264
3265 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3266 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3267 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3268
3269 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3270 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3271 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3272
3273 // Validate data types for MANDATORY bias tensors
3274 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3275
3276 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3277 "forgetGateBias", "cellBias");
3278 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3279 "forgetGateBias", "outputGateBias");
3280
3281 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3282 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3283 !m_Parameters.m_CifgEnabled) ||
3284 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3285 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3286
3287 if (!allCifgParamsPresentOrNot)
3288 {
3289 throw InvalidArgumentException(descriptorName +
3290 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3291 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3292 "set appropriately.");
3293 }
3294
3295 if (!m_Parameters.m_CifgEnabled)
3296 {
3297 // Validate number of dimensions and number of elements
3298 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3299 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3300
3301 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3302 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3303 " RecurrentToInputWeights");
3304
3305 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3306 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3307
3308 // Validate data types
3309 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3310 "inputToForgetWeights", "inputToInputWeights");
3311 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3312 "inputToForgetWeights", "recurrentToInputWeights");
3313 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3314 "forgetGateBias", "inputGateBias");
3315 }
3316
3317 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3318 bool allPeepholeWeightsPresentOrNot =
3319 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3320 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3321 || (!m_CellToInputWeights && !m_CellToForgetWeights
3322 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3323
3324 if (!allPeepholeWeightsPresentOrNot)
3325 {
3326 throw InvalidArgumentException(descriptorName +
3327 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3328 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3329 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3330 "appropriately.");
3331 }
3332
3333 if (m_Parameters.m_PeepholeEnabled)
3334 {
3335 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3336 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3337 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3338
3339 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3340 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3341 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3342 "cellToForgetWeight", "cellToOutputWeights");
3343
3344 if (!m_Parameters.m_CifgEnabled)
3345 {
3346 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3347 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3348 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3349 "cellToForgetWeights", "cellToInputWeights");
3350 }
3351 }
3352
3353 // Validate OPTIONAL params: Layer Norm Weights
3354 bool allLayerNormWeightsPresentOrNot =
3355 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3356 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3357 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3358 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3359
3360 if (!allLayerNormWeightsPresentOrNot)
3361 {
3362 throw InvalidArgumentException(descriptorName +
3363 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3364 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3365 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3366 "only be present when Layer Norm is enabled and CIFG is disabled. "
3367 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3368 }
3369
3370 if (m_Parameters.m_LayerNormEnabled)
3371 {
3372 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3373 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3374 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3375
3376 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3377 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3378 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3379 "forgetLayerNormWeights", "cellLayerNormWeights");
3380
3381 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3382 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3383 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3384 "forgetLayerNormWeights", "outputLayerNormWeights");
3385
3386 if (!m_Parameters.m_CifgEnabled)
3387 {
3388 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3389 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3390 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3391 "forgetLayerNormWeights", "inputLayerNormWeights");
3392 }
3393 }
3394
3395 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3396 bool correctProjectionTensorsPresent =
3397 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3398 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3399 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3400
3401 if (!correctProjectionTensorsPresent)
3402 {
3403 throw InvalidArgumentException(descriptorName +
3404 ": If projection is enabled, ProjectionWeights should be present and "
3405 "ProjectionBias is optional. If projection is disabled, neither "
3406 "ProjectionWeights nor ProjectionBias should be present.");
3407 }
3408
3409 if (m_Parameters.m_ProjectionEnabled)
3410 {
3411 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3412 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3413 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3414
3415 if (m_ProjectionBias)
3416 {
3417 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003418 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003419 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3420 }
3421
3422 }
3423 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3424 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3425 throw InvalidArgumentException(descriptorName +
3426 ": If projection is disabled, output quantization info (scale, offset) "
3427 "should match HiddenStateScale and HiddenStateZeroPoint.");
3428 }
3429
3430}
3431
James Conroy9c3cae82019-08-01 16:01:48 +01003432void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3433{
3434 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3435
3436 // Validate number of inputs/outputs
3437 ValidateNumInputs(workloadInfo, descriptorName, 3);
3438 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3439
3440 // Input/output tensor infos
3441 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3442 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3443 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3444
3445 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3446 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3447
3448 std::vector<DataType> inputOutputSupportedTypes =
3449 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003450 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003451 };
3452
3453 std::vector<DataType> cellStateSupportedTypes =
3454 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003455 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003456 };
3457
3458 std::vector<DataType> weightsSupportedTypes =
3459 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003460 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003461 };
3462
3463 std::vector<DataType> biasSupportedTypes =
3464 {
3465 DataType::Signed32
3466 };
3467
3468 // Validate types of input/output tensors
3469 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3470 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3471 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3472
3473 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3474 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3475
3476 // Validate matching types of input/output tensors
3477 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3478 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3479 "outputStateIn", "outputStateOut");
3480 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3481
3482 // Validate matching quantization info for input/output tensors
3483 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3484 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3485 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003486
James Conroy9c3cae82019-08-01 16:01:48 +01003487 // Infer number of batches, input size and output size from tensor dimensions
3488 const uint32_t numBatches = inputInfo.GetShape()[0];
3489 const uint32_t inputSize = inputInfo.GetShape()[1];
3490 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3491
3492 // Validate number of dimensions and number of elements for input/output tensors
3493 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3494 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3495 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3496 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3497 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3498
3499 // Validate number of dimensions and number of elements for weights tensors
3500 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3501 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3502 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3503
3504 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3505 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3506 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3507
3508 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3509 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3510 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3511
3512 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3513 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3514 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3515
3516 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3517 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3518 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3519
3520 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3521 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3522 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3523 " RecurrentToForgetWeights");
3524
3525 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3526 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3527 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3528
3529 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3530 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3531 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3532
3533 // Validate data types for weights tensors (all should match each other)
3534 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3535
3536 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3537 "inputToInputWeights", "inputToForgetWeights");
3538 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3539 "inputToInputWeights", "inputToCellWeights");
3540 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3541 "inputToInputWeights", "inputToOutputWeights");
3542
3543 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3544 "inputToInputWeights", "recurrentToInputWeights");
3545 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3546 "inputToInputWeights", "recurrentToForgeteights");
3547 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3548 "inputToInputWeights", "recurrentToCellWeights");
3549 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3550 "inputToInputWeights", "recurrentToOutputWeights");
3551
3552 // Validate matching quantization info for weight tensors (all should match each other)
3553 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3554 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3555 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3556 descriptorName, "inputToInputWeights", "inputToCellWeights");
3557 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3558 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3559
3560 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3561 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3562 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3563 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3564 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3565 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3566 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3567 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3568
3569 // Validate number of dimensions and number of elements in bias tensors
3570 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3571 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3572 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3573
3574 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3575 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3576 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3577
3578 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3579 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3580 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3581
3582 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3583 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3584 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3585
3586 // Validate data types for bias tensors (all should match each other)
3587 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3588
3589 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3590 "inputGateBias", "forgetGateBias");
3591 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3592 "inputGateBias", "cellBias");
3593 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3594 "inputGateBias", "outputGateBias");
3595
3596 // Validate bias tensor quantization info
Ryan OSheaf183acd2023-07-06 11:41:25 +01003597 ValidateBiasTensorQuantization(inputGateBiasInfo, inputToInputWeightsInfo, descriptorName);
3598 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputToInputWeightsInfo, descriptorName);
3599 ValidateBiasTensorQuantization(cellBiasInfo, inputToInputWeightsInfo, descriptorName);
3600 ValidateBiasTensorQuantization(outputGateBiasInfo, inputToInputWeightsInfo, descriptorName);
James Conroy9c3cae82019-08-01 16:01:48 +01003601}
3602
Kevin May868eb142019-09-04 17:29:31 +01003603void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3604{
3605 const std::string descriptorName{"AbsQueueDescriptor"};
3606
3607 ValidateNumInputs(workloadInfo, descriptorName, 1);
3608 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3609
3610 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3611 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3612
3613 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3614
3615 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003616 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003617 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003618 DataType::Float16,
3619 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003620 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003621 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003622 DataType::QSymmS16,
3623 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003624 };
Kevin May868eb142019-09-04 17:29:31 +01003625
3626 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3627 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3628}
3629
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003630void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3631{
3632 const std::string descriptorName{"SliceQueueDescriptor"};
3633
3634 ValidateNumInputs(workloadInfo, descriptorName, 1);
3635 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3636
3637 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3638 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3639
3640 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3641
3642 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3643 if (rank > 4)
3644 {
3645 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3646 }
3647
3648 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3649
3650 // Check if m_Begin and m_Size have the expected length
3651 if (m_Parameters.m_Begin.size() != rank)
3652 {
3653 throw InvalidArgumentException(descriptorName +
3654 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3655 }
3656 if (m_Parameters.m_Size.size() != rank)
3657 {
3658 throw InvalidArgumentException(descriptorName +
3659 ": Length of size descriptor must equal rank " + std::to_string(rank));
3660 }
3661
3662 // Check if the shape of the output tensor matches m_Size
3663 const TensorShape& outputShape = outputTensorInfo.GetShape();
3664 for (unsigned int i = 0u; i < rank; ++i)
3665 {
3666 if (m_Parameters.m_Size[i] != outputShape[i])
3667 {
3668 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3669 }
3670 }
3671
3672 // Check if the sum of begin offset and size in a given dimension
3673 // does not exceed the size of corresponding input
3674 const TensorShape& inputShape = inputTensorInfo.GetShape();
3675 for(unsigned int i = 0u; i < rank; ++i)
3676 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003677 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003678 {
3679 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3680 std::to_string(i) + " exceeds input size.");
3681 }
3682 }
3683}
3684
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003685void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3686{
3687 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3688
3689 ValidateNumInputs(workloadInfo, descriptorName, 1);
3690 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3691
3692 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3693 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3694
3695 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3696 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3697
3698 std::vector<DataType> supportedTypes =
3699 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003700 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003701 DataType::Float32,
3702 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003703 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003704 DataType::QAsymmU8,
3705 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003706 };
3707
3708 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3709 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3710
3711 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3712
3713 if (m_Parameters.m_BlockSize == 0)
3714 {
3715 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3716 }
3717
3718 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3719 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3720 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3721 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3722
3723 const TensorShape& outputShape = outputInfo.GetShape();
3724 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3725 {
3726 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3727 "must be divisible by block size.");
3728 }
3729
3730 const TensorShape& inputShape = inputInfo.GetShape();
3731 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3732 {
3733 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3734 "must be divisible by the square of block size." );
3735 }
3736}
3737
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003738void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3739{
3740 const std::string descriptorName{"ComparisonQueueDescriptor"};
3741
3742 ValidateNumInputs(workloadInfo, descriptorName, 2);
3743 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3744
3745 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3746 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3747 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3748
3749 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3750 inputTensorInfo1,
3751 outputTensorInfo,
3752 descriptorName,
3753 "input_0",
3754 "input_1");
3755
3756 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3757 {
3758 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3759 }
3760}
3761
Mike Kelly3ec30772023-03-08 13:47:17 +00003762void ElementwiseBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3763{
3764 const std::string descriptorName{"ElementwiseBinaryQueueDescriptor"};
3765
3766 ValidateNumInputs(workloadInfo, descriptorName, 2);
3767 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3768
3769 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3770 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3771 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3772
3773 std::vector<DataType> supportedTypes =
3774 {
3775 DataType::BFloat16,
3776 DataType::Float16,
3777 DataType::Float32,
3778 DataType::QAsymmS8,
3779 DataType::QAsymmU8,
3780 DataType::QSymmS16,
3781 DataType::Signed32
3782 };
3783
3784 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
3785 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
3786
3787 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input", "output");
3788 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input", "output");
3789}
3790
josh minor4a3c6102020-01-06 16:40:46 -06003791void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3792{
3793 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3794
3795 ValidateNumInputs(workloadInfo, descriptorName, 1);
3796 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3797
3798 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3799 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3800
3801 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3802
3803 std::vector<DataType> supportedTypes =
3804 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003805 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003806 DataType::Float16,
3807 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003808 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003809 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003810 DataType::QSymmS16,
3811 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003812 };
3813
James Conroyaba90cd2020-11-06 16:28:18 +00003814 std::vector<DataType> logicalSupportedTypes =
3815 {
3816 DataType::Boolean
3817 };
3818
3819 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3820 {
3821 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3822 }
3823 else
3824 {
3825 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3826 }
3827
3828
josh minor4a3c6102020-01-06 16:40:46 -06003829 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3830}
3831
Finn Williams2605b232020-06-10 15:53:46 +01003832void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3833{
3834 const std::string descriptorName{"RankQueueDescriptor"};
3835
3836 ValidateNumInputs(workloadInfo, descriptorName, 1);
3837 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3838
3839 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3840 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3841
3842 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3843 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3844
3845 std::vector<DataType> supportedTypes =
3846 {
3847 DataType::BFloat16,
3848 DataType::Float16,
3849 DataType::Float32,
3850 DataType::QAsymmS8,
3851 DataType::QAsymmU8,
3852 DataType::QSymmS8,
3853 DataType::QSymmS16,
3854 DataType::Signed32
3855 };
3856
3857 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3858 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3859}
3860
James Conroyaba90cd2020-11-06 16:28:18 +00003861void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3862{
3863 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3864
3865 ValidateNumInputs(workloadInfo, descriptorName, 2);
3866 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3867
3868 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3869 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3870 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3871
3872 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3873 inputTensorInfo1,
3874 outputTensorInfo,
3875 descriptorName,
3876 "input_0",
3877 "input_1");
3878
3879 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3880 {
3881 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3882 }
3883
3884 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3885 {
3886 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3887 }
3888
3889 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3890 {
3891 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3892 }
3893}
3894
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003895void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3896{
3897 const std::string descriptorName{"ReduceQueueDescriptor"};
3898
3899 ValidateNumInputs(workloadInfo, descriptorName, 1);
3900 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3901
3902 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3903 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3904
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003905 std::vector<DataType> supportedTypes =
3906 {
3907 DataType::BFloat16,
3908 DataType::Float16,
3909 DataType::Float32,
3910 DataType::QAsymmS8,
3911 DataType::QAsymmU8,
3912 DataType::QSymmS16,
3913 DataType::Signed32
3914 };
3915
3916 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3917 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3918}
3919
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003920void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3921{
3922 // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm
3923
3924 const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"};
3925
3926 // check dimensions of all inputs and outputs
3927 if (workloadInfo.m_InputTensorInfos.size() != 3)
3928 {
3929 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
3930 }
Mike Kelly12994962022-04-21 11:57:09 +01003931 if (workloadInfo.m_OutputTensorInfos.size() != 3)
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003932 {
3933 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
3934 }
3935
3936 std::vector<DataType> supportedTypes =
3937 {
Mike Kelly12994962022-04-21 11:57:09 +01003938 DataType::Float32,
3939 DataType::QAsymmS8
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003940 };
3941
3942 // check for supported type of one input and match them with all the other input and output
3943 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
3944
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003945 // Making sure clipping parameters have valid values.
3946 // == 0 means no clipping
3947 // > 0 means clipping
3948 if (m_Parameters.m_ClippingThresCell < 0.0f)
3949 {
3950 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
3951 }
3952 if (m_Parameters.m_ClippingThresProj < 0.0f)
3953 {
3954 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
3955 }
3956
3957 unsigned int batchIndx = 0;
3958 unsigned int inputIndx = 1;
3959 uint32_t timeStep = 1;
3960 unsigned int timeIndx = 1;
3961 inputIndx = 2;
3962 if (m_Parameters.m_TimeMajor)
3963 {
3964 batchIndx = 1;
3965 timeIndx = 0;
3966
3967 }
3968 timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx];
3969
3970 // Inferring batch size, number of outputs and number of cells from the inputs.
3971 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx];
3972 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx];
3973 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
3974 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
3975 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
3976 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
3977
3978 // input tensor
3979 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input),
3980 descriptorName + " input_0");
3981 // outputStateInTensor
3982 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
3983 descriptorName + " input_1");
3984 // outputStateInTensor
3985 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
3986 descriptorName + " input_2");
3987
3988 // outputTensor
Mike Kelly12994962022-04-21 11:57:09 +01003989 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 3, (timeStep * n_batch * n_output),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003990 descriptorName + " output_0");
3991
3992 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
3993 if ( m_InputToInputWeights )
3994 {
3995 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
3996 (n_cell * n_input), "InputLayerNormWeights");
3997 }
3998
3999 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
4000 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
4001 (n_cell * n_input), "InputToForgetWeights");
4002
4003 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
4004 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
4005 (n_cell * n_input), "InputToCellWeights");
4006
4007 if ( m_RecurrentToInputWeights )
4008 {
4009 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
4010 (n_cell * n_output), "RecurrentToInputWeights");
4011 }
4012
4013 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
4014 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
4015 (n_cell * n_output), "RecurrentToForgetWeights");
4016
4017 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
4018 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
4019 (n_cell * n_output), "RecurrentToCellWeights");
4020
4021 // Make sure the input-gate's parameters are either both present (regular
4022 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
4023 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
4024 !m_Parameters.m_CifgEnabled) ||
4025 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
4026 m_Parameters.m_CifgEnabled));
4027 if (!cifg_weights_all_or_none)
4028 {
4029 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
4030 "RecurrentToInputWeights must either both be present (regular LSTM) "
4031 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
4032 "accordingly.");
4033 }
4034
4035 if ( m_CellToInputWeights )
4036 {
4037 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
4038 n_cell, "CellToInputWeights");
4039 }
4040 if ( m_CellToForgetWeights )
4041 {
4042 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
4043 n_cell, "CellToForgetWeights");
4044 }
4045 if ( m_CellToOutputWeights )
4046 {
4047 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
4048 n_cell, "CellToOutputWeights");
4049 }
4050
4051 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
4052 bool peephole_weights_all_or_none =
4053 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
4054 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
4055 || ( !m_CellToInputWeights && !m_CellToForgetWeights
4056 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
4057 if (!peephole_weights_all_or_none)
4058 {
4059 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
4060 }
4061
4062 // Make sure the input gate bias is present only when not a CIFG-LSTM.
4063 if (m_Parameters.m_CifgEnabled)
4064 {
4065 if (m_InputGateBias)
4066 {
4067 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
4068 }
4069 }
4070 else
4071 {
4072 if (!m_InputGateBias)
4073 {
4074 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
4075 "must be present.");
4076 }
4077 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
4078 n_cell, "InputGateBias");
4079 }
4080
4081 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
4082 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
4083
4084 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
4085 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
4086
4087 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
4088 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
4089
4090 if (m_ProjectionWeights)
4091 {
4092 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
4093 (n_cell * n_output), "ProjectionWeights");
4094 }
4095 if (m_ProjectionBias)
4096 {
4097 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
4098 }
4099
4100 // Making sure the projection tensors are consistent:
4101 // 1) If projection weight is not present, then projection bias should not be
4102 // present.
4103 // 2) If projection weight is present, then projection bias is optional.
4104 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
4105 !m_Parameters.m_ProjectionEnabled)
4106 || (m_ProjectionWeights && !m_ProjectionBias &&
4107 m_Parameters.m_ProjectionEnabled)
4108 || (m_ProjectionWeights && m_ProjectionBias &&
4109 m_Parameters.m_ProjectionEnabled));
4110 if (!projecton_tensors_consistent)
4111 {
4112 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
4113 }
4114
4115 // The four layer normalization weights either all have values or none of them have values. Additionally, if
4116 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
4117 // either all have values or none of them have values. Layer normalization is used when the values of all the
4118 // layer normalization weights are present
4119 if (m_InputLayerNormWeights)
4120 {
4121 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
4122 }
4123 if (m_ForgetLayerNormWeights)
4124 {
4125 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4126 }
4127 if (m_CellLayerNormWeights)
4128 {
4129 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4130 }
4131 if (m_OutputLayerNormWeights)
4132 {
4133 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4134 }
4135
4136 if (m_Parameters.m_LayerNormEnabled)
4137 {
4138 if (!m_Parameters.m_CifgEnabled)
4139 {
4140 if (!m_InputLayerNormWeights)
4141 {
4142 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
4143 "disabled but InputLayerNormWeights are not present");
4144 }
4145 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
4146 1, n_cell, "InputLayerNormWeights");
4147 }
4148 else if (m_InputLayerNormWeights)
4149 {
4150 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
4151 "enabled");
4152 }
4153
4154 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
4155 "ForgetLayerNormWeights");
4156 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4157
4158 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
4159 "OutputLayerNormWeights");
4160 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4161
4162 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
4163 "CellLayerNormWeights");
4164 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4165 }
4166 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
4167 {
4168 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
4169 "normalisation weights are present.");
4170 }
4171}
4172
Samuel Yap6b478092022-07-06 15:36:03 +01004173void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4174{
4175 const std::string descriptorName{"BatchMatMulDescriptor"};
4176
4177 ValidateNumInputs(workloadInfo, descriptorName, 2);
4178 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4179
4180 // Inputs must be: both 2D+
4181 // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
4182 // axes N and I must be the same size
4183
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004184 const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
4185 const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
4186 const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
4187 // Output info has already been inferred
Samuel Yap6b478092022-07-06 15:36:03 +01004188
4189 std::vector<DataType> supportedTypes =
4190 {
4191 DataType::BFloat16,
4192 DataType::Float16,
4193 DataType::Float32,
4194 DataType::QAsymmS8,
4195 DataType::QAsymmU8,
4196 DataType::QSymmS16
4197 };
4198
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004199 ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
4200 ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
4201 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
Samuel Yap6b478092022-07-06 15:36:03 +01004202
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004203 if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
4204 (inputYInfoBeforeParams.GetNumDimensions() < 2))
Samuel Yap6b478092022-07-06 15:36:03 +01004205 {
4206 throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
4207 }
4208
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004209 TensorInfo inputXInfoAfterParams;
4210 TensorInfo inputYInfoAfterParams;
4211
4212 if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) ||
4213 (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY))
Samuel Yap6b478092022-07-06 15:36:03 +01004214 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004215 throw InvalidArgumentException(descriptorName +
4216 ": Invalid descriptor parameters - Transpose and Adjoint "
4217 "cannot both be true for a given input tensor.");
4218 }
4219 if(m_Parameters.m_TransposeX)
4220 {
4221 inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
4222 BatchMatMulDescriptor::GetPermuteVec(
4223 m_Parameters.m_DataLayoutX,
4224 inputXInfoBeforeParams.GetShape()));
4225 }
4226 else if(m_Parameters.m_AdjointX)
4227 {
4228 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4229 inputXInfoBeforeParams.GetShape());
4230 if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
4231 inputXInfoBeforeParams.GetShape()[axesToMul.second])
Samuel Yap6b478092022-07-06 15:36:03 +01004232 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004233 throw InvalidArgumentException(descriptorName +
4234 ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
Samuel Yap6b478092022-07-06 15:36:03 +01004235 }
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004236 // Shape remains the same as it's square
4237 inputXInfoAfterParams = inputXInfoBeforeParams;
4238 }
4239 else
4240 {
4241 inputXInfoAfterParams = inputXInfoBeforeParams;
Samuel Yap6b478092022-07-06 15:36:03 +01004242 }
4243
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004244 if(m_Parameters.m_TransposeY)
Samuel Yap6b478092022-07-06 15:36:03 +01004245 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004246 inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
4247 BatchMatMulDescriptor::GetPermuteVec(
4248 m_Parameters.m_DataLayoutY,
4249 inputYInfoBeforeParams.GetShape()));
4250 }
4251 else if(m_Parameters.m_AdjointY)
4252 {
4253 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4254 inputYInfoBeforeParams.GetShape());
4255 if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
4256 inputYInfoBeforeParams.GetShape()[axesToMul.second])
Samuel Yap6b478092022-07-06 15:36:03 +01004257 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004258 throw InvalidArgumentException(descriptorName +
4259 ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
Samuel Yap6b478092022-07-06 15:36:03 +01004260 }
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004261 // Shape remains the same as it's square
4262 inputYInfoAfterParams = inputYInfoBeforeParams;
4263 }
4264 else
4265 {
4266 inputYInfoAfterParams = inputYInfoBeforeParams;
Samuel Yap6b478092022-07-06 15:36:03 +01004267 }
4268
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004269 switch(m_Parameters.m_DataLayoutX)
4270 {
4271 case DataLayout::NCDHW:
4272 case DataLayout::NDHWC:
4273 if(inputXInfoAfterParams.GetNumDimensions() < 3)
4274 {
4275 throw InvalidArgumentException(descriptorName +
4276 ": Input tensor X does not have the correct "
4277 "number of dimensions for the Data Layout that it has been assigned.");
4278 }
4279 break;
4280 case DataLayout::NCHW:
4281 case DataLayout::NHWC:
4282 default:
4283 break;
4284 }
Samuel Yap6b478092022-07-06 15:36:03 +01004285
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004286 switch(m_Parameters.m_DataLayoutY)
4287 {
4288 case DataLayout::NCDHW:
4289 case DataLayout::NDHWC:
4290 if(inputYInfoAfterParams.GetNumDimensions() < 3)
4291 {
4292 throw InvalidArgumentException(descriptorName +
4293 ": Input tensor Y does not have the correct "
4294 "number of dimensions for the Data Layout that it has been assigned.");
4295 }
4296 break;
4297 case DataLayout::NCHW:
4298 case DataLayout::NHWC:
4299 default:
4300 break;
4301 }
4302
4303 auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4304 inputXInfoAfterParams.GetShape());
4305 auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4306 inputXInfoBeforeParams.GetShape());
4307
4308 if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
4309 != inputYInfoAfterParams.GetShape()[axesYToMul.first])
Samuel Yap6b478092022-07-06 15:36:03 +01004310 {
4311 throw InvalidArgumentException(descriptorName +
4312 ": The final axis of input tensor X must be the same size as "
4313 "the second last axis of input tensor Y.");
4314 }
4315
Samuel Yap6b478092022-07-06 15:36:03 +01004316 { // Separate scope so we don't pollute the rest of the scope with our temp variables
4317 // e.g. NHWC isnt compatible with NCHW as of now
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004318 DataLayout xLayout = m_Parameters.m_DataLayoutX;
4319 DataLayout yLayout = m_Parameters.m_DataLayoutY;
Samuel Yap6b478092022-07-06 15:36:03 +01004320
4321 if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
4322 {
4323 if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
4324 {
4325 throw InvalidArgumentException(descriptorName +
4326 ": Invalid input tensor data layout combination.");
4327 }
4328 }
4329 if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
4330 {
4331 if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
4332 {
4333 throw InvalidArgumentException(descriptorName +
4334 ": Invalid input tensor data layout combination.");
4335 }
4336 }
4337 }
4338
4339 // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004340 unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
4341 inputYInfoAfterParams.GetNumDimensions());
Samuel Yap6b478092022-07-06 15:36:03 +01004342 if(outputTensorDimSize-2 > 0)
4343 {
4344 TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4345 DataType::Float32);
4346 TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4347 DataType::Float32);
4348 TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4349 DataType::Float32);
4350
4351 auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
4352 {
4353 auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4354
4355 for(unsigned int i = 0; i < sizeDiff; i++)
4356 {
4357 axisIndices.insert(axisIndices.begin(), 1);
4358 }
4359
4360 for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4361 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004362 ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
Samuel Yap6b478092022-07-06 15:36:03 +01004363 }
4364 };
4365
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004366 auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX,
4367 inputXInfoAfterParams.GetShape());
4368 auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY,
4369 inputYInfoAfterParams.GetShape());
4370
4371 doAxisExtension(axesXNotMul, tiXNotMul);
4372 doAxisExtension(axesYNotMul, tiYNotMul);
Samuel Yap6b478092022-07-06 15:36:03 +01004373
4374 for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
4375 {
4376 tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
4377 tiYNotMul.GetShape()[i]);
4378 }
4379
4380 ValidateBroadcastTensorShapesMatch(tiXNotMul,
4381 tiYNotMul,
4382 tiOutNotMul,
4383 descriptorName,
4384 "input_X",
4385 "input_Y");
4386 }
Samuel Yap6b478092022-07-06 15:36:03 +01004387}
4388
Teresa Charlin79a06a52023-07-13 17:16:45 +01004389void TileQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4390{
4391 const std::string& descriptorName{"TileQueueDescriptor"};
4392
4393 ValidateNumInputs(workloadInfo, descriptorName, 1);
4394 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4395
4396 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
4397 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
4398
4399 std::vector<DataType> supportedTypes =
4400 {
4401 DataType::Float32,
4402 DataType::Float16,
4403 DataType::QAsymmS8,
4404 DataType::QAsymmU8,
4405 DataType::QSymmS8,
4406 DataType::QSymmS16,
4407 DataType::Signed32
4408 };
4409
4410 // Multiples length must be the same as the number of dimensions in input.
4411 if (m_Parameters.m_Multiples.size() != inputTensorInfo.GetNumDimensions())
4412 {
4413 throw InvalidArgumentException(descriptorName +
4414 ": Multiples length is not same as the number of dimensions in Input.");
4415 }
4416
4417 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
4418 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
4419}
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01004420
Idriss Chaouch98e383e2023-08-28 14:28:31 +01004421void BroadcastToQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4422{
4423 const std::string& descriptorName{"BroadcastToQueueDescriptor"};
4424
4425 ValidateNumInputs(workloadInfo, descriptorName, 1);
4426 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4427
4428 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
4429 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
4430
4431 std::vector<DataType> supportedTypes =
4432 {
4433 DataType::Float32,
4434 DataType::Float16,
4435 DataType::QAsymmS8,
4436 DataType::QAsymmU8,
4437 DataType::QSymmS8,
4438 DataType::QSymmS16,
4439 DataType::Signed32,
4440 DataType::Signed64
4441 };
4442
4443 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
4444 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
4445}
4446
mathad01df9a3222021-04-28 11:42:57 +01004447} // namespace armnn