blob: aa6bb848e58c5f4a3bb54ced7312d9ebb47badc5 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Mike Kelly3ec30772023-03-08 13:47:17 +00002// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Colm Donelan0c479742021-12-10 12:43:54 +00006#include <armnn/backends/TensorHandle.hpp>
7#include <armnn/backends/WorkloadData.hpp>
8#include <armnn/backends/WorkloadInfo.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00009#include <armnnUtils/DataLayoutIndexed.hpp>
10#include <armnnUtils/TensorUtils.hpp>
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010011#include <armnnUtils/Permute.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010012#include <armnn/utility/NumericCast.hpp>
mathad01df9a3222021-04-28 11:42:57 +010013#include <armnn/Logging.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000014
telsoa014fcda012018-03-09 14:13:49 +000015#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000017#include <string>
18#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000019
James Ward47fce872020-09-10 11:57:28 +010020#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000021
Matteo Martincigh21350152018-11-28 16:22:22 +000022using namespace armnnUtils;
23
telsoa014fcda012018-03-09 14:13:49 +000024namespace armnn
25{
26
27//---------------------------------------------------------------
28DataType GetBiasDataType(DataType inputDataType)
29{
30 switch (inputDataType)
31 {
telsoa01c577f2c2018-08-31 09:22:23 +010032 case DataType::Float16:
33 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000034 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000035 case DataType::Float32:
36 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000037 case DataType::QAsymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +000038 case DataType::QAsymmU8:
Keith Davis5204aa82020-01-27 15:24:59 +000039 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +000040 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010041 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000042 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010043 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000044 return DataType::Float32;
45 }
46}
47
48namespace
49{
50
51//---------------------------------------------------------------
52//android ndk does not support std::to_string function.
53template <typename T>
54std::string to_string(T value)
55{
56 std::ostringstream os;
57 os << value;
58 return os.str();
59}
60
61//---------------------------------------------------------------
62void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
63{
64 if (!ptr)
65 {
66 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
67 paramName + " parameter must be set.");
68 }
69}
70
71//---------------------------------------------------------------
72void ValidateTensorShapesMatch(const TensorInfo& first,
73 const TensorInfo& second,
74 std::string const& descName,
75 std::string const& firstName,
76 std::string const& secondName)
77{
78 if (first.GetShape() != second.GetShape())
79 {
80 throw InvalidArgumentException(descName + ": "
81 + firstName + " & " + secondName + " must have identical shapes");
82 }
83}
84
85//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010086void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000087{
Sadik Armaganeff363d2019-04-05 15:25:46 +010088 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000089 {
90 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010091 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000092 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
93 }
94}
95
96//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010097void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000098{
Sadik Armaganeff363d2019-04-05 15:25:46 +010099 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000100 {
101 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100102 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000103 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
104 }
105}
106
107//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000108
109//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100110void ValidateTensorNumElements(const TensorInfo& tensor,
111 std::string const& descName,
112 unsigned int numElements,
113 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100114{
115 if (tensor.GetNumElements() != numElements)
116 {
117 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100118 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100119 tensorName + " tensor.");
120 }
121}
122
123//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000124void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
125 const std::string& descName, std::string const& tensorName)
126{
127 if (tensor.GetDataType() != dataType)
128 {
129 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
130 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
131 }
132}
133
Derek Lambertid466a542020-01-22 15:37:29 +0000134void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
135{
Jan Eilers1b2654f2021-09-24 15:45:46 +0100136 if (tensor.GetDataType() != DataType::QSymmS8)
Derek Lambertid466a542020-01-22 15:37:29 +0000137 {
138 throw InvalidArgumentException(descName +
139 ": Expected data type which supports per-axis quantization scheme but got " +
140 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
141 }
Derek Lambertid466a542020-01-22 15:37:29 +0000142}
143
telsoa014fcda012018-03-09 14:13:49 +0000144//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100145void ValidateTensorQuantizationSpace(const TensorInfo& first,
146 const TensorInfo& second,
147 const std::string& descName,
148 std::string const& firstName,
149 std::string const& secondName)
150{
151 if (!first.IsQuantized() ||
152 !second.IsQuantized())
153 {
154 // Not a quantized type, ignore the validation
155 return;
156 }
157
158 DataType firstDataType = first.GetDataType();
159 DataType secondDataType = second.GetDataType();
160
161 if (firstDataType != secondDataType)
162 {
163 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
164 " must be of the same quantized type, " +
165 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
166 secondName + " is " + GetDataTypeName(secondDataType));
167 }
168
169 if (!first.IsTypeSpaceMatch(second))
170 {
171 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
172 " must have the same quantization space, " +
173 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
174 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
175 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
176 " and scale " + to_string(second.GetQuantizationScale()));
177 }
178}
179
180//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100181void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100182 const TensorInfo& weightsTensorInfo,
183 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000184{
185 if (biasTensor.GetQuantizationOffset() != 0)
186 {
187 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
188 to_string(biasTensor.GetQuantizationOffset()));
189 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000190
James Conroy8502ade2020-11-12 19:26:29 +0000191 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000192 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000193 // Validate per-axis quantization scales
194 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
195 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
196
197 if (weightScales.size() != biasScales.size())
198 {
199 std::stringstream msg;
James Conroy8502ade2020-11-12 19:26:29 +0000200 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
201 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
202 << ", biases=" << biasScales.size();
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000203 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
204 }
telsoa014fcda012018-03-09 14:13:49 +0000205 }
206}
207
208//---------------------------------------------------------------
209void ValidateTensors(const std::vector<ITensorHandle*>& vec,
Teresa Charlin79a06a52023-07-13 17:16:45 +0100210 unsigned int numExpected,
211 const std::string& descName,
212 const std::string& varName)
telsoa014fcda012018-03-09 14:13:49 +0000213{
214 if (vec.empty() && numExpected > 0)
215 {
216 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
217 }
218
219 for (unsigned int i = 0; i < numExpected; ++i)
220 {
221 if (!vec[i])
222 {
223 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
224 }
225 }
226}
227
228//---------------------------------------------------------------
229void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
230 const TensorInfo& second,
231 const TensorInfo& output,
232 std::string const& descName,
233 std::string const& firstName,
234 std::string const& secondName)
235{
236 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
237 // broadcasted.
238 if (first.GetNumDimensions() != second.GetNumDimensions())
239 {
240 throw InvalidArgumentException(descName + ": Tensors "
241 + firstName + " & " + secondName
242 + " must have the same number of dimensions in order to be broadcasted");
243 }
244 uint32_t numDims = first.GetNumDimensions();
245 std::vector<uint32_t> outputDims(numDims, 0u);
246 for (uint32_t i = 0; i < numDims; i++)
247 {
248 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
249 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
250 if (dimsNotEqual && dimsNotOne)
251 {
252 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
253 }
254 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
255 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100256 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000257 if (broadcastShape != output.GetShape())
258 {
259 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
260 + firstName + " & " + secondName
261 + " does not match the output shape");
262 }
263}
264
265//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100266void ValidateDataTypes(const TensorInfo& info,
267 const std::vector<armnn::DataType>& supportedTypes,
268 std::string const& descName)
269{
270 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
271 if (iterator == supportedTypes.end())
272 {
273 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
274 }
275}
276
James Conroy4d1ff582019-06-10 17:06:39 +0100277//---------------------------------------------------------------
278void ValidateTensorDataTypesMatch(const TensorInfo& first,
279 const TensorInfo& second,
280 std::string const& descName,
281 std::string const& firstName,
282 std::string const& secondName)
283{
284 if (first.GetDataType() != second.GetDataType())
285 {
286 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
287 " must have identical data types.");
288 }
289}
290
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100291//---------------------------------------------------------------
292void ValidateTensorNumElementsMatch(const TensorInfo& first,
293 const TensorInfo& second,
294 std::string const& descName,
295 std::string const& firstName,
296 std::string const& secondName)
297{
298 if (first.GetNumElements() != second.GetNumElements())
299 {
300 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
301 " must have the same number of elements.");
302 }
303}
304
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000305void ValidateWeightDataType(const TensorInfo& inputInfo,
306 const TensorInfo& weightInfo,
307 const std::string& descName)
308{
309 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000310 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000311 {
312 const std::vector<DataType> validTypes =
313 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000314 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100315 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +0100316 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000317 };
318
319 ValidateDataTypes(weightInfo, validTypes, descName);
320 }
321 else
322 {
323 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
324 }
325}
326
327void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
328 const std::string& descName,
329 const std::string& tensorName)
330{
331 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
332 if (!quantizationDim.has_value())
333 {
James Ward47fce872020-09-10 11:57:28 +0100334 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
335 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000336 }
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000337}
338
339void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
340 const std::string& descName,
341 const std::string& tensorName)
342{
343 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
344 if (quantizationOffset != 0)
345 {
James Ward47fce872020-09-10 11:57:28 +0100346 throw InvalidArgumentException(fmt::format(
347 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
348 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000349 }
350}
351
352void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
353 const TensorInfo& outputInfo,
354 const TensorInfo& weightInfo,
355 const Optional<TensorInfo>& optionalBiasInfo,
356 const std::string& descName)
357{
358 if (weightInfo.HasPerAxisQuantization())
359 {
360 const DataType inputDataType = inputInfo.GetDataType();
361 const DataType outputDataType = outputInfo.GetDataType();
362
Keith Davis0c2eeac2020-02-11 16:51:50 +0000363 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000364
365 if (!canHavePerAxisQuantization)
366 {
James Ward47fce872020-09-10 11:57:28 +0100367 throw InvalidArgumentException(fmt::format(
368 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
369 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000370 }
371
Derek Lambertid466a542020-01-22 15:37:29 +0000372
373 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000374 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
375 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
376
377 if (optionalBiasInfo.has_value())
378 {
379 const TensorInfo& biasInfo = optionalBiasInfo.value();
380 if (!biasInfo.HasPerAxisQuantization())
381 {
James Ward47fce872020-09-10 11:57:28 +0100382 throw InvalidArgumentException(fmt::format(
383 "{}: Per-axis quantization parameters not set on bias tensor, "
384 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000385 }
386
387 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
388 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
389 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
390 }
391 }
392}
393
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100394} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000395
Mike Kelly80512b02022-05-16 23:10:42 +0100396//---------------------------------------------------------------
397void QueueDescriptor::ValidateTensorNumDimensions(const TensorInfo& tensor,
398 std::string const& descName,
399 unsigned int numDimensions,
400 std::string const& tensorName) const
401{
402 // If we're allowing expanded dimensions then numDimensions becomes the minimum number of Dimensions we can allow.
403 // Throw an Exception if the tensors has fewer than numDimensions or if the squeezed dimensions are greater than
404 // numDimensions.
405 if (m_AllowExpandedDims)
406 {
407 unsigned int squeezedDims = 0;
408
409 for (unsigned int i = 0; i < tensor.GetNumDimensions(); ++i)
410 {
411 if (tensor.GetShape()[i] != 1)
412 {
413 ++squeezedDims;
414 }
415 }
416 if (tensor.GetNumDimensions() < numDimensions || squeezedDims > numDimensions)
417 {
418 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " or less but got " +
419 to_string(tensor.GetNumDimensions()) + " dimensions for " +
420 tensorName + " tensor.");
421 }
422 }
423 else
424 {
425 if (tensor.GetNumDimensions() != numDimensions)
426 {
427 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
428 to_string(tensor.GetNumDimensions()) + " dimensions for " +
429 tensorName + " tensor.");
430 }
431 }
432}
433
434//---------------------------------------------------------------
435void QueueDescriptor::ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Teresa Charlin79a06a52023-07-13 17:16:45 +0100436 unsigned int numDimension,
437 unsigned int numElements,
438 std::string const& tensorName) const
Mike Kelly80512b02022-05-16 23:10:42 +0100439{
440 const std::string functionName{"ValidateTensorNumDimNumElem"};
441 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
442 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
443}
444
445//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000446void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
447 unsigned int numExpectedIn, unsigned int numExpectedOut) const
448{
449 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
450 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
451}
452
453//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100454void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
455{
456 const std::string descriptorName{"MapQueueDescriptor"};
457
458 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100459 ValidateNumOutputs(workloadInfo, descriptorName, 0);
460
461 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
462 {
463 if (!m_Inputs[i])
464 {
465 throw InvalidArgumentException(
466 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
467 }
468 }
469}
470
471//---------------------------------------------------------------
472void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
473{
474 const std::string descriptorName{"UnmapQueueDescriptor"};
475
476 ValidateNumInputs(workloadInfo, descriptorName, 1);
477 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100478
479 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
480 {
481 if (!m_Inputs[i])
482 {
483 throw InvalidArgumentException(
484 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
485 }
486 }
487}
488
489//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000490void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
491{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100492 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000493
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100494 ValidateNumInputs(workloadInfo, descriptorName, 1);
495 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000496
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100497 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
498 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
499
500 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
501 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000502
503 if (m_Inputs.size() != m_Outputs.size())
504 {
James Ward47fce872020-09-10 11:57:28 +0100505 throw InvalidArgumentException(fmt::format(
506 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
507 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000508 }
509
510 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
511 {
512 if (!m_Inputs[i])
513 {
James Ward47fce872020-09-10 11:57:28 +0100514 throw InvalidArgumentException(fmt::format(
515 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000516 }
517
518 if (!m_Outputs[i])
519 {
James Ward47fce872020-09-10 11:57:28 +0100520 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000521 }
522 }
523}
524
Derek Lambertif674aa02019-08-01 15:56:25 +0100525//---------------------------------------------------------------
526void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
527{
528 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
529 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
530
531 if (workloadInfo.m_InputTensorInfos.size() != 1)
532 {
James Ward47fce872020-09-10 11:57:28 +0100533 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
534 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100535
536 }
537
538 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
539 {
James Ward47fce872020-09-10 11:57:28 +0100540 throw InvalidArgumentException(fmt::format(
541 "Number of input infos ({0}) does not match the number of output infos ({1})",
542 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100543 }
544
545 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
546 {
547 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
548 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
549 {
James Ward47fce872020-09-10 11:57:28 +0100550 throw InvalidArgumentException(fmt::format(
551 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100552 }
553 }
554
555 if (m_Inputs.size() != 1)
556 {
James Ward47fce872020-09-10 11:57:28 +0100557 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100558 }
559
560 if (m_Inputs.size() != m_Outputs.size())
561 {
James Ward47fce872020-09-10 11:57:28 +0100562 throw InvalidArgumentException(fmt::format(
563 "Number of inputs ({0}) does not match the number of outputs ({1})",
564 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100565 }
566
567 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
568 {
569 if (!m_Inputs[i])
570 {
James Ward47fce872020-09-10 11:57:28 +0100571 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100572 }
573
574 if (!m_Outputs[i])
575 {
James Ward47fce872020-09-10 11:57:28 +0100576 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100577 }
578 }
579}
580
581//---------------------------------------------------------------
582void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
583{
584 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
Derek Lambertif674aa02019-08-01 15:56:25 +0100585
Derek Lambertif674aa02019-08-01 15:56:25 +0100586 if (m_Inputs.size() != 1)
587 {
James Ward47fce872020-09-10 11:57:28 +0100588 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100589 }
590
591 if (m_Outputs.size() != 0)
592 {
James Ward47fce872020-09-10 11:57:28 +0100593 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100594 }
595
596 if (!m_Inputs[0])
597 {
James Ward47fce872020-09-10 11:57:28 +0100598 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100599 }
600}
601
602//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000603void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
604{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100605 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100606
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100607 ValidateNumInputs(workloadInfo, descriptorName, 1);
608 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100609
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100610 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
611 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100612
613 std::vector<DataType> supportedTypes =
614 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000615 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100616 DataType::Float16,
617 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000618 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000619 DataType::QAsymmU8,
620 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100621 };
622
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100623 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
624 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
625 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000626}
627
Nikhil Rajee391d52019-09-05 17:50:44 +0100628void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
629{
630 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
631
632 ValidateNumInputs(workloadInfo, descriptorName, 1);
633 ValidateNumOutputs(workloadInfo, descriptorName, 1);
634
635 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
636 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
637
Inki Daed4619e22020-09-10 15:33:54 +0900638 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
639 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100640 {
Inki Daed4619e22020-09-10 15:33:54 +0900641 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100642 }
643
James Conroyd47a0642019-09-17 14:22:06 +0100644 std::vector<DataType> supportedInputTypes =
645 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000646 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100647 DataType::Float16,
648 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100649 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000650 DataType::QAsymmU8,
651 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900652 DataType::Signed32,
653 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100654 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100655
James Conroyd47a0642019-09-17 14:22:06 +0100656 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100657
658 auto inputShape = inputTensorInfo.GetShape();
659 auto outputShape = outputTensorInfo.GetShape();
660
661 auto inputNumDimensions = inputShape.GetNumDimensions();
662 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
663
664 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
665
666 // 1D input shape results in scalar output shape
667 if (inputShape.GetNumDimensions() == 1)
668 {
669 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
670 {
671 throw InvalidArgumentException(descriptorName + outputShapeError);
672 }
673 }
674 else
675 {
676 for (unsigned int i = 0; i < unsignedAxis; ++i)
677 {
678 if (outputShape[i] != inputShape[i])
679 {
680 throw InvalidArgumentException(descriptorName + outputShapeError);
681 }
682 }
683
684 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
685 {
686 if (outputShape[i - 1] != inputShape[i])
687 {
688 throw InvalidArgumentException(descriptorName + outputShapeError);
689 }
690 }
691 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100692}
693
mathad01b392e982021-04-07 12:07:30 +0100694void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
695{
696 const std::string descriptorName{"CastQueueDescriptor"};
697
698 ValidateNumInputs(workloadInfo, descriptorName, 1);
699 ValidateNumOutputs(workloadInfo, descriptorName, 1);
700
701 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
702 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
703
704 std::vector<DataType> supportedTypes =
705 {
706 DataType::BFloat16,
707 DataType::Float16,
708 DataType::Float32,
709 DataType::QAsymmS8,
710 DataType::QAsymmU8,
711 DataType::QSymmS8,
712 DataType::QSymmS16,
713 DataType::Signed32,
714 DataType::Signed64
715 };
716
717 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
718 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
719}
720
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100721void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
722{
723 const std::string descriptorName{"SoftmaxQueueDescriptor"};
724
725 ValidateNumInputs(workloadInfo, descriptorName, 1);
726 ValidateNumOutputs(workloadInfo, descriptorName, 1);
727
728 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
729 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
730
731 std::vector<DataType> supportedTypes =
732 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000733 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100734 DataType::Float16,
735 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000736 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000737 DataType::QAsymmU8,
738 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100739 };
740
741 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
742 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
743 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
744}
745
telsoa014fcda012018-03-09 14:13:49 +0000746void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
747{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100748 const std::string descriptorName{"SplitterQueueDescriptor"};
749
750 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000751
Ruomei Yan25339c32019-05-28 16:48:20 +0100752 // Check the supported data types
753 std::vector<DataType> supportedTypes =
754 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000755 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100756 DataType::Float32,
757 DataType::Float16,
758 DataType::Boolean,
759 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100760 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000761 DataType::QAsymmU8,
762 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100763 };
764
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100765 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
766 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100767 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100768 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
769 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
770
771 const std::string outputName = "output_" + std::to_string(i);
772 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100773 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100774
telsoa014fcda012018-03-09 14:13:49 +0000775 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
776 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100777 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000778 }
779
780 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
781 {
782 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100783 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000784 "has to match number of workloadInfo.m_OutputTensorInfos. "
785 "Number of windows: " +
786 to_string(m_ViewOrigins.size()) +
787 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
788 }
789
telsoa01c577f2c2018-08-31 09:22:23 +0100790 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000791 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
792 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
793 {
telsoa01c577f2c2018-08-31 09:22:23 +0100794 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000795 ViewOrigin const& e = m_ViewOrigins[w];
796 if (e.m_Origin.size() != inputDims)
797 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100798 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000799 "have the same dimensionality as the input tensor. "
800 "Window origin (index: " +
801 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
802 " dimensions, the input "
803 "tensor has " +
804 to_string(inputDims) + " dimensions.");
805 }
806 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
807 {
808 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
809 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
810 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100811 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000812 "be smaller or equal than the size of the input in that coord.");
813 }
814 }
815 }
816}
817
Jim Flynne242f2d2019-05-22 14:24:13 +0100818void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000819{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100820 const std::string descriptorName{"ConcatQueueDescriptor"};
821
822 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000823
824 if (m_Inputs.size() <= 0)
825 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100826 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000827 }
828 if (m_Outputs.size() <= 0)
829 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100830 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000831 }
832
833 if (workloadInfo.m_InputTensorInfos.size() <= 0)
834 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100835 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000836 }
837 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
838 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100839 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000840 }
841
Nikhil Raj8599a412018-11-19 14:51:07 +0000842 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
843 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100844 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000845 }
846
847 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
848 {
849 return;
850 }
851
telsoa014fcda012018-03-09 14:13:49 +0000852 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
853 {
854 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100855 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000856 "has to match number of workloadInfo.m_InputTensorInfos. "
857 "Number of windows: " +
858 to_string(m_ViewOrigins.size()) +
859 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
860 }
861
telsoa01c577f2c2018-08-31 09:22:23 +0100862 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000863 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
864 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
865 {
telsoa01c577f2c2018-08-31 09:22:23 +0100866 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000867 ViewOrigin const& e = m_ViewOrigins[w];
868 if (e.m_Origin.size() != outputDims)
869 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100870 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000871 "have the same dimensionality as the output tensor. "
872 "Window origin (index: " +
873 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
874 " dimensions, the output "
875 "tensor has " +
876 to_string(outputDims) + " dimensions.");
877 }
telsoa01c577f2c2018-08-31 09:22:23 +0100878 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000879 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
880 {
881 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
882 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
883 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100884 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000885 "be smaller or equal than the size of the output in that coord.");
886 }
887 }
888 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100889
890 // Check the supported data types
891 std::vector<DataType> supportedTypes =
892 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000893 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100894 DataType::Float32,
895 DataType::Float16,
896 DataType::Boolean,
897 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100898 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000899 DataType::QAsymmU8,
900 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100901 };
902
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100903 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
904 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100905 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100906 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
907 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
908
909 const std::string inputName = "input_" + std::to_string(i);
910 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100911 }
telsoa014fcda012018-03-09 14:13:49 +0000912}
913
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100914void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
915{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100916 const std::string descriptorName{"StackQueueDescriptor"};
917
918 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100919
920 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
921 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100922 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100923 }
924
925 // All inputs must have the same shape, which is defined in parameters
926 const TensorShape& inputShape = m_Parameters.m_InputShape;
927 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
928 {
929 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
930 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100931 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100932 }
933 }
934
Matthew Jacksondba634f2019-08-15 15:14:18 +0100935 if (inputShape.GetNumDimensions() > 4)
936 {
937 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
938 }
939
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100940 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
941 // since the output tensor has an additional dimension.
942 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
943 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100944 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100945 "than the number of input dimensions.");
946 }
947
948 // Output shape must be as inferred from the input shape
949 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
950 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
951 {
952 if (outputShape[i] != inputShape[i])
953 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100954 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100955 "match shape inferred from input tensor.");
956 }
957 }
958
959 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
960 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100961 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100962 "match shape inferred from input tensor.");
963 }
964
965 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
966 {
967 if (outputShape[i] != inputShape[i-1])
968 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100969 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100970 "match shape inferred from input tensor.");
971 }
972 }
973
Matthew Jacksondba634f2019-08-15 15:14:18 +0100974 if (outputShape.GetNumDimensions() > 5)
975 {
976 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
977 }
978
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100979 // Check the supported data types
980 std::vector<DataType> supportedTypes =
981 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000982 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100983 DataType::Float32,
984 DataType::Float16,
985 DataType::Boolean,
986 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100987 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000988 DataType::QAsymmU8,
989 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100990 };
991
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100992 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100993
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100994 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100995 {
996 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
997 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100998 descriptorName,
999 "input_0",
1000 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001001 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001002
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001003 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1004 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001005 descriptorName,
1006 "input_0",
1007 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001008}
1009
Ryan OSheaec6c6802020-06-05 17:17:06 +01001010void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1011{
1012 const std::string descriptorName{"FillQueueDescriptor"};
1013
1014 ValidateNumInputs(workloadInfo, descriptorName, 1);
1015 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1016
1017 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1018 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1019
1020 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1021
1022 std::vector<DataType> supportedTypes =
1023 {
1024 DataType::BFloat16,
1025 DataType::Float32,
1026 DataType::Float16,
1027 DataType::Signed32
1028 };
1029
1030 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1031}
1032
telsoa014fcda012018-03-09 14:13:49 +00001033void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1034{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001035 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001036
Matthew Sloyan81beae32021-07-13 19:46:11 +01001037 uint32_t numInputs = 2;
1038 if (m_Parameters.m_BiasEnabled)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001039 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001040 numInputs = 3;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001041 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001042
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001043 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001044 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1045
1046 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1047 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1048
1049 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1050
1051 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001052 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001053 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001054 }
1055
Matthew Sloyan81beae32021-07-13 19:46:11 +01001056 TensorInfo weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001057 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001058
1059 if (m_Parameters.m_BiasEnabled)
1060 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001061 TensorInfo biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
telsoa01c577f2c2018-08-31 09:22:23 +01001062 // Validates type and quantization values.
Ryan OSheaf183acd2023-07-06 11:41:25 +01001063 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001064 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1065 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001066 }
1067
Francis Murtagh46c09d02019-05-28 08:15:28 +01001068 // Check the supported data types
1069 std::vector<DataType> supportedTypes =
1070 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001071 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001072 DataType::Float32,
1073 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001074 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001075 DataType::QAsymmU8,
1076 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001077 };
1078
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001079 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001080
1081 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1082 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1083 {
1084 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1085 {
1086 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1087 "for BFloat16 input.");
1088 }
1089 }
1090 else
1091 {
1092 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1093 }
telsoa014fcda012018-03-09 14:13:49 +00001094}
1095
telsoa014fcda012018-03-09 14:13:49 +00001096void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1097{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001098 const std::string descriptorName{"NormalizationQueueDescriptor"};
1099
1100 ValidateNumInputs(workloadInfo, descriptorName, 1);
1101 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1102
1103 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1104 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001105
1106 // Check the supported data types
1107 std::vector<DataType> supportedTypes =
1108 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001109 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001110 DataType::Float16,
1111 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001112 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001113 DataType::QAsymmU8,
1114 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001115 };
1116
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001117 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001118
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001119 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001120
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001121 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001122}
1123
1124void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1125{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001126 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001127
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001128 ValidateNumInputs(workloadInfo, descriptorName, 2);
1129 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1130
1131 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1132 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1133 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1134
1135 std::vector<DataType> supportedTypes =
1136 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001137 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001138 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001139 DataType::Float16,
1140 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001141 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001142 DataType::QSymmS16,
1143 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001144 };
1145
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001146 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1147 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1148 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001149
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001150 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1151 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001152
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001153 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1154 inputTensorInfo1,
1155 outputTensorInfo,
1156 descriptorName,
1157 "input_0",
1158 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001159}
1160
telsoa014fcda012018-03-09 14:13:49 +00001161void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1162{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001163 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001164
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001165 ValidateNumInputs(workloadInfo, descriptorName, 2);
1166 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1167
1168 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1169 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1170 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1171
1172 std::vector<DataType> supportedTypes =
1173 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001174 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001175 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001176 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001177 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001178 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001179 DataType::QSymmS16,
1180 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001181 };
1182
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001183 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1184 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1185 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001186
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001187 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1188 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001189
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001190 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1191 inputTensorInfo1,
1192 outputTensorInfo,
1193 descriptorName,
1194 "input_0",
1195 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001196}
1197
1198void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1199{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001200 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001201
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001202 ValidateNumInputs(workloadInfo, descriptorName, 1);
1203 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1204
1205 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1206 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001207
1208 std::vector<DataType> supportedTypes =
1209 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001210 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001211 DataType::Float16,
1212 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001213 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001214 DataType::QAsymmU8,
1215 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001216 };
1217
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001218 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1219 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001220
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001221 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001222 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001223
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001224 ValidatePointer(m_Mean, descriptorName, "mean");
1225 ValidatePointer(m_Variance, descriptorName, "variance");
1226 ValidatePointer(m_Beta, descriptorName, "beta");
1227 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001228
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001229 const TensorInfo& mean = m_Mean->GetTensorInfo();
1230 const TensorInfo& variance = m_Variance->GetTensorInfo();
1231 const TensorInfo& beta = m_Beta->GetTensorInfo();
1232 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001233
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001234 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1235 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1236 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1237 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001238
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001239 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1240 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1241 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001242}
1243
1244void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1245{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001246 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001247
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001248 uint32_t numInputs = 2;
1249 if (m_Parameters.m_BiasEnabled)
1250 {
1251 numInputs = 3;
1252 }
1253
1254 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001255 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001256
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001257 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1258 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001259
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001260 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1261 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001262
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001263 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
telsoa014fcda012018-03-09 14:13:49 +00001264
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001265 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001266
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001267 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001268
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001269 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001270 if (m_Parameters.m_BiasEnabled)
1271 {
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001272 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001273 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001274
1275 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Ryan OSheaf183acd2023-07-06 11:41:25 +01001276 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001277 }
1278
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001279 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1280 {
1281 throw InvalidArgumentException(
1282 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1283 "cannot be either negative or 0.",
1284 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1285 }
1286
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001287 ValidatePerAxisQuantization(inputTensorInfo,
1288 outputTensorInfo,
1289 weightTensorInfo,
1290 optionalBiasTensorInfo,
1291 descriptorName);
1292
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001293 std::vector<DataType> supportedTypes =
1294 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001295 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001296 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001297 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001298 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001299 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001300 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001301 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001302 };
1303
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001304 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001305
1306 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1307 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1308 {
1309 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1310 {
1311 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1312 "for BFloat16 input.");
1313 }
1314 }
1315 else
1316 {
1317 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1318 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001319}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001320
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001321void Convolution3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1322{
1323 const std::string descriptorName{"Convolution3dQueueDescriptor"};
1324
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001325 uint32_t numInputs = 2;
1326 if (m_Parameters.m_BiasEnabled)
1327 {
1328 numInputs = 3;
1329 }
1330 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001331 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1332
1333 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1334 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1335
1336 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1337 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1338
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001339 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001340 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 5, "weight");
1341
1342 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1343
1344 Optional<TensorInfo> optionalBiasTensorInfo;
1345 if (m_Parameters.m_BiasEnabled)
1346 {
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001347 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001348 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1349
1350 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Ryan OSheaf183acd2023-07-06 11:41:25 +01001351 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001352 }
1353
1354 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 || m_Parameters.m_StrideZ <= 0 )
1355 {
1356 throw InvalidArgumentException(
1357 fmt::format("{}: strideX (provided {}), strideY (provided {}) or strideZ (provided {})"
1358 "cannot be either negative or 0.",
1359 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY, m_Parameters.m_StrideZ));
1360 }
1361
1362 ValidatePerAxisQuantization(inputTensorInfo,
1363 outputTensorInfo,
1364 weightTensorInfo,
1365 optionalBiasTensorInfo,
1366 descriptorName);
1367
1368 std::vector<DataType> supportedTypes =
1369 {
1370 DataType::BFloat16,
1371 DataType::Float16,
1372 DataType::Float32,
1373 DataType::QAsymmS8,
1374 DataType::QAsymmU8,
1375 DataType::QSymmS16,
1376 DataType::QSymmS8
1377 };
1378
1379 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1380 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1381}
1382
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001383void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1384{
1385 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1386
Cathal Corbett06902652022-04-14 17:55:11 +01001387 uint32_t numInputs = 2;
1388 if (m_Parameters.m_BiasEnabled)
1389 {
1390 numInputs = 3;
1391 }
1392
1393 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001394 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1395
1396 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1397 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1398
1399 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1400 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1401
Cathal Corbett06902652022-04-14 17:55:11 +01001402 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001403 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1404
1405 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1406 {
1407 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001408 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1409 "cannot be smaller than 1.",
1410 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001411 }
1412
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001413 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1414 {
1415 throw InvalidArgumentException(
1416 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1417 "cannot be either negative or 0.",
1418 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1419 }
1420
Jan Eilers53ef7952021-06-02 12:01:25 +01001421 if (weightTensorInfo.GetShape()[0] != 1)
1422 {
1423 throw InvalidArgumentException(fmt::format(
1424 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1425 "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
1426 descriptorName,
1427 weightTensorInfo.GetShape()[0],
1428 weightTensorInfo.GetShape()[1],
1429 weightTensorInfo.GetShape()[2],
1430 weightTensorInfo.GetShape()[3]));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001431 }
1432
Cathal Corbett4b19d222022-05-11 20:12:17 +01001433 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1434 const unsigned int numWeightOutputChannelsRefFormat = weightTensorInfo.GetShape()[3];
1435 const unsigned int numWeightOutputChannelsAclFormat = weightTensorInfo.GetShape()[1];
1436 const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1437
1438 // Weights format has two valid options: [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] (CpuAcc/GpuAcc).
1439 bool validRefFormat = (numWeightOutputChannelsRefFormat == numOutputChannels);
1440 bool validAclFormat = (numWeightOutputChannelsAclFormat == numOutputChannels);
1441
1442 if (!(validRefFormat || validAclFormat))
1443 {
1444 throw InvalidArgumentException(fmt::format(
1445 "{0}: The weight format in armnn is expected to be [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] "
1446 "(CpuAcc/GpuAcc). But neither the 4th (CpuRef) or 2nd (CpuAcc/GpuAcc) dimension is equal to Cout."
1447 "Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
1448 descriptorName,
1449 numOutputChannels,
1450 weightTensorInfo.GetShape()[0],
1451 weightTensorInfo.GetShape()[1],
1452 weightTensorInfo.GetShape()[2],
1453 weightTensorInfo.GetShape()[3]));
1454 }
1455
Teresa Charlind8df0262019-11-11 12:28:15 +00001456 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001457
Teresa Charlind8df0262019-11-11 12:28:15 +00001458 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001459 if (m_Parameters.m_BiasEnabled)
1460 {
Cathal Corbett06902652022-04-14 17:55:11 +01001461 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Teresa Charlind8df0262019-11-11 12:28:15 +00001462 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001463
Ryan OSheaf183acd2023-07-06 11:41:25 +01001464 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001465 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1466 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001467 ValidatePerAxisQuantization(inputTensorInfo,
1468 outputTensorInfo,
1469 weightTensorInfo,
1470 optionalBiasTensorInfo,
1471 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001472
1473 std::vector<DataType> supportedTypes =
1474 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001475 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001476 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001477 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001478 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001479 DataType::QAsymmU8,
1480 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001481 };
1482
1483 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1484 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001485}
1486
1487void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1488{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001489 const std::string descriptorName{"PermuteQueueDescriptor"};
1490
1491 ValidateNumInputs(workloadInfo, descriptorName, 1);
1492 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001493
1494 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1495
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001496 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1497 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001498
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001499 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1500 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001501
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001502 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001503 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001504 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001505 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001506 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1507 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1508 "must match dst dimension " + to_string(mapping[i]) +
1509 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001510 }
1511 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001512
1513 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001514}
1515
1516void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1517{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001518 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001519
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001520 ValidateNumInputs(workloadInfo, descriptorName, 1);
1521 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1522
1523 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1524 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1525
1526 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1527 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001528
1529 std::vector<DataType> supportedTypes =
1530 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001531 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001532 DataType::Float32,
1533 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001534 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001535 DataType::QAsymmU8,
1536 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001537 };
1538
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001539 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1540 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001541}
1542
Tamás Nyíri7b885b32021-10-26 14:47:57 +01001543void Pooling3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1544{
1545 const std::string descriptorName{"Pooling3dQueueDescriptor"};
1546
1547 ValidateNumInputs(workloadInfo, descriptorName, 1);
1548 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1549
1550 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1551 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1552
1553 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1554 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1555
1556 std::vector<DataType> supportedTypes =
1557 {
1558 DataType::BFloat16,
1559 DataType::Float32,
1560 DataType::Float16,
1561 DataType::QAsymmS8,
1562 DataType::QAsymmU8,
1563 DataType::QSymmS16
1564 };
1565
1566 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1567 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1568}
1569
Teresa Charlin970f43b2019-07-01 13:51:07 +01001570void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1571{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001572 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001573
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001574 ValidateNumInputs(workloadInfo, descriptorName, 1);
1575 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1576
1577 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1578 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1579
1580 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1581 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001582
1583 std::vector<DataType> supportedTypes =
1584 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001585 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001586 DataType::Float16,
1587 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001588 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001589 DataType::QAsymmU8,
1590 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001591 };
1592
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001593 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1594 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001595
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001596 // Resize only changes width and height: batch and channel count must match.
1597 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1598 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001599 if (inputBatchSize != outputBatchSize)
1600 {
1601 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001602 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1603 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001604 }
1605
1606 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001607 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1608 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001609 if (inputChannelCount != outputChannelCount)
1610 {
1611 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001612 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1613 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001614 }
1615}
1616
Teresa Charlin79a06a52023-07-13 17:16:45 +01001617void ReverseV2QueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
1618{
Tianle Cheng988354d2023-06-28 13:20:47 +01001619 const std::string descriptorName{"ReverseV2QueueDescriptor"};
1620
Tracy Narinebb8d7592023-07-13 16:50:54 +01001621 // Backend restriction
1622 const unsigned int maxDimensions = 4;
1623
1624 ValidateNumInputs(workloadInfo, descriptorName, 2);
Tianle Cheng988354d2023-06-28 13:20:47 +01001625 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1626
1627 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
Tracy Narinebb8d7592023-07-13 16:50:54 +01001628 const TensorInfo& axisTensorInfo = workloadInfo.m_InputTensorInfos[1];
Tianle Cheng988354d2023-06-28 13:20:47 +01001629 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1630
Tracy Narinebb8d7592023-07-13 16:50:54 +01001631 const auto inputTensorNumDimensions = inputTensorInfo.GetNumDimensions();
1632 if (inputTensorNumDimensions > maxDimensions)
Tianle Cheng988354d2023-06-28 13:20:47 +01001633 {
1634 throw InvalidArgumentException(descriptorName +
1635 ": Input tensors with rank greater than " +
Tracy Narinebb8d7592023-07-13 16:50:54 +01001636 std::to_string(maxDimensions) + " are not supported.");
1637 }
1638
1639 const auto axisTensorNumDimensions = axisTensorInfo.GetNumDimensions();
1640 if (axisTensorNumDimensions > maxDimensions)
1641 {
1642 throw InvalidArgumentException(descriptorName +
1643 ": More than " + std::to_string(maxDimensions) + " axes cannot be specified.");
1644 }
1645
1646 if (axisTensorNumDimensions > inputTensorNumDimensions)
1647 {
1648 throw InvalidArgumentException(descriptorName +
1649 ": More axes specified than the number of axes on the input tensor.");
Tianle Cheng988354d2023-06-28 13:20:47 +01001650 }
1651
1652 std::vector<DataType> supportedTypes =
1653 {
1654 DataType::BFloat16,
1655 DataType::Float16,
1656 DataType::Float32,
1657 DataType::QAsymmS8,
1658 DataType::QAsymmU8,
Declan-ARM1bf56cd2023-07-20 17:32:57 +01001659 DataType::QSymmS8,
1660 DataType::QSymmS16,
1661 DataType::Signed32
Tianle Cheng988354d2023-06-28 13:20:47 +01001662 };
1663
1664 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Tracy Narinebb8d7592023-07-13 16:50:54 +01001665
1666 std::vector<DataType> axisSupportedTypes =
1667 {
1668 DataType::Signed32,
1669 };
1670
1671 ValidateDataTypes(axisTensorInfo, axisSupportedTypes, descriptorName);
1672
Tianle Cheng988354d2023-06-28 13:20:47 +01001673 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1674 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Tianle Cheng988354d2023-06-28 13:20:47 +01001675}
1676
telsoa014fcda012018-03-09 14:13:49 +00001677void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1678{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001679 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001680
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001681 ValidateNumInputs(workloadInfo, descriptorName, 1);
1682 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1683
1684 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1685 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1686
1687 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1688 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1689
1690 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1691
telsoa014fcda012018-03-09 14:13:49 +00001692 if (m_Parameters.m_Min > m_Parameters.m_Max)
1693 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001694 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001695 }
telsoa014fcda012018-03-09 14:13:49 +00001696}
1697
Kevin Mayce5045a2019-10-02 14:07:47 +01001698void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1699{
1700 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1701
1702 ValidateNumInputs(workloadInfo, descriptorName, 1);
1703 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1704
1705 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1706 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1707
1708 if (inputTensorInfo.GetNumDimensions() > 4)
1709 {
1710 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1711 }
1712
1713 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1714
1715 // Check the supported data types
1716 std::vector<DataType> supportedTypes =
1717 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001718 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001719 DataType::Float32,
1720 DataType::Float16
1721 };
1722
1723 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001724 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001725}
1726
telsoa014fcda012018-03-09 14:13:49 +00001727void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1728{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001729 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001730
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001731 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001732 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1733
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001734 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1735 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1736
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001737 if (inputTensorInfo.GetNumDimensions() > 4)
1738 {
1739 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1740 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001741
1742 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001743
1744 // Check the supported data types
1745 std::vector<DataType> supportedTypes =
1746 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001747 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001748 DataType::Float32,
1749 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001750 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001751 DataType::QAsymmU8,
1752 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001753 };
1754
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001755 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001756 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1757}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001758
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001759void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1760{
1761 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1762
1763 ValidateNumInputs(workloadInfo, descriptorName, 1);
1764 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1765
1766 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1767 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1768
1769 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1770
1771 std::vector<DataType> supportedTypes =
1772 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001773 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001774 DataType::Float32,
1775 DataType::Float16,
1776 };
1777
1778 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001779 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001780}
1781
1782void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1783{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001784 const std::string descriptorName{"ConstantQueueDescriptor"};
1785
1786 ValidateNumInputs(workloadInfo, descriptorName, 0);
1787 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001788
1789 if (!m_LayerOutput)
1790 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001791 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001792 }
1793
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001794 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1795 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001796
1797 // Check the supported data types
1798 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001799 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001800 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001801 DataType::Float32,
1802 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001803 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001804 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001805 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001806 DataType::QSymmS16,
1807 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001808 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001809
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001810 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001811}
1812
1813void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1814{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001815 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001816
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001817 ValidateNumInputs(workloadInfo, descriptorName, 1);
1818 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1819
1820 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1821 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1822
1823 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001824
1825 // Check the supported data types
1826 std::vector<DataType> supportedTypes =
1827 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001828 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001829 DataType::Float32,
1830 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001831 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001832 DataType::QAsymmU8,
1833 DataType::QSymmS16,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001834 DataType::Signed32,
1835 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001836 };
1837
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001838 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1839 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001840}
1841
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001842void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1843{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001844 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001845
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001846 ValidateNumInputs(workloadInfo, descriptorName, 1);
1847 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1848
1849 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1850 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1851
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001852 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1853 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001854 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1855 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001856 }
1857
Teresa Charlinf77cab52023-06-01 16:15:13 +01001858 if (m_Parameters.m_BlockShape.size() == 2)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001859 {
Teresa Charlinf77cab52023-06-01 16:15:13 +01001860 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1861 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1862 }
1863 else if (m_Parameters.m_BlockShape.size() == 1)
1864 {
1865 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 3, "input");
1866 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 3, "output");
1867 }
1868 else
1869 {
1870 throw InvalidArgumentException(descriptorName + ": Invalid Block and Crops size.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001871 }
1872
Teresa Charlinf77cab52023-06-01 16:15:13 +01001873 // Check input + padding and output have the same number of elements
1874 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1875 const unsigned int inputHeight = inputTensorInfo.GetShape()[dimensionIndices.GetHeightIndex()] +
1876 m_Parameters.m_PadList[0].first + m_Parameters.m_PadList[0].second;
1877 const unsigned int inputWidth = (inputTensorInfo.GetNumDimensions() == 3) ? 1 :
1878 inputTensorInfo.GetShape()[dimensionIndices.GetWidthIndex()] +
1879 m_Parameters.m_PadList[1].first + m_Parameters.m_PadList[1].second;
1880
1881 const int channelsIndex_int = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : -1;
1882 const unsigned int channelsIndex = channelsIndex_int < 0 ?
1883 static_cast<unsigned int>(channelsIndex_int) + inputTensorInfo.GetNumDimensions()
1884 : static_cast<unsigned int>(channelsIndex_int);
1885
1886 const unsigned int numInputElements = inputTensorInfo.GetShape()[0] *
1887 inputHeight *
1888 inputWidth *
1889 inputTensorInfo.GetShape()[channelsIndex];
1890
1891 if (outputTensorInfo.GetNumElements() != numInputElements)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001892 {
Teresa Charlinf77cab52023-06-01 16:15:13 +01001893 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
1894 to_string(numInputElements) + " after padding but output tensor has " +
1895 to_string(outputTensorInfo.GetNumElements()) + " elements.");
1896 }
1897
1898 // In a 4D tensor, there will be 2 spatialDimensions (H and W), and the for loop will run twice.
1899 // In a 3D tensor, there will be 1 spatialDimensions, and the for loop will run once.
1900 unsigned int firstSpatialDimension = m_Parameters.m_DataLayout == DataLayout::NCHW ? 2 : 1;
1901 for (unsigned int i = 0; i < m_Parameters.m_BlockShape.size(); ++i)
1902 {
1903 unsigned int spatialDimension = firstSpatialDimension + i;
1904 auto inputSize = inputTensorInfo.GetShape()[spatialDimension] +
1905 m_Parameters.m_PadList[i].first +
1906 m_Parameters.m_PadList[i].second;
1907 if (inputSize % m_Parameters.m_BlockShape[i] != 0)
1908 {
1909 throw InvalidArgumentException(descriptorName + ": Input dimension size after padding must be "
1910 "divisible by Block Shape in dimension: " + to_string(spatialDimension) + ".");
1911 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001912 }
nikraj01120522a2019-05-31 11:33:07 +01001913
1914 std::vector<DataType> supportedTypes =
1915 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001916 DataType::BFloat16,
1917 DataType::Float16,
1918 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001919 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001920 DataType::QAsymmU8,
1921 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001922 };
1923
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001924 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1925 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001926}
1927
Keith Davisa57eccb2019-06-14 17:33:22 +01001928void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1929{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001930 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001931
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001932 ValidateNumInputs(workloadInfo, descriptorName, 1);
1933 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001934
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001935 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1936 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1937
1938 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1939 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001940
1941 std::vector<DataType> supportedTypes =
1942 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001943 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001944 DataType::Float32,
1945 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001946 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001947 DataType::QAsymmU8,
1948 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001949 };
1950
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001951 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1952 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001953
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001954 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1955
1956 if (m_Parameters.m_BlockSize == 0)
1957 {
1958 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1959 }
1960
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001961 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1962 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1963 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1964 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001965
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001966 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001967 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001968 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001969 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1970 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001971 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001972
1973 const TensorShape& outputShape = outputTensorInfo.GetShape();
1974 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1975 {
1976 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1977 "must be divisible by the square of block size." );
1978 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001979}
1980
telsoa014fcda012018-03-09 14:13:49 +00001981void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1982{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001983 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001984
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001985 ValidateNumInputs(workloadInfo, descriptorName, 1);
1986 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1987
1988 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1989 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001990
1991 std::vector<DataType> supportedTypes =
1992 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001993 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001994 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001995 DataType::Float16,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001996 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001997 };
1998
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001999 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matthew Sloyan81beae32021-07-13 19:46:11 +01002000 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2001 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2002 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00002003}
2004
telsoa01c577f2c2018-08-31 09:22:23 +01002005void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2006{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002007 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
2008
2009 const std::string descriptorName{"LstmQueueDescriptor"};
2010
2011 // check dimensions of all inputs and outputs
2012 if (workloadInfo.m_InputTensorInfos.size() != 3)
2013 {
2014 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
2015 }
2016 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2017 {
2018 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
2019 }
2020
2021 std::vector<DataType> supportedTypes =
2022 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002023 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01002024 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002025 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002026 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002027 };
2028
Jan Eilers38e05bd2019-06-26 13:10:09 +01002029 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002030 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
2031
Jan Eilers38e05bd2019-06-26 13:10:09 +01002032 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002033 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01002034 {
2035 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2036 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002037 descriptorName,
2038 "input_0",
2039 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01002040 }
2041 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002042 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01002043 {
2044 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2045 workloadInfo.m_OutputTensorInfos[i],
2046 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002047 "input_0",
2048 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01002049 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002050
janeil0117d8d852019-11-15 15:00:16 +00002051 // Making sure clipping parameters have valid values.
2052 // == 0 means no clipping
2053 // > 0 means clipping
2054 if (m_Parameters.m_ClippingThresCell < 0.0f)
2055 {
2056 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
2057 }
2058 if (m_Parameters.m_ClippingThresProj < 0.0f)
2059 {
2060 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
2061 }
2062
Jan Eilers38e05bd2019-06-26 13:10:09 +01002063 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01002064 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
2065 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
2066 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
2067 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
2068 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
2069 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
2070
Jan Eilers38e05bd2019-06-26 13:10:09 +01002071 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002072 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
2073 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002074 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002075 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
2076 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002077 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002078 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
2079 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002080 // scratchBufferTensor
2081 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002082 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
2083 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002084 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002085 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
2086 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002087 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002088 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
2089 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002090 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002091 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
2092 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002093
Jan Eilers38e05bd2019-06-26 13:10:09 +01002094 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
2095 if ( m_InputToInputWeights )
2096 {
2097 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
2098 (n_cell * n_input), "InputLayerNormWeights");
2099 }
2100
2101 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2102 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
2103 (n_cell * n_input), "InputToForgetWeights");
2104
2105 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2106 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2107 (n_cell * n_input), "InputToCellWeights");
2108
2109 if ( m_RecurrentToInputWeights )
2110 {
2111 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2112 (n_cell * n_output), "RecurrentToInputWeights");
2113 }
2114
2115 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2116 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2117 (n_cell * n_output), "RecurrentToForgetWeights");
2118
2119 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2120 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2121 (n_cell * n_output), "RecurrentToCellWeights");
2122
2123 // Make sure the input-gate's parameters are either both present (regular
2124 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2125 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2126 !m_Parameters.m_CifgEnabled) ||
2127 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2128 m_Parameters.m_CifgEnabled));
2129 if (!cifg_weights_all_or_none)
2130 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002131 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2132 "RecurrentToInputWeights must either both be present (regular LSTM) "
2133 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2134 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002135 }
2136
2137 if ( m_CellToInputWeights )
2138 {
2139 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2140 n_cell, "CellToInputWeights");
2141 }
2142 if ( m_CellToForgetWeights )
2143 {
2144 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2145 n_cell, "CellToForgetWeights");
2146 }
2147 if ( m_CellToOutputWeights )
2148 {
2149 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2150 n_cell, "CellToOutputWeights");
2151 }
2152
2153 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2154 bool peephole_weights_all_or_none =
2155 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2156 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2157 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2158 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2159 if (!peephole_weights_all_or_none)
2160 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002161 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002162 }
2163
2164 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2165 if (m_Parameters.m_CifgEnabled)
2166 {
2167 if (m_InputGateBias)
2168 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002169 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002170 }
2171 }
2172 else
2173 {
2174 if (!m_InputGateBias)
2175 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002176 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2177 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002178 }
2179 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2180 n_cell, "InputGateBias");
2181 }
2182
2183 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2184 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2185
2186 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2187 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2188
2189 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2190 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2191
2192 if (m_ProjectionWeights)
2193 {
2194 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2195 (n_cell * n_output), "ProjectionWeights");
2196 }
2197 if (m_ProjectionBias)
2198 {
2199 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2200 }
2201
2202 // Making sure the projection tensors are consistent:
2203 // 1) If projection weight is not present, then projection bias should not be
2204 // present.
2205 // 2) If projection weight is present, then projection bias is optional.
2206 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2207 !m_Parameters.m_ProjectionEnabled)
2208 || (m_ProjectionWeights && !m_ProjectionBias &&
2209 m_Parameters.m_ProjectionEnabled)
2210 || (m_ProjectionWeights && m_ProjectionBias &&
2211 m_Parameters.m_ProjectionEnabled));
2212 if (!projecton_tensors_consistent)
2213 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002214 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002215 }
2216
2217 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2218 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2219 // either all have values or none of them have values. Layer normalization is used when the values of all the
2220 // layer normalization weights are present
2221 if (m_InputLayerNormWeights)
2222 {
2223 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2224 }
2225 if (m_ForgetLayerNormWeights)
2226 {
2227 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2228 }
2229 if (m_CellLayerNormWeights)
2230 {
2231 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2232 }
2233 if (m_OutputLayerNormWeights)
2234 {
2235 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2236 }
2237
Jan Eilers38e05bd2019-06-26 13:10:09 +01002238 if (m_Parameters.m_LayerNormEnabled)
2239 {
2240 if (!m_Parameters.m_CifgEnabled)
2241 {
2242 if (!m_InputLayerNormWeights)
2243 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002244 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2245 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002246 }
2247 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2248 1, n_cell, "InputLayerNormWeights");
2249 }
2250 else if (m_InputLayerNormWeights)
2251 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002252 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2253 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002254 }
2255
2256 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2257 "ForgetLayerNormWeights");
2258 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2259
2260 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2261 "OutputLayerNormWeights");
2262 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2263
2264 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2265 "CellLayerNormWeights");
2266 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2267 }
2268 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2269 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002270 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2271 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002272 }
telsoa01c577f2c2018-08-31 09:22:23 +01002273}
2274
2275void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2276{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002277 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002278
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002279 ValidateNumInputs(workloadInfo, descriptorName, 1);
2280 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2281
2282 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2283 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2284
2285 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002286 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002287 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002288 }
2289
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002290 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002291 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002292 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002293 }
2294
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002295 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002296}
2297
2298void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2299{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002300 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002301
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002302 ValidateNumInputs(workloadInfo, descriptorName, 1);
2303 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2304
2305 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2306 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2307
2308 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002309 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002310 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002311 }
2312
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002313 if (outputTensorInfo.GetDataType() != DataType::Float32)
2314 {
2315 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2316 }
2317
2318 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002319}
2320
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002321void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2322{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002323 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002324
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002325 ValidateNumInputs(workloadInfo, descriptorName, 2);
2326 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2327
2328 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2329 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2330 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2331
2332 std::vector<DataType> supportedTypes =
2333 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002334 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002335 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002336 DataType::Float32,
2337 DataType::QAsymmS8,
2338 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002339 DataType::QSymmS16,
2340 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002341 };
2342
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002343 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2344 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2345 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002346
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002347 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2348 inputTensorInfo1,
2349 outputTensorInfo,
2350 descriptorName,
2351 "input_0",
2352 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002353}
2354
David Beckc2044fe2018-09-05 15:00:38 +01002355void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2356{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002357 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002358
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002359 ValidateNumInputs(workloadInfo, descriptorName, 2);
2360 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2361
2362 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2363 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2364 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2365
2366 std::vector<DataType> supportedTypes =
2367 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002368 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002369 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002370 DataType::Float32,
2371 DataType::QAsymmS8,
2372 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002373 DataType::QSymmS16,
2374 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002375 };
2376
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002377 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2378 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2379 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002380
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002381 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2382 inputTensorInfo1,
2383 outputTensorInfo,
2384 descriptorName,
2385 "input_0",
2386 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002387}
2388
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002389void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2390{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002391 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002392
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002393 ValidateNumInputs(workloadInfo, descriptorName, 2);
2394 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2395
2396 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2397 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2398 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2399
2400 std::vector<DataType> supportedTypes =
2401 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002402 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002403 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002404 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002405 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002406 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002407 DataType::QSymmS16,
2408 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002409 };
2410
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002411 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2412 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2413 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002414
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002415 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2416 inputTensorInfo1,
2417 outputTensorInfo,
2418 descriptorName,
2419 "input_0",
2420 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002421}
2422
narpra01a6bf9122018-09-10 09:50:09 +01002423void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2424{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002425 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002426
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002427 ValidateNumInputs(workloadInfo, descriptorName, 1);
2428 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2429
2430 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2431 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002432
2433 std::vector<DataType> supportedTypes =
2434 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002435 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002436 DataType::Float32,
2437 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002438 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002439 DataType::QAsymmU8,
2440 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002441 };
narpra01eb061912018-09-10 17:35:27 +01002442
James Conroy4d1ff582019-06-10 17:06:39 +01002443 // First check if input tensor data type is supported, then
2444 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002445 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2446 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002447
narpra0132b90462018-09-13 11:07:48 +01002448 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002449 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002450 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002451 }
narpra0132b90462018-09-13 11:07:48 +01002452 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002453 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002454 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002455 }
2456 else
2457 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002458 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002459 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002460 ValidateTensorNumDimensions(outputTensorInfo,
2461 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002462 outputDim > 0 ? outputDim : 1,
2463 "output");
2464 }
narpra01a6bf9122018-09-10 09:50:09 +01002465}
2466
jimfly012c9322a2018-09-19 10:59:49 +01002467void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2468{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002469 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002470
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002471 ValidateNumInputs(workloadInfo, descriptorName, 1);
2472 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2473
2474 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2475 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002476
jimfly012c9322a2018-09-19 10:59:49 +01002477 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002478 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2479
jimfly012c9322a2018-09-19 10:59:49 +01002480 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002481 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2482 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2483 "as there are dimensions in the input tensor that is " +
2484 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2485 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002486 }
2487}
2488
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002489void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2490{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002491 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002492
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002493 ValidateNumInputs(workloadInfo, descriptorName, 1);
2494 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002495
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002496 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2497 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2498
Sadik Armagan2208b602019-07-31 16:36:27 +01002499 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002500 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002501 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002502 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002503 DataType::Float16,
2504 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002505 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002506 DataType::QAsymmU8,
2507 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002508 };
2509
2510 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002511
Keith Davis0c2eeac2020-02-11 16:51:50 +00002512 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002513 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002514 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002515 }
2516}
2517
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002518void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2519{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002520 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002521
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002522 ValidateNumInputs(workloadInfo, descriptorName, 1);
2523 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002524
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002525 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2526 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002527
Teresa Charlinf77cab52023-06-01 16:15:13 +01002528 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_Crops.size())
2529 {
2530 throw InvalidArgumentException(descriptorName + ": Crops must contain the same number of "
2531 "dimensions as Block Shape.");
2532 }
2533
2534 if (m_Parameters.m_BlockShape.size() == 2)
2535 {
2536 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2537 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
2538 }
2539 else if (m_Parameters.m_BlockShape.size() == 1)
2540 {
2541 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 3, "input");
2542 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 3, "output");
2543 }
2544 else
2545 {
2546 throw InvalidArgumentException(descriptorName + ": Invalid Block and Crops size.");
2547 }
2548
2549 // In a 4D tensor, there will be 2 spatialDimensions (H and W), and the for loop will run twice.
2550 // In a 3D tensor, there will be 1 spatialDimensions, and the for loop will run once.
2551 unsigned int firstSpatialDimension = m_Parameters.m_DataLayout == DataLayout::NCHW ? 2 : 1;
2552 for (unsigned int i = 0; i < m_Parameters.m_BlockShape.size(); ++i)
2553 {
2554 unsigned int spatialDimension = firstSpatialDimension + i;
2555 unsigned int cropSize = m_Parameters.m_Crops[i].first + m_Parameters.m_Crops[i].second;
2556 unsigned int outputSize = inputTensorInfo.GetShape()[spatialDimension] * m_Parameters.m_BlockShape[i];
2557 if (cropSize > outputSize)
2558 {
2559 throw InvalidArgumentException(descriptorName + ": CropSize must be less than or equal to the uncropped"
2560 "outputSize in dimension: " + to_string(spatialDimension) + ".");
2561 }
2562 }
2563
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002564 std::vector<DataType> supportedTypes =
2565 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002566 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002567 DataType::Float32,
2568 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002569 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002570 DataType::QAsymmU8,
2571 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002572 };
2573
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002574 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2575 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002576}
2577
Conor Kennedy430b5d82018-11-14 15:28:28 +00002578void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2579{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002580 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002581
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002582 ValidateNumInputs(workloadInfo, descriptorName, 1);
2583 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2584
2585 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2586 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002587
2588 std::vector<DataType> supportedTypes =
2589 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002590 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002591 DataType::Float16,
2592 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002593 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002594 DataType::QAsymmU8,
2595 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002596 };
2597
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002598 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2599 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002600
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002601 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002602
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002603 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002604 if (rank > 4)
2605 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002606 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002607 }
2608
Conor Kennedy430b5d82018-11-14 15:28:28 +00002609 // Begin, End & Stride length must be of rank(input0)
2610 if (m_Parameters.m_Begin.size() != rank)
2611 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002612 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002613 }
2614
2615 if (m_Parameters.m_End.size() != rank)
2616 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002617 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002618 }
2619
2620 if (m_Parameters.m_Stride.size() != rank)
2621 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002622 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002623 }
2624
2625 // Stride entries must be non-zero
2626 for (auto& stride : m_Parameters.m_Stride)
2627 {
2628 if (stride == 0)
2629 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002630 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002631 }
2632 }
2633}
2634
kevmay0190539692018-11-29 08:40:19 +00002635void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2636{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002637 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002638
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002639 ValidateNumInputs(workloadInfo, descriptorName, 2);
2640 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2641
2642 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2643 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2644 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2645
2646 std::vector<DataType> supportedTypes =
2647 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002648 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002649 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002650 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002651 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002652 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002653 DataType::QSymmS16,
2654 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002655 };
2656
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002657 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2658 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2659 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002660
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002661 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2662 inputTensorInfo1,
2663 outputTensorInfo,
2664 descriptorName,
2665 "input_0",
2666 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002667}
2668
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002669void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2670{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002671 const std::string descriptorName{"DebugQueueDescriptor"};
2672
2673 ValidateNumInputs(workloadInfo, descriptorName, 1);
2674 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002675}
2676
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002677void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2678{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002679 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002680
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002681 ValidateNumInputs(workloadInfo, descriptorName, 2);
2682 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002683
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002684 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2685 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2686 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2687
2688 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2689 inputTensorInfo1,
2690 outputTensorInfo,
2691 descriptorName,
2692 "input_0",
2693 "input_1");
2694
2695 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002696 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002697 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002698 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002699}
2700
FrancisMurtagh878f0232018-12-19 10:56:15 +00002701void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2702{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002703 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002704
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002705 ValidateNumInputs(workloadInfo, descriptorName, 2);
2706 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002707
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002708 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2709 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2710 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2711
2712 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2713 inputTensorInfo1,
2714 outputTensorInfo,
2715 descriptorName,
2716 "input_0",
2717 "input_1");
2718
2719 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002720 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002721 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002722 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002723}
2724
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002725void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2726{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002727 const std::string descriptorName{"RsqrtQueueDescriptor"};
2728
2729 ValidateNumInputs(workloadInfo, descriptorName, 1);
2730 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2731
2732 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2733 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2734
2735 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002736
2737 std::vector<DataType> supportedTypes =
2738 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002739 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002740 DataType::Float16,
2741 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002742 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002743 DataType::QAsymmU8,
2744 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002745 };
2746
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002747 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2748 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002749}
2750
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01002751void GatherNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2752{
2753 const std::string descriptorName{"GatherNdQueueDescriptor"};
2754
2755 ValidateNumInputs(workloadInfo, descriptorName, 2);
2756 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2757
2758 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2759 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
2760 {
2761 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
2762 }
2763
2764 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2765 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2766
2767 std::vector<DataType> supportedTypes =
2768 {
2769 DataType::BFloat16,
2770 DataType::Float16,
2771 DataType::Float32,
2772 DataType::QAsymmS8,
2773 DataType::QAsymmU8,
2774 DataType::QSymmS16,
2775 DataType::Signed32,
2776 };
2777
2778 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2779
2780 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2781
2782 unsigned int outputDim = outputTensorInfo.GetNumDimensions();
2783 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
2784}
2785
narpra01b89b05f2019-01-16 09:53:09 +00002786void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2787{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002788 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002789
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002790 ValidateNumInputs(workloadInfo, descriptorName, 2);
2791 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002792
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002793 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2794 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002795 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002796 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002797 }
2798
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002799 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2800 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2801
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002802 std::vector<DataType> supportedTypes =
2803 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002804 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002805 DataType::Float16,
2806 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002807 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002808 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002809 DataType::QSymmS16,
2810 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002811 };
2812
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002813 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002814
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002815 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002816
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002817 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2818 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002819}
2820
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002821void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2822{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002823 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2824
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002825 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002826
2827 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2828 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002829 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002830 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2831 }
2832
2833 if (m_Anchors == nullptr)
2834 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002835 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002836 }
2837
2838 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002839 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2840 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2841
2842 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002843 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002844 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2845 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002846
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002847 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2848 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2849 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002850
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002851 const std::vector<DataType> supportedInputTypes =
2852 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002853 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002854 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002855 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002856 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002857 DataType::QAsymmU8,
2858 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002859 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002860
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002861 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2862 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2863 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2864
2865 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2866 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2867 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2868 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2869
2870 // NOTE: Output is always Float32 regardless of input type
2871 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2872 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2873 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2874 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002875
2876 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2877 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002878 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002879 "must be positive and less than or equal to 1.");
2880 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002881
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002882 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2883 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002884 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002885 "should be equal to number of classes + 1.");
2886 }
2887}
2888
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002889void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2890{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002891 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002892
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002893 ValidateNumInputs(workloadInfo, descriptorName, 1);
2894 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2895
2896 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2897 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2898
Teresa Charlin07307f32022-05-15 14:07:05 +01002899 std::vector<DataType> inputSupportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002900 {
Teresa Charlin07307f32022-05-15 14:07:05 +01002901 DataType::QAsymmS8,
2902 DataType::QAsymmU8,
2903 DataType::QSymmS8,
2904 DataType::QSymmS16,
2905 DataType::Float16
2906 };
2907 ValidateDataTypes(inputTensorInfo, inputSupportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002908
Teresa Charlin07307f32022-05-15 14:07:05 +01002909 std::vector<DataType> outputSupportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002910 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002911 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002912 DataType::Float32,
2913 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002914 };
2915
Teresa Charlin07307f32022-05-15 14:07:05 +01002916 ValidateDataTypes(outputTensorInfo, outputSupportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002917}
2918
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002919void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2920{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002921 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002922
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002923 ValidateNumInputs(workloadInfo, descriptorName, 2);
2924 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002925
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002926 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2927 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2928 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002929
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002930 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2931 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2932
2933 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2934 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002935}
2936
Keith Davis3ae3f972021-05-21 16:33:48 +01002937void ShapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2938{
2939 const std::string& descriptorName{"ShapeQueueDescriptor"};
2940
2941 ValidateNumInputs(workloadInfo, descriptorName, 1);
2942 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2943
2944 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2945 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2946
2947 std::vector<DataType> supportedTypes =
2948 {
2949 DataType::BFloat16,
2950 DataType::Float16,
2951 DataType::Float32,
2952 DataType::QAsymmS8,
2953 DataType::QAsymmU8,
Keith Davis3ae3f972021-05-21 16:33:48 +01002954 DataType::QSymmS8,
2955 DataType::QSymmS16,
2956 DataType::Signed32
2957 };
2958
2959 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2960 ValidateDataTypes(outputTensorInfo, {DataType::Signed32}, descriptorName);
2961}
2962
Sadik Armaganeff363d2019-04-05 15:25:46 +01002963void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2964{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002965 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002966
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002967 ValidateNumInputs(workloadInfo, descriptorName, 2);
2968 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2969
2970 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2971 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2972
2973 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2974 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2975
2976 std::vector<DataType> supportedTypes =
2977 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002978 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002979 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002980 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002981 DataType::QAsymmU8,
2982 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002983 };
2984
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002985 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2986 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002987
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002988 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2989 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002990
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002991 ValidateTensorShapesMatch(inputTensorInfo0,
2992 outputTensorInfo0,
2993 descriptorName,
2994 "input_0",
2995 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002996
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002997 ValidateTensorShapesMatch(inputTensorInfo0,
2998 outputTensorInfo1,
2999 descriptorName,
3000 "input_0",
3001 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01003002}
3003
Derek Lamberti901ea112019-12-10 22:07:09 +00003004void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00003005{
3006 // This is internally generated so it should not need validation.
3007}
3008
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003009void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3010{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003011 const std::string& descriptorName{"PreluQueueDescriptor"};
3012
3013 ValidateNumInputs(workloadInfo, descriptorName, 2);
3014 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3015
3016 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3017 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
3018 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003019
3020 std::vector<DataType> supportedTypes
3021 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003022 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003023 DataType::Float16,
3024 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003025 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003026 DataType::QAsymmU8,
3027 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003028 };
3029
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003030 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3031 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003032
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003033 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003034
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003035 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
3036 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003037
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003038 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
3039 alphaTensorInfo,
3040 outputTensorInfo,
3041 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003042 "input",
3043 "alpha");
3044}
3045
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003046void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3047{
3048 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
3049
3050 ValidateNumInputs(workloadInfo, descriptorName, 1);
3051 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3052
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003053 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3054 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3055
3056 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
3057 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003058
3059 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003060
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003061 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
3062 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003063
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003064 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
3065
3066 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003067 if (m_Parameters.m_BiasEnabled)
3068 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003069 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003070
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003071 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
3072 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003073
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003074 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Ryan OSheaf183acd2023-07-06 11:41:25 +01003075 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003076 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003077
3078 ValidatePerAxisQuantization(inputTensorInfo,
3079 outputTensorInfo,
3080 weightTensorInfo,
3081 optionalBiasTensorInfo,
3082 descriptorName);
3083
3084 std::vector<DataType> supportedTypes =
3085 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003086 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003087 DataType::Float32,
3088 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003089 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003090 DataType::QAsymmU8,
3091 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003092 };
3093
3094 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3095 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003096}
3097
Mike Kellyc9ea45a2020-02-28 18:11:58 +00003098void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3099{
3100 const std::string descriptorName{"TransposeQueueDescriptor"};
3101
3102 ValidateNumInputs(workloadInfo, descriptorName, 1);
3103 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3104
3105 const PermutationVector& mapping = m_Parameters.m_DimMappings;
3106
3107 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3108 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3109
3110 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
3111 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
3112
3113 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
3114 {
3115 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
3116 {
3117 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
3118 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
3119 "must match dst dimension " + to_string(i) +
3120 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
3121 }
3122 }
3123
3124 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3125}
3126
Simon Obute51f67772021-09-03 15:50:13 +01003127void ChannelShuffleQueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
3128{
3129 const std::string descriptorName{"TransposeQueueDescriptor"};
3130
3131 ValidateNumInputs(workloadInfo, descriptorName, 1);
3132 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3133
3134 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3135 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3136
3137 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3138}
3139
James Conroy4f1f8992020-04-29 20:01:10 +01003140void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3141{
3142 const std::string descriptorName{"QLstmQueueDescriptor"};
3143
3144 // Validate number of inputs/outputs
3145 ValidateNumInputs(workloadInfo, descriptorName, 3);
3146 ValidateNumOutputs(workloadInfo, descriptorName, 3);
3147
3148 // Input/output tensor info
3149 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3150 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
3151 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
3152
3153 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3154 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3155 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
3156
3157 // Supported types for various tensors in QLSTM
3158 std::vector<DataType> inputOutputSupportedTypes =
3159 {
3160 DataType::QAsymmS8
3161 };
3162
3163 std::vector<DataType> cellStateSupportedTypes =
3164 {
3165 DataType::QSymmS16
3166 };
3167
3168 std::vector<DataType> weightsSupportedTypes =
3169 {
3170 DataType::QSymmS8
3171 };
3172
3173 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3174 {
3175 DataType::QSymmS16
3176 };
3177
3178 std::vector<DataType> biasSupportedTypes =
3179 {
3180 DataType::Signed32
3181 };
3182
3183 // Validate types of input/output tensors
3184 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3185 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3186 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3187
3188 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3189 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3190 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3191
3192 // Validate matching types of input/output tensors
3193 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3194 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3195 "outputStateIn", "outputStateOut");
3196 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3197
3198 // Infer number of batches, number of units, input size and output size from tensor dimensions
3199 const uint32_t numBatches = inputInfo.GetShape()[0];
3200 const uint32_t inputSize = inputInfo.GetShape()[1];
3201 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3202 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3203
3204 // Validate number of dimensions and number of elements for input/output tensors
3205 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3206 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3207 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3208
3209 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3210 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3211 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3212
3213 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3214 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3215 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3216 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3217
3218 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3219 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3220 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3221
3222 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3223 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3224 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3225
3226 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3227 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3228 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3229 " RecurrentToForgetWeights");
3230
3231 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3232 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3233 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3234
3235 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3236 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3237 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3238
3239 // Validate data types for MANDATORY weights tensors (all should match each other)
3240 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3241
3242 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3243 "inputToForgetWeights", "inputToCellWeights");
3244 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3245 "inputToForgetWeights", "inputToOutputWeights");
3246
3247 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3248 "inputToForgetWeights", "recurrentToForgeteights");
3249 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3250 "inputToForgetWeights", "recurrentToCellWeights");
3251 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3252 "inputToForgetWeights", "recurrentToOutputWeights");
3253
3254 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3255 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3256 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3257 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3258
3259 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3260 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3261 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3262
3263 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3264 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3265 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3266
3267 // Validate data types for MANDATORY bias tensors
3268 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3269
3270 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3271 "forgetGateBias", "cellBias");
3272 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3273 "forgetGateBias", "outputGateBias");
3274
3275 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3276 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3277 !m_Parameters.m_CifgEnabled) ||
3278 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3279 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3280
3281 if (!allCifgParamsPresentOrNot)
3282 {
3283 throw InvalidArgumentException(descriptorName +
3284 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3285 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3286 "set appropriately.");
3287 }
3288
3289 if (!m_Parameters.m_CifgEnabled)
3290 {
3291 // Validate number of dimensions and number of elements
3292 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3293 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3294
3295 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3296 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3297 " RecurrentToInputWeights");
3298
3299 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3300 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3301
3302 // Validate data types
3303 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3304 "inputToForgetWeights", "inputToInputWeights");
3305 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3306 "inputToForgetWeights", "recurrentToInputWeights");
3307 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3308 "forgetGateBias", "inputGateBias");
3309 }
3310
3311 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3312 bool allPeepholeWeightsPresentOrNot =
3313 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3314 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3315 || (!m_CellToInputWeights && !m_CellToForgetWeights
3316 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3317
3318 if (!allPeepholeWeightsPresentOrNot)
3319 {
3320 throw InvalidArgumentException(descriptorName +
3321 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3322 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3323 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3324 "appropriately.");
3325 }
3326
3327 if (m_Parameters.m_PeepholeEnabled)
3328 {
3329 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3330 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3331 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3332
3333 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3334 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3335 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3336 "cellToForgetWeight", "cellToOutputWeights");
3337
3338 if (!m_Parameters.m_CifgEnabled)
3339 {
3340 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3341 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3342 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3343 "cellToForgetWeights", "cellToInputWeights");
3344 }
3345 }
3346
3347 // Validate OPTIONAL params: Layer Norm Weights
3348 bool allLayerNormWeightsPresentOrNot =
3349 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3350 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3351 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3352 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3353
3354 if (!allLayerNormWeightsPresentOrNot)
3355 {
3356 throw InvalidArgumentException(descriptorName +
3357 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3358 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3359 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3360 "only be present when Layer Norm is enabled and CIFG is disabled. "
3361 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3362 }
3363
3364 if (m_Parameters.m_LayerNormEnabled)
3365 {
3366 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3367 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3368 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3369
3370 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3371 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3372 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3373 "forgetLayerNormWeights", "cellLayerNormWeights");
3374
3375 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3376 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3377 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3378 "forgetLayerNormWeights", "outputLayerNormWeights");
3379
3380 if (!m_Parameters.m_CifgEnabled)
3381 {
3382 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3383 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3384 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3385 "forgetLayerNormWeights", "inputLayerNormWeights");
3386 }
3387 }
3388
3389 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3390 bool correctProjectionTensorsPresent =
3391 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3392 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3393 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3394
3395 if (!correctProjectionTensorsPresent)
3396 {
3397 throw InvalidArgumentException(descriptorName +
3398 ": If projection is enabled, ProjectionWeights should be present and "
3399 "ProjectionBias is optional. If projection is disabled, neither "
3400 "ProjectionWeights nor ProjectionBias should be present.");
3401 }
3402
3403 if (m_Parameters.m_ProjectionEnabled)
3404 {
3405 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3406 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3407 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3408
3409 if (m_ProjectionBias)
3410 {
3411 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003412 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003413 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3414 }
3415
3416 }
3417 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3418 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3419 throw InvalidArgumentException(descriptorName +
3420 ": If projection is disabled, output quantization info (scale, offset) "
3421 "should match HiddenStateScale and HiddenStateZeroPoint.");
3422 }
3423
3424}
3425
James Conroy9c3cae82019-08-01 16:01:48 +01003426void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3427{
3428 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3429
3430 // Validate number of inputs/outputs
3431 ValidateNumInputs(workloadInfo, descriptorName, 3);
3432 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3433
3434 // Input/output tensor infos
3435 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3436 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3437 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3438
3439 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3440 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3441
3442 std::vector<DataType> inputOutputSupportedTypes =
3443 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003444 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003445 };
3446
3447 std::vector<DataType> cellStateSupportedTypes =
3448 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003449 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003450 };
3451
3452 std::vector<DataType> weightsSupportedTypes =
3453 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003454 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003455 };
3456
3457 std::vector<DataType> biasSupportedTypes =
3458 {
3459 DataType::Signed32
3460 };
3461
3462 // Validate types of input/output tensors
3463 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3464 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3465 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3466
3467 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3468 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3469
3470 // Validate matching types of input/output tensors
3471 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3472 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3473 "outputStateIn", "outputStateOut");
3474 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3475
3476 // Validate matching quantization info for input/output tensors
3477 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3478 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3479 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003480
James Conroy9c3cae82019-08-01 16:01:48 +01003481 // Infer number of batches, input size and output size from tensor dimensions
3482 const uint32_t numBatches = inputInfo.GetShape()[0];
3483 const uint32_t inputSize = inputInfo.GetShape()[1];
3484 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3485
3486 // Validate number of dimensions and number of elements for input/output tensors
3487 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3488 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3489 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3490 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3491 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3492
3493 // Validate number of dimensions and number of elements for weights tensors
3494 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3495 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3496 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3497
3498 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3499 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3500 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3501
3502 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3503 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3504 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3505
3506 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3507 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3508 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3509
3510 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3511 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3512 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3513
3514 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3515 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3516 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3517 " RecurrentToForgetWeights");
3518
3519 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3520 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3521 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3522
3523 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3524 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3525 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3526
3527 // Validate data types for weights tensors (all should match each other)
3528 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3529
3530 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3531 "inputToInputWeights", "inputToForgetWeights");
3532 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3533 "inputToInputWeights", "inputToCellWeights");
3534 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3535 "inputToInputWeights", "inputToOutputWeights");
3536
3537 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3538 "inputToInputWeights", "recurrentToInputWeights");
3539 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3540 "inputToInputWeights", "recurrentToForgeteights");
3541 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3542 "inputToInputWeights", "recurrentToCellWeights");
3543 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3544 "inputToInputWeights", "recurrentToOutputWeights");
3545
3546 // Validate matching quantization info for weight tensors (all should match each other)
3547 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3548 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3549 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3550 descriptorName, "inputToInputWeights", "inputToCellWeights");
3551 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3552 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3553
3554 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3555 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3556 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3557 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3558 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3559 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3560 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3561 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3562
3563 // Validate number of dimensions and number of elements in bias tensors
3564 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3565 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3566 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3567
3568 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3569 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3570 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3571
3572 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3573 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3574 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3575
3576 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3577 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3578 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3579
3580 // Validate data types for bias tensors (all should match each other)
3581 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3582
3583 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3584 "inputGateBias", "forgetGateBias");
3585 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3586 "inputGateBias", "cellBias");
3587 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3588 "inputGateBias", "outputGateBias");
3589
3590 // Validate bias tensor quantization info
Ryan OSheaf183acd2023-07-06 11:41:25 +01003591 ValidateBiasTensorQuantization(inputGateBiasInfo, inputToInputWeightsInfo, descriptorName);
3592 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputToInputWeightsInfo, descriptorName);
3593 ValidateBiasTensorQuantization(cellBiasInfo, inputToInputWeightsInfo, descriptorName);
3594 ValidateBiasTensorQuantization(outputGateBiasInfo, inputToInputWeightsInfo, descriptorName);
James Conroy9c3cae82019-08-01 16:01:48 +01003595}
3596
Kevin May868eb142019-09-04 17:29:31 +01003597void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3598{
3599 const std::string descriptorName{"AbsQueueDescriptor"};
3600
3601 ValidateNumInputs(workloadInfo, descriptorName, 1);
3602 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3603
3604 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3605 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3606
3607 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3608
3609 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003610 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003611 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003612 DataType::Float16,
3613 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003614 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003615 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003616 DataType::QSymmS16,
3617 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003618 };
Kevin May868eb142019-09-04 17:29:31 +01003619
3620 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3621 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3622}
3623
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003624void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3625{
3626 const std::string descriptorName{"SliceQueueDescriptor"};
3627
3628 ValidateNumInputs(workloadInfo, descriptorName, 1);
3629 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3630
3631 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3632 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3633
3634 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3635
3636 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3637 if (rank > 4)
3638 {
3639 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3640 }
3641
3642 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3643
3644 // Check if m_Begin and m_Size have the expected length
3645 if (m_Parameters.m_Begin.size() != rank)
3646 {
3647 throw InvalidArgumentException(descriptorName +
3648 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3649 }
3650 if (m_Parameters.m_Size.size() != rank)
3651 {
3652 throw InvalidArgumentException(descriptorName +
3653 ": Length of size descriptor must equal rank " + std::to_string(rank));
3654 }
3655
3656 // Check if the shape of the output tensor matches m_Size
3657 const TensorShape& outputShape = outputTensorInfo.GetShape();
3658 for (unsigned int i = 0u; i < rank; ++i)
3659 {
3660 if (m_Parameters.m_Size[i] != outputShape[i])
3661 {
3662 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3663 }
3664 }
3665
3666 // Check if the sum of begin offset and size in a given dimension
3667 // does not exceed the size of corresponding input
3668 const TensorShape& inputShape = inputTensorInfo.GetShape();
3669 for(unsigned int i = 0u; i < rank; ++i)
3670 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003671 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003672 {
3673 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3674 std::to_string(i) + " exceeds input size.");
3675 }
3676 }
3677}
3678
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003679void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3680{
3681 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3682
3683 ValidateNumInputs(workloadInfo, descriptorName, 1);
3684 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3685
3686 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3687 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3688
3689 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3690 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3691
3692 std::vector<DataType> supportedTypes =
3693 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003694 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003695 DataType::Float32,
3696 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003697 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003698 DataType::QAsymmU8,
3699 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003700 };
3701
3702 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3703 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3704
3705 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3706
3707 if (m_Parameters.m_BlockSize == 0)
3708 {
3709 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3710 }
3711
3712 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3713 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3714 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3715 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3716
3717 const TensorShape& outputShape = outputInfo.GetShape();
3718 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3719 {
3720 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3721 "must be divisible by block size.");
3722 }
3723
3724 const TensorShape& inputShape = inputInfo.GetShape();
3725 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3726 {
3727 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3728 "must be divisible by the square of block size." );
3729 }
3730}
3731
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003732void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3733{
3734 const std::string descriptorName{"ComparisonQueueDescriptor"};
3735
3736 ValidateNumInputs(workloadInfo, descriptorName, 2);
3737 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3738
3739 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3740 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3741 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3742
3743 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3744 inputTensorInfo1,
3745 outputTensorInfo,
3746 descriptorName,
3747 "input_0",
3748 "input_1");
3749
3750 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3751 {
3752 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3753 }
3754}
3755
Mike Kelly3ec30772023-03-08 13:47:17 +00003756void ElementwiseBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3757{
3758 const std::string descriptorName{"ElementwiseBinaryQueueDescriptor"};
3759
3760 ValidateNumInputs(workloadInfo, descriptorName, 2);
3761 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3762
3763 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3764 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3765 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3766
3767 std::vector<DataType> supportedTypes =
3768 {
3769 DataType::BFloat16,
3770 DataType::Float16,
3771 DataType::Float32,
3772 DataType::QAsymmS8,
3773 DataType::QAsymmU8,
3774 DataType::QSymmS16,
3775 DataType::Signed32
3776 };
3777
3778 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
3779 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
3780
3781 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input", "output");
3782 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input", "output");
3783}
3784
josh minor4a3c6102020-01-06 16:40:46 -06003785void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3786{
3787 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3788
3789 ValidateNumInputs(workloadInfo, descriptorName, 1);
3790 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3791
3792 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3793 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3794
3795 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3796
3797 std::vector<DataType> supportedTypes =
3798 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003799 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003800 DataType::Float16,
3801 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003802 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003803 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003804 DataType::QSymmS16,
3805 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003806 };
3807
James Conroyaba90cd2020-11-06 16:28:18 +00003808 std::vector<DataType> logicalSupportedTypes =
3809 {
3810 DataType::Boolean
3811 };
3812
3813 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3814 {
3815 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3816 }
3817 else
3818 {
3819 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3820 }
3821
3822
josh minor4a3c6102020-01-06 16:40:46 -06003823 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3824}
3825
Finn Williams2605b232020-06-10 15:53:46 +01003826void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3827{
3828 const std::string descriptorName{"RankQueueDescriptor"};
3829
3830 ValidateNumInputs(workloadInfo, descriptorName, 1);
3831 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3832
3833 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3834 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3835
3836 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3837 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3838
3839 std::vector<DataType> supportedTypes =
3840 {
3841 DataType::BFloat16,
3842 DataType::Float16,
3843 DataType::Float32,
3844 DataType::QAsymmS8,
3845 DataType::QAsymmU8,
3846 DataType::QSymmS8,
3847 DataType::QSymmS16,
3848 DataType::Signed32
3849 };
3850
3851 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3852 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3853}
3854
James Conroyaba90cd2020-11-06 16:28:18 +00003855void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3856{
3857 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3858
3859 ValidateNumInputs(workloadInfo, descriptorName, 2);
3860 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3861
3862 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3863 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3864 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3865
3866 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3867 inputTensorInfo1,
3868 outputTensorInfo,
3869 descriptorName,
3870 "input_0",
3871 "input_1");
3872
3873 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3874 {
3875 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3876 }
3877
3878 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3879 {
3880 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3881 }
3882
3883 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3884 {
3885 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3886 }
3887}
3888
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003889void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3890{
3891 const std::string descriptorName{"ReduceQueueDescriptor"};
3892
3893 ValidateNumInputs(workloadInfo, descriptorName, 1);
3894 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3895
3896 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3897 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3898
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003899 std::vector<DataType> supportedTypes =
3900 {
3901 DataType::BFloat16,
3902 DataType::Float16,
3903 DataType::Float32,
3904 DataType::QAsymmS8,
3905 DataType::QAsymmU8,
3906 DataType::QSymmS16,
3907 DataType::Signed32
3908 };
3909
3910 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3911 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3912}
3913
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003914void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3915{
3916 // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm
3917
3918 const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"};
3919
3920 // check dimensions of all inputs and outputs
3921 if (workloadInfo.m_InputTensorInfos.size() != 3)
3922 {
3923 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
3924 }
Mike Kelly12994962022-04-21 11:57:09 +01003925 if (workloadInfo.m_OutputTensorInfos.size() != 3)
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003926 {
3927 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
3928 }
3929
3930 std::vector<DataType> supportedTypes =
3931 {
Mike Kelly12994962022-04-21 11:57:09 +01003932 DataType::Float32,
3933 DataType::QAsymmS8
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003934 };
3935
3936 // check for supported type of one input and match them with all the other input and output
3937 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
3938
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003939 // Making sure clipping parameters have valid values.
3940 // == 0 means no clipping
3941 // > 0 means clipping
3942 if (m_Parameters.m_ClippingThresCell < 0.0f)
3943 {
3944 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
3945 }
3946 if (m_Parameters.m_ClippingThresProj < 0.0f)
3947 {
3948 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
3949 }
3950
3951 unsigned int batchIndx = 0;
3952 unsigned int inputIndx = 1;
3953 uint32_t timeStep = 1;
3954 unsigned int timeIndx = 1;
3955 inputIndx = 2;
3956 if (m_Parameters.m_TimeMajor)
3957 {
3958 batchIndx = 1;
3959 timeIndx = 0;
3960
3961 }
3962 timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx];
3963
3964 // Inferring batch size, number of outputs and number of cells from the inputs.
3965 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx];
3966 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx];
3967 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
3968 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
3969 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
3970 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
3971
3972 // input tensor
3973 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input),
3974 descriptorName + " input_0");
3975 // outputStateInTensor
3976 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
3977 descriptorName + " input_1");
3978 // outputStateInTensor
3979 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
3980 descriptorName + " input_2");
3981
3982 // outputTensor
Mike Kelly12994962022-04-21 11:57:09 +01003983 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 3, (timeStep * n_batch * n_output),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003984 descriptorName + " output_0");
3985
3986 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
3987 if ( m_InputToInputWeights )
3988 {
3989 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
3990 (n_cell * n_input), "InputLayerNormWeights");
3991 }
3992
3993 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
3994 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
3995 (n_cell * n_input), "InputToForgetWeights");
3996
3997 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
3998 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
3999 (n_cell * n_input), "InputToCellWeights");
4000
4001 if ( m_RecurrentToInputWeights )
4002 {
4003 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
4004 (n_cell * n_output), "RecurrentToInputWeights");
4005 }
4006
4007 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
4008 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
4009 (n_cell * n_output), "RecurrentToForgetWeights");
4010
4011 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
4012 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
4013 (n_cell * n_output), "RecurrentToCellWeights");
4014
4015 // Make sure the input-gate's parameters are either both present (regular
4016 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
4017 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
4018 !m_Parameters.m_CifgEnabled) ||
4019 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
4020 m_Parameters.m_CifgEnabled));
4021 if (!cifg_weights_all_or_none)
4022 {
4023 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
4024 "RecurrentToInputWeights must either both be present (regular LSTM) "
4025 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
4026 "accordingly.");
4027 }
4028
4029 if ( m_CellToInputWeights )
4030 {
4031 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
4032 n_cell, "CellToInputWeights");
4033 }
4034 if ( m_CellToForgetWeights )
4035 {
4036 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
4037 n_cell, "CellToForgetWeights");
4038 }
4039 if ( m_CellToOutputWeights )
4040 {
4041 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
4042 n_cell, "CellToOutputWeights");
4043 }
4044
4045 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
4046 bool peephole_weights_all_or_none =
4047 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
4048 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
4049 || ( !m_CellToInputWeights && !m_CellToForgetWeights
4050 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
4051 if (!peephole_weights_all_or_none)
4052 {
4053 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
4054 }
4055
4056 // Make sure the input gate bias is present only when not a CIFG-LSTM.
4057 if (m_Parameters.m_CifgEnabled)
4058 {
4059 if (m_InputGateBias)
4060 {
4061 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
4062 }
4063 }
4064 else
4065 {
4066 if (!m_InputGateBias)
4067 {
4068 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
4069 "must be present.");
4070 }
4071 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
4072 n_cell, "InputGateBias");
4073 }
4074
4075 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
4076 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
4077
4078 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
4079 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
4080
4081 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
4082 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
4083
4084 if (m_ProjectionWeights)
4085 {
4086 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
4087 (n_cell * n_output), "ProjectionWeights");
4088 }
4089 if (m_ProjectionBias)
4090 {
4091 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
4092 }
4093
4094 // Making sure the projection tensors are consistent:
4095 // 1) If projection weight is not present, then projection bias should not be
4096 // present.
4097 // 2) If projection weight is present, then projection bias is optional.
4098 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
4099 !m_Parameters.m_ProjectionEnabled)
4100 || (m_ProjectionWeights && !m_ProjectionBias &&
4101 m_Parameters.m_ProjectionEnabled)
4102 || (m_ProjectionWeights && m_ProjectionBias &&
4103 m_Parameters.m_ProjectionEnabled));
4104 if (!projecton_tensors_consistent)
4105 {
4106 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
4107 }
4108
4109 // The four layer normalization weights either all have values or none of them have values. Additionally, if
4110 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
4111 // either all have values or none of them have values. Layer normalization is used when the values of all the
4112 // layer normalization weights are present
4113 if (m_InputLayerNormWeights)
4114 {
4115 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
4116 }
4117 if (m_ForgetLayerNormWeights)
4118 {
4119 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4120 }
4121 if (m_CellLayerNormWeights)
4122 {
4123 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4124 }
4125 if (m_OutputLayerNormWeights)
4126 {
4127 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4128 }
4129
4130 if (m_Parameters.m_LayerNormEnabled)
4131 {
4132 if (!m_Parameters.m_CifgEnabled)
4133 {
4134 if (!m_InputLayerNormWeights)
4135 {
4136 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
4137 "disabled but InputLayerNormWeights are not present");
4138 }
4139 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
4140 1, n_cell, "InputLayerNormWeights");
4141 }
4142 else if (m_InputLayerNormWeights)
4143 {
4144 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
4145 "enabled");
4146 }
4147
4148 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
4149 "ForgetLayerNormWeights");
4150 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4151
4152 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
4153 "OutputLayerNormWeights");
4154 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4155
4156 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
4157 "CellLayerNormWeights");
4158 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4159 }
4160 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
4161 {
4162 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
4163 "normalisation weights are present.");
4164 }
4165}
4166
Samuel Yap6b478092022-07-06 15:36:03 +01004167void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4168{
4169 const std::string descriptorName{"BatchMatMulDescriptor"};
4170
4171 ValidateNumInputs(workloadInfo, descriptorName, 2);
4172 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4173
4174 // Inputs must be: both 2D+
4175 // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
4176 // axes N and I must be the same size
4177
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004178 const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
4179 const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
4180 const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
4181 // Output info has already been inferred
Samuel Yap6b478092022-07-06 15:36:03 +01004182
4183 std::vector<DataType> supportedTypes =
4184 {
4185 DataType::BFloat16,
4186 DataType::Float16,
4187 DataType::Float32,
4188 DataType::QAsymmS8,
4189 DataType::QAsymmU8,
4190 DataType::QSymmS16
4191 };
4192
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004193 ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
4194 ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
4195 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
Samuel Yap6b478092022-07-06 15:36:03 +01004196
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004197 if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
4198 (inputYInfoBeforeParams.GetNumDimensions() < 2))
Samuel Yap6b478092022-07-06 15:36:03 +01004199 {
4200 throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
4201 }
4202
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004203 TensorInfo inputXInfoAfterParams;
4204 TensorInfo inputYInfoAfterParams;
4205
4206 if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) ||
4207 (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY))
Samuel Yap6b478092022-07-06 15:36:03 +01004208 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004209 throw InvalidArgumentException(descriptorName +
4210 ": Invalid descriptor parameters - Transpose and Adjoint "
4211 "cannot both be true for a given input tensor.");
4212 }
4213 if(m_Parameters.m_TransposeX)
4214 {
4215 inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
4216 BatchMatMulDescriptor::GetPermuteVec(
4217 m_Parameters.m_DataLayoutX,
4218 inputXInfoBeforeParams.GetShape()));
4219 }
4220 else if(m_Parameters.m_AdjointX)
4221 {
4222 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4223 inputXInfoBeforeParams.GetShape());
4224 if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
4225 inputXInfoBeforeParams.GetShape()[axesToMul.second])
Samuel Yap6b478092022-07-06 15:36:03 +01004226 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004227 throw InvalidArgumentException(descriptorName +
4228 ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
Samuel Yap6b478092022-07-06 15:36:03 +01004229 }
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004230 // Shape remains the same as it's square
4231 inputXInfoAfterParams = inputXInfoBeforeParams;
4232 }
4233 else
4234 {
4235 inputXInfoAfterParams = inputXInfoBeforeParams;
Samuel Yap6b478092022-07-06 15:36:03 +01004236 }
4237
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004238 if(m_Parameters.m_TransposeY)
Samuel Yap6b478092022-07-06 15:36:03 +01004239 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004240 inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
4241 BatchMatMulDescriptor::GetPermuteVec(
4242 m_Parameters.m_DataLayoutY,
4243 inputYInfoBeforeParams.GetShape()));
4244 }
4245 else if(m_Parameters.m_AdjointY)
4246 {
4247 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4248 inputYInfoBeforeParams.GetShape());
4249 if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
4250 inputYInfoBeforeParams.GetShape()[axesToMul.second])
Samuel Yap6b478092022-07-06 15:36:03 +01004251 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004252 throw InvalidArgumentException(descriptorName +
4253 ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
Samuel Yap6b478092022-07-06 15:36:03 +01004254 }
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004255 // Shape remains the same as it's square
4256 inputYInfoAfterParams = inputYInfoBeforeParams;
4257 }
4258 else
4259 {
4260 inputYInfoAfterParams = inputYInfoBeforeParams;
Samuel Yap6b478092022-07-06 15:36:03 +01004261 }
4262
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004263 switch(m_Parameters.m_DataLayoutX)
4264 {
4265 case DataLayout::NCDHW:
4266 case DataLayout::NDHWC:
4267 if(inputXInfoAfterParams.GetNumDimensions() < 3)
4268 {
4269 throw InvalidArgumentException(descriptorName +
4270 ": Input tensor X does not have the correct "
4271 "number of dimensions for the Data Layout that it has been assigned.");
4272 }
4273 break;
4274 case DataLayout::NCHW:
4275 case DataLayout::NHWC:
4276 default:
4277 break;
4278 }
Samuel Yap6b478092022-07-06 15:36:03 +01004279
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004280 switch(m_Parameters.m_DataLayoutY)
4281 {
4282 case DataLayout::NCDHW:
4283 case DataLayout::NDHWC:
4284 if(inputYInfoAfterParams.GetNumDimensions() < 3)
4285 {
4286 throw InvalidArgumentException(descriptorName +
4287 ": Input tensor Y does not have the correct "
4288 "number of dimensions for the Data Layout that it has been assigned.");
4289 }
4290 break;
4291 case DataLayout::NCHW:
4292 case DataLayout::NHWC:
4293 default:
4294 break;
4295 }
4296
4297 auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4298 inputXInfoAfterParams.GetShape());
4299 auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4300 inputXInfoBeforeParams.GetShape());
4301
4302 if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
4303 != inputYInfoAfterParams.GetShape()[axesYToMul.first])
Samuel Yap6b478092022-07-06 15:36:03 +01004304 {
4305 throw InvalidArgumentException(descriptorName +
4306 ": The final axis of input tensor X must be the same size as "
4307 "the second last axis of input tensor Y.");
4308 }
4309
Samuel Yap6b478092022-07-06 15:36:03 +01004310 { // Separate scope so we don't pollute the rest of the scope with our temp variables
4311 // e.g. NHWC isnt compatible with NCHW as of now
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004312 DataLayout xLayout = m_Parameters.m_DataLayoutX;
4313 DataLayout yLayout = m_Parameters.m_DataLayoutY;
Samuel Yap6b478092022-07-06 15:36:03 +01004314
4315 if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
4316 {
4317 if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
4318 {
4319 throw InvalidArgumentException(descriptorName +
4320 ": Invalid input tensor data layout combination.");
4321 }
4322 }
4323 if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
4324 {
4325 if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
4326 {
4327 throw InvalidArgumentException(descriptorName +
4328 ": Invalid input tensor data layout combination.");
4329 }
4330 }
4331 }
4332
4333 // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004334 unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
4335 inputYInfoAfterParams.GetNumDimensions());
Samuel Yap6b478092022-07-06 15:36:03 +01004336 if(outputTensorDimSize-2 > 0)
4337 {
4338 TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4339 DataType::Float32);
4340 TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4341 DataType::Float32);
4342 TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4343 DataType::Float32);
4344
4345 auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
4346 {
4347 auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4348
4349 for(unsigned int i = 0; i < sizeDiff; i++)
4350 {
4351 axisIndices.insert(axisIndices.begin(), 1);
4352 }
4353
4354 for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4355 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004356 ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
Samuel Yap6b478092022-07-06 15:36:03 +01004357 }
4358 };
4359
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004360 auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX,
4361 inputXInfoAfterParams.GetShape());
4362 auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY,
4363 inputYInfoAfterParams.GetShape());
4364
4365 doAxisExtension(axesXNotMul, tiXNotMul);
4366 doAxisExtension(axesYNotMul, tiYNotMul);
Samuel Yap6b478092022-07-06 15:36:03 +01004367
4368 for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
4369 {
4370 tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
4371 tiYNotMul.GetShape()[i]);
4372 }
4373
4374 ValidateBroadcastTensorShapesMatch(tiXNotMul,
4375 tiYNotMul,
4376 tiOutNotMul,
4377 descriptorName,
4378 "input_X",
4379 "input_Y");
4380 }
Samuel Yap6b478092022-07-06 15:36:03 +01004381}
4382
Teresa Charlin79a06a52023-07-13 17:16:45 +01004383void TileQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4384{
4385 const std::string& descriptorName{"TileQueueDescriptor"};
4386
4387 ValidateNumInputs(workloadInfo, descriptorName, 1);
4388 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4389
4390 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
4391 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
4392
4393 std::vector<DataType> supportedTypes =
4394 {
4395 DataType::Float32,
4396 DataType::Float16,
4397 DataType::QAsymmS8,
4398 DataType::QAsymmU8,
4399 DataType::QSymmS8,
4400 DataType::QSymmS16,
4401 DataType::Signed32
4402 };
4403
4404 // Multiples length must be the same as the number of dimensions in input.
4405 if (m_Parameters.m_Multiples.size() != inputTensorInfo.GetNumDimensions())
4406 {
4407 throw InvalidArgumentException(descriptorName +
4408 ": Multiples length is not same as the number of dimensions in Input.");
4409 }
4410
4411 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
4412 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
4413}
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01004414
mathad01df9a3222021-04-28 11:42:57 +01004415} // namespace armnn