blob: d4ae08d87491ba93786570a276cb83304bd13438 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Mike Kelly3ec30772023-03-08 13:47:17 +00002// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Colm Donelan0c479742021-12-10 12:43:54 +00006#include <armnn/backends/TensorHandle.hpp>
7#include <armnn/backends/WorkloadData.hpp>
8#include <armnn/backends/WorkloadInfo.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00009#include <armnnUtils/DataLayoutIndexed.hpp>
10#include <armnnUtils/TensorUtils.hpp>
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010011#include <armnnUtils/Permute.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010012#include <armnn/utility/NumericCast.hpp>
mathad01df9a3222021-04-28 11:42:57 +010013#include <armnn/Logging.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000014
telsoa014fcda012018-03-09 14:13:49 +000015#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000017#include <string>
18#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000019
James Ward47fce872020-09-10 11:57:28 +010020#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000021
Matteo Martincigh21350152018-11-28 16:22:22 +000022using namespace armnnUtils;
23
telsoa014fcda012018-03-09 14:13:49 +000024namespace armnn
25{
26
27//---------------------------------------------------------------
28DataType GetBiasDataType(DataType inputDataType)
29{
30 switch (inputDataType)
31 {
telsoa01c577f2c2018-08-31 09:22:23 +010032 case DataType::Float16:
33 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000034 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000035 case DataType::Float32:
36 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000037 case DataType::QAsymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +000038 case DataType::QAsymmU8:
Keith Davis5204aa82020-01-27 15:24:59 +000039 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +000040 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010041 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000042 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010043 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000044 return DataType::Float32;
45 }
46}
47
48namespace
49{
50
51//---------------------------------------------------------------
52//android ndk does not support std::to_string function.
53template <typename T>
54std::string to_string(T value)
55{
56 std::ostringstream os;
57 os << value;
58 return os.str();
59}
60
61//---------------------------------------------------------------
62void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
63{
64 if (!ptr)
65 {
66 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
67 paramName + " parameter must be set.");
68 }
69}
70
71//---------------------------------------------------------------
72void ValidateTensorShapesMatch(const TensorInfo& first,
73 const TensorInfo& second,
74 std::string const& descName,
75 std::string const& firstName,
76 std::string const& secondName)
77{
78 if (first.GetShape() != second.GetShape())
79 {
80 throw InvalidArgumentException(descName + ": "
81 + firstName + " & " + secondName + " must have identical shapes");
82 }
83}
84
85//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010086void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000087{
Sadik Armaganeff363d2019-04-05 15:25:46 +010088 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000089 {
90 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010091 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000092 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
93 }
94}
95
96//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010097void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000098{
Sadik Armaganeff363d2019-04-05 15:25:46 +010099 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000100 {
101 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100102 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000103 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
104 }
105}
106
107//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000108
109//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100110void ValidateTensorNumElements(const TensorInfo& tensor,
111 std::string const& descName,
112 unsigned int numElements,
113 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100114{
115 if (tensor.GetNumElements() != numElements)
116 {
117 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100118 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100119 tensorName + " tensor.");
120 }
121}
122
123//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000124void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
125 const std::string& descName, std::string const& tensorName)
126{
127 if (tensor.GetDataType() != dataType)
128 {
129 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
130 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
131 }
132}
133
Derek Lambertid466a542020-01-22 15:37:29 +0000134void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
135{
Jan Eilers1b2654f2021-09-24 15:45:46 +0100136 if (tensor.GetDataType() != DataType::QSymmS8)
Derek Lambertid466a542020-01-22 15:37:29 +0000137 {
138 throw InvalidArgumentException(descName +
139 ": Expected data type which supports per-axis quantization scheme but got " +
140 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
141 }
Derek Lambertid466a542020-01-22 15:37:29 +0000142}
143
telsoa014fcda012018-03-09 14:13:49 +0000144//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100145void ValidateTensorQuantizationSpace(const TensorInfo& first,
146 const TensorInfo& second,
147 const std::string& descName,
148 std::string const& firstName,
149 std::string const& secondName)
150{
151 if (!first.IsQuantized() ||
152 !second.IsQuantized())
153 {
154 // Not a quantized type, ignore the validation
155 return;
156 }
157
158 DataType firstDataType = first.GetDataType();
159 DataType secondDataType = second.GetDataType();
160
161 if (firstDataType != secondDataType)
162 {
163 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
164 " must be of the same quantized type, " +
165 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
166 secondName + " is " + GetDataTypeName(secondDataType));
167 }
168
169 if (!first.IsTypeSpaceMatch(second))
170 {
171 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
172 " must have the same quantization space, " +
173 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
174 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
175 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
176 " and scale " + to_string(second.GetQuantizationScale()));
177 }
178}
179
180//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100181void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
182 const TensorInfo& inputTensorInfo,
183 const TensorInfo& weightsTensorInfo,
184 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000185{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000186 // Helper lambda function to validate a single bias quantization scale value
187 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
188 {
mathad01df9a3222021-04-28 11:42:57 +0100189 constexpr float tolerance = 0.0001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000190 if (std::abs(biasScale - expectedScale) > tolerance)
191 {
192 // Print the float values with extra precision to see very small differences
mathad01df9a3222021-04-28 11:42:57 +0100193 ARMNN_LOG(warning) << std::setprecision(6) << descName << ": Expected " << expectedScale <<
194 " for bias quantization scale (product of input and weight scales), but got " <<
195 biasScale << ". Using scale provided.";
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000196 }
197 };
198
telsoa014fcda012018-03-09 14:13:49 +0000199 if (biasTensor.GetQuantizationOffset() != 0)
200 {
201 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
202 to_string(biasTensor.GetQuantizationOffset()));
203 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000204
James Conroy8502ade2020-11-12 19:26:29 +0000205 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000206 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000207 // Validate per-axis quantization scales
208 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
209 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
210
211 if (weightScales.size() != biasScales.size())
212 {
213 std::stringstream msg;
James Conroy8502ade2020-11-12 19:26:29 +0000214 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
215 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
216 << ", biases=" << biasScales.size();
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000217 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
218 }
219
220 for (size_t i = 0ul; i < biasScales.size(); ++i)
221 {
222 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
223 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
224 }
225 }
226 else
227 {
228 // Validate per-tensor quantization scale
229 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
230 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000231 }
232}
233
234//---------------------------------------------------------------
235void ValidateTensors(const std::vector<ITensorHandle*>& vec,
236 unsigned int numExpected,
237 const std::string& descName,
238 const std::string& varName)
239{
240 if (vec.empty() && numExpected > 0)
241 {
242 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
243 }
244
245 for (unsigned int i = 0; i < numExpected; ++i)
246 {
247 if (!vec[i])
248 {
249 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
250 }
251 }
252}
253
254//---------------------------------------------------------------
255void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
256 const TensorInfo& second,
257 const TensorInfo& output,
258 std::string const& descName,
259 std::string const& firstName,
260 std::string const& secondName)
261{
262 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
263 // broadcasted.
264 if (first.GetNumDimensions() != second.GetNumDimensions())
265 {
266 throw InvalidArgumentException(descName + ": Tensors "
267 + firstName + " & " + secondName
268 + " must have the same number of dimensions in order to be broadcasted");
269 }
270 uint32_t numDims = first.GetNumDimensions();
271 std::vector<uint32_t> outputDims(numDims, 0u);
272 for (uint32_t i = 0; i < numDims; i++)
273 {
274 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
275 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
276 if (dimsNotEqual && dimsNotOne)
277 {
278 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
279 }
280 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
281 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100282 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000283 if (broadcastShape != output.GetShape())
284 {
285 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
286 + firstName + " & " + secondName
287 + " does not match the output shape");
288 }
289}
290
291//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100292void ValidateDataTypes(const TensorInfo& info,
293 const std::vector<armnn::DataType>& supportedTypes,
294 std::string const& descName)
295{
296 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
297 if (iterator == supportedTypes.end())
298 {
299 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
300 }
301}
302
James Conroy4d1ff582019-06-10 17:06:39 +0100303//---------------------------------------------------------------
304void ValidateTensorDataTypesMatch(const TensorInfo& first,
305 const TensorInfo& second,
306 std::string const& descName,
307 std::string const& firstName,
308 std::string const& secondName)
309{
310 if (first.GetDataType() != second.GetDataType())
311 {
312 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
313 " must have identical data types.");
314 }
315}
316
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100317//---------------------------------------------------------------
318void ValidateTensorNumElementsMatch(const TensorInfo& first,
319 const TensorInfo& second,
320 std::string const& descName,
321 std::string const& firstName,
322 std::string const& secondName)
323{
324 if (first.GetNumElements() != second.GetNumElements())
325 {
326 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
327 " must have the same number of elements.");
328 }
329}
330
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000331void ValidateWeightDataType(const TensorInfo& inputInfo,
332 const TensorInfo& weightInfo,
333 const std::string& descName)
334{
335 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000336 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000337 {
338 const std::vector<DataType> validTypes =
339 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000340 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100341 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +0100342 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000343 };
344
345 ValidateDataTypes(weightInfo, validTypes, descName);
346 }
347 else
348 {
349 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
350 }
351}
352
353void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
354 const std::string& descName,
355 const std::string& tensorName)
356{
357 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
358 if (!quantizationDim.has_value())
359 {
James Ward47fce872020-09-10 11:57:28 +0100360 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
361 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000362 }
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000363}
364
365void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
366 const std::string& descName,
367 const std::string& tensorName)
368{
369 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
370 if (quantizationOffset != 0)
371 {
James Ward47fce872020-09-10 11:57:28 +0100372 throw InvalidArgumentException(fmt::format(
373 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
374 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000375 }
376}
377
378void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
379 const TensorInfo& outputInfo,
380 const TensorInfo& weightInfo,
381 const Optional<TensorInfo>& optionalBiasInfo,
382 const std::string& descName)
383{
384 if (weightInfo.HasPerAxisQuantization())
385 {
386 const DataType inputDataType = inputInfo.GetDataType();
387 const DataType outputDataType = outputInfo.GetDataType();
388
Keith Davis0c2eeac2020-02-11 16:51:50 +0000389 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000390
391 if (!canHavePerAxisQuantization)
392 {
James Ward47fce872020-09-10 11:57:28 +0100393 throw InvalidArgumentException(fmt::format(
394 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
395 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000396 }
397
Derek Lambertid466a542020-01-22 15:37:29 +0000398
399 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000400 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
401 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
402
403 if (optionalBiasInfo.has_value())
404 {
405 const TensorInfo& biasInfo = optionalBiasInfo.value();
406 if (!biasInfo.HasPerAxisQuantization())
407 {
James Ward47fce872020-09-10 11:57:28 +0100408 throw InvalidArgumentException(fmt::format(
409 "{}: Per-axis quantization parameters not set on bias tensor, "
410 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000411 }
412
413 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
414 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
415 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
416 }
417 }
418}
419
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100420} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000421
Mike Kelly80512b02022-05-16 23:10:42 +0100422//---------------------------------------------------------------
423void QueueDescriptor::ValidateTensorNumDimensions(const TensorInfo& tensor,
424 std::string const& descName,
425 unsigned int numDimensions,
426 std::string const& tensorName) const
427{
428 // If we're allowing expanded dimensions then numDimensions becomes the minimum number of Dimensions we can allow.
429 // Throw an Exception if the tensors has fewer than numDimensions or if the squeezed dimensions are greater than
430 // numDimensions.
431 if (m_AllowExpandedDims)
432 {
433 unsigned int squeezedDims = 0;
434
435 for (unsigned int i = 0; i < tensor.GetNumDimensions(); ++i)
436 {
437 if (tensor.GetShape()[i] != 1)
438 {
439 ++squeezedDims;
440 }
441 }
442 if (tensor.GetNumDimensions() < numDimensions || squeezedDims > numDimensions)
443 {
444 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " or less but got " +
445 to_string(tensor.GetNumDimensions()) + " dimensions for " +
446 tensorName + " tensor.");
447 }
448 }
449 else
450 {
451 if (tensor.GetNumDimensions() != numDimensions)
452 {
453 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
454 to_string(tensor.GetNumDimensions()) + " dimensions for " +
455 tensorName + " tensor.");
456 }
457 }
458}
459
460//---------------------------------------------------------------
461void QueueDescriptor::ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
462 unsigned int numDimension,
463 unsigned int numElements,
464 std::string const& tensorName) const
465{
466 const std::string functionName{"ValidateTensorNumDimNumElem"};
467 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
468 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
469}
470
471//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000472void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
473 unsigned int numExpectedIn, unsigned int numExpectedOut) const
474{
475 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
476 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
477}
478
479//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100480void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
481{
482 const std::string descriptorName{"MapQueueDescriptor"};
483
484 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100485 ValidateNumOutputs(workloadInfo, descriptorName, 0);
486
487 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
488 {
489 if (!m_Inputs[i])
490 {
491 throw InvalidArgumentException(
492 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
493 }
494 }
495}
496
497//---------------------------------------------------------------
498void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
499{
500 const std::string descriptorName{"UnmapQueueDescriptor"};
501
502 ValidateNumInputs(workloadInfo, descriptorName, 1);
503 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100504
505 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
506 {
507 if (!m_Inputs[i])
508 {
509 throw InvalidArgumentException(
510 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
511 }
512 }
513}
514
515//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000516void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
517{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100518 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000519
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100520 ValidateNumInputs(workloadInfo, descriptorName, 1);
521 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000522
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100523 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
524 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
525
526 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
527 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000528
529 if (m_Inputs.size() != m_Outputs.size())
530 {
James Ward47fce872020-09-10 11:57:28 +0100531 throw InvalidArgumentException(fmt::format(
532 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
533 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000534 }
535
536 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
537 {
538 if (!m_Inputs[i])
539 {
James Ward47fce872020-09-10 11:57:28 +0100540 throw InvalidArgumentException(fmt::format(
541 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000542 }
543
544 if (!m_Outputs[i])
545 {
James Ward47fce872020-09-10 11:57:28 +0100546 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000547 }
548 }
549}
550
Derek Lambertif674aa02019-08-01 15:56:25 +0100551//---------------------------------------------------------------
552void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
553{
554 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
555 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
556
557 if (workloadInfo.m_InputTensorInfos.size() != 1)
558 {
James Ward47fce872020-09-10 11:57:28 +0100559 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
560 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100561
562 }
563
564 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
565 {
James Ward47fce872020-09-10 11:57:28 +0100566 throw InvalidArgumentException(fmt::format(
567 "Number of input infos ({0}) does not match the number of output infos ({1})",
568 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100569 }
570
571 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
572 {
573 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
574 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
575 {
James Ward47fce872020-09-10 11:57:28 +0100576 throw InvalidArgumentException(fmt::format(
577 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100578 }
579 }
580
581 if (m_Inputs.size() != 1)
582 {
James Ward47fce872020-09-10 11:57:28 +0100583 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100584 }
585
586 if (m_Inputs.size() != m_Outputs.size())
587 {
James Ward47fce872020-09-10 11:57:28 +0100588 throw InvalidArgumentException(fmt::format(
589 "Number of inputs ({0}) does not match the number of outputs ({1})",
590 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100591 }
592
593 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
594 {
595 if (!m_Inputs[i])
596 {
James Ward47fce872020-09-10 11:57:28 +0100597 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100598 }
599
600 if (!m_Outputs[i])
601 {
James Ward47fce872020-09-10 11:57:28 +0100602 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100603 }
604 }
605}
606
607//---------------------------------------------------------------
608void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
609{
610 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
Derek Lambertif674aa02019-08-01 15:56:25 +0100611
Derek Lambertif674aa02019-08-01 15:56:25 +0100612 if (m_Inputs.size() != 1)
613 {
James Ward47fce872020-09-10 11:57:28 +0100614 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100615 }
616
617 if (m_Outputs.size() != 0)
618 {
James Ward47fce872020-09-10 11:57:28 +0100619 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100620 }
621
622 if (!m_Inputs[0])
623 {
James Ward47fce872020-09-10 11:57:28 +0100624 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100625 }
626}
627
628//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000629void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
630{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100631 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100632
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100633 ValidateNumInputs(workloadInfo, descriptorName, 1);
634 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100635
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100636 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
637 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100638
639 std::vector<DataType> supportedTypes =
640 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000641 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100642 DataType::Float16,
643 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000644 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000645 DataType::QAsymmU8,
646 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100647 };
648
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100649 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
650 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
651 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000652}
653
Nikhil Rajee391d52019-09-05 17:50:44 +0100654void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
655{
656 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
657
658 ValidateNumInputs(workloadInfo, descriptorName, 1);
659 ValidateNumOutputs(workloadInfo, descriptorName, 1);
660
661 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
662 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
663
Inki Daed4619e22020-09-10 15:33:54 +0900664 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
665 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100666 {
Inki Daed4619e22020-09-10 15:33:54 +0900667 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100668 }
669
James Conroyd47a0642019-09-17 14:22:06 +0100670 std::vector<DataType> supportedInputTypes =
671 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000672 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100673 DataType::Float16,
674 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100675 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000676 DataType::QAsymmU8,
677 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900678 DataType::Signed32,
679 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100680 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100681
James Conroyd47a0642019-09-17 14:22:06 +0100682 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100683
684 auto inputShape = inputTensorInfo.GetShape();
685 auto outputShape = outputTensorInfo.GetShape();
686
687 auto inputNumDimensions = inputShape.GetNumDimensions();
688 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
689
690 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
691
692 // 1D input shape results in scalar output shape
693 if (inputShape.GetNumDimensions() == 1)
694 {
695 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
696 {
697 throw InvalidArgumentException(descriptorName + outputShapeError);
698 }
699 }
700 else
701 {
702 for (unsigned int i = 0; i < unsignedAxis; ++i)
703 {
704 if (outputShape[i] != inputShape[i])
705 {
706 throw InvalidArgumentException(descriptorName + outputShapeError);
707 }
708 }
709
710 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
711 {
712 if (outputShape[i - 1] != inputShape[i])
713 {
714 throw InvalidArgumentException(descriptorName + outputShapeError);
715 }
716 }
717 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100718}
719
mathad01b392e982021-04-07 12:07:30 +0100720void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
721{
722 const std::string descriptorName{"CastQueueDescriptor"};
723
724 ValidateNumInputs(workloadInfo, descriptorName, 1);
725 ValidateNumOutputs(workloadInfo, descriptorName, 1);
726
727 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
728 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
729
730 std::vector<DataType> supportedTypes =
731 {
732 DataType::BFloat16,
733 DataType::Float16,
734 DataType::Float32,
735 DataType::QAsymmS8,
736 DataType::QAsymmU8,
737 DataType::QSymmS8,
738 DataType::QSymmS16,
739 DataType::Signed32,
740 DataType::Signed64
741 };
742
743 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
744 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
745}
746
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100747void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
748{
749 const std::string descriptorName{"SoftmaxQueueDescriptor"};
750
751 ValidateNumInputs(workloadInfo, descriptorName, 1);
752 ValidateNumOutputs(workloadInfo, descriptorName, 1);
753
754 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
755 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
756
757 std::vector<DataType> supportedTypes =
758 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000759 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100760 DataType::Float16,
761 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000762 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000763 DataType::QAsymmU8,
764 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100765 };
766
767 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
768 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
769 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
770}
771
telsoa014fcda012018-03-09 14:13:49 +0000772void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
773{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100774 const std::string descriptorName{"SplitterQueueDescriptor"};
775
776 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000777
Ruomei Yan25339c32019-05-28 16:48:20 +0100778 // Check the supported data types
779 std::vector<DataType> supportedTypes =
780 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000781 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100782 DataType::Float32,
783 DataType::Float16,
784 DataType::Boolean,
785 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100786 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000787 DataType::QAsymmU8,
788 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100789 };
790
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100791 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
792 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100793 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100794 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
795 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
796
797 const std::string outputName = "output_" + std::to_string(i);
798 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100799 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100800
telsoa014fcda012018-03-09 14:13:49 +0000801 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
802 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100803 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000804 }
805
806 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
807 {
808 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100809 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000810 "has to match number of workloadInfo.m_OutputTensorInfos. "
811 "Number of windows: " +
812 to_string(m_ViewOrigins.size()) +
813 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
814 }
815
telsoa01c577f2c2018-08-31 09:22:23 +0100816 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000817 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
818 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
819 {
telsoa01c577f2c2018-08-31 09:22:23 +0100820 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000821 ViewOrigin const& e = m_ViewOrigins[w];
822 if (e.m_Origin.size() != inputDims)
823 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100824 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000825 "have the same dimensionality as the input tensor. "
826 "Window origin (index: " +
827 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
828 " dimensions, the input "
829 "tensor has " +
830 to_string(inputDims) + " dimensions.");
831 }
832 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
833 {
834 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
835 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
836 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100837 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000838 "be smaller or equal than the size of the input in that coord.");
839 }
840 }
841 }
842}
843
Jim Flynne242f2d2019-05-22 14:24:13 +0100844void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000845{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100846 const std::string descriptorName{"ConcatQueueDescriptor"};
847
848 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000849
850 if (m_Inputs.size() <= 0)
851 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100852 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000853 }
854 if (m_Outputs.size() <= 0)
855 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100856 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000857 }
858
859 if (workloadInfo.m_InputTensorInfos.size() <= 0)
860 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100861 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000862 }
863 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
864 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100865 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000866 }
867
Nikhil Raj8599a412018-11-19 14:51:07 +0000868 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
869 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100870 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000871 }
872
873 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
874 {
875 return;
876 }
877
telsoa014fcda012018-03-09 14:13:49 +0000878 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
879 {
880 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100881 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000882 "has to match number of workloadInfo.m_InputTensorInfos. "
883 "Number of windows: " +
884 to_string(m_ViewOrigins.size()) +
885 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
886 }
887
telsoa01c577f2c2018-08-31 09:22:23 +0100888 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000889 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
890 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
891 {
telsoa01c577f2c2018-08-31 09:22:23 +0100892 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000893 ViewOrigin const& e = m_ViewOrigins[w];
894 if (e.m_Origin.size() != outputDims)
895 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100896 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000897 "have the same dimensionality as the output tensor. "
898 "Window origin (index: " +
899 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
900 " dimensions, the output "
901 "tensor has " +
902 to_string(outputDims) + " dimensions.");
903 }
telsoa01c577f2c2018-08-31 09:22:23 +0100904 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000905 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
906 {
907 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
908 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
909 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100910 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000911 "be smaller or equal than the size of the output in that coord.");
912 }
913 }
914 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100915
916 // Check the supported data types
917 std::vector<DataType> supportedTypes =
918 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000919 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100920 DataType::Float32,
921 DataType::Float16,
922 DataType::Boolean,
923 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100924 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000925 DataType::QAsymmU8,
926 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100927 };
928
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100929 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
930 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100931 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100932 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
933 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
934
935 const std::string inputName = "input_" + std::to_string(i);
936 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100937 }
telsoa014fcda012018-03-09 14:13:49 +0000938}
939
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100940void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
941{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100942 const std::string descriptorName{"StackQueueDescriptor"};
943
944 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100945
946 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
947 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100948 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100949 }
950
951 // All inputs must have the same shape, which is defined in parameters
952 const TensorShape& inputShape = m_Parameters.m_InputShape;
953 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
954 {
955 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
956 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100957 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100958 }
959 }
960
Matthew Jacksondba634f2019-08-15 15:14:18 +0100961 if (inputShape.GetNumDimensions() > 4)
962 {
963 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
964 }
965
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100966 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
967 // since the output tensor has an additional dimension.
968 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
969 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100970 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100971 "than the number of input dimensions.");
972 }
973
974 // Output shape must be as inferred from the input shape
975 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
976 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
977 {
978 if (outputShape[i] != inputShape[i])
979 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100980 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100981 "match shape inferred from input tensor.");
982 }
983 }
984
985 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
986 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100987 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100988 "match shape inferred from input tensor.");
989 }
990
991 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
992 {
993 if (outputShape[i] != inputShape[i-1])
994 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100995 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100996 "match shape inferred from input tensor.");
997 }
998 }
999
Matthew Jacksondba634f2019-08-15 15:14:18 +01001000 if (outputShape.GetNumDimensions() > 5)
1001 {
1002 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
1003 }
1004
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001005 // Check the supported data types
1006 std::vector<DataType> supportedTypes =
1007 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001008 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001009 DataType::Float32,
1010 DataType::Float16,
1011 DataType::Boolean,
1012 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001013 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001014 DataType::QAsymmU8,
1015 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001016 };
1017
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001018 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001019
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001020 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001021 {
1022 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1023 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001024 descriptorName,
1025 "input_0",
1026 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001027 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001028
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001029 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1030 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001031 descriptorName,
1032 "input_0",
1033 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001034}
1035
Ryan OSheaec6c6802020-06-05 17:17:06 +01001036void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1037{
1038 const std::string descriptorName{"FillQueueDescriptor"};
1039
1040 ValidateNumInputs(workloadInfo, descriptorName, 1);
1041 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1042
1043 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1044 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1045
1046 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1047
1048 std::vector<DataType> supportedTypes =
1049 {
1050 DataType::BFloat16,
1051 DataType::Float32,
1052 DataType::Float16,
1053 DataType::Signed32
1054 };
1055
1056 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1057}
1058
telsoa014fcda012018-03-09 14:13:49 +00001059void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1060{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001061 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001062
Matthew Sloyan81beae32021-07-13 19:46:11 +01001063 uint32_t numInputs = 2;
1064 if (m_Parameters.m_BiasEnabled)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001065 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001066 numInputs = 3;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001067 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001068
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001069 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001070 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1071
1072 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1073 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1074
1075 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1076
1077 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001078 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001079 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001080 }
1081
Matthew Sloyan81beae32021-07-13 19:46:11 +01001082 TensorInfo weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001083 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001084
1085 if (m_Parameters.m_BiasEnabled)
1086 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001087 TensorInfo biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
telsoa01c577f2c2018-08-31 09:22:23 +01001088 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001089 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001090 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1091 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001092 }
1093
Francis Murtagh46c09d02019-05-28 08:15:28 +01001094 // Check the supported data types
1095 std::vector<DataType> supportedTypes =
1096 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001097 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001098 DataType::Float32,
1099 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001100 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001101 DataType::QAsymmU8,
1102 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001103 };
1104
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001105 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001106
1107 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1108 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1109 {
1110 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1111 {
1112 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1113 "for BFloat16 input.");
1114 }
1115 }
1116 else
1117 {
1118 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1119 }
telsoa014fcda012018-03-09 14:13:49 +00001120}
1121
telsoa014fcda012018-03-09 14:13:49 +00001122void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1123{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001124 const std::string descriptorName{"NormalizationQueueDescriptor"};
1125
1126 ValidateNumInputs(workloadInfo, descriptorName, 1);
1127 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1128
1129 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1130 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001131
1132 // Check the supported data types
1133 std::vector<DataType> supportedTypes =
1134 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001135 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001136 DataType::Float16,
1137 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001138 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001139 DataType::QAsymmU8,
1140 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001141 };
1142
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001143 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001144
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001145 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001146
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001147 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001148}
1149
1150void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1151{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001152 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001153
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001154 ValidateNumInputs(workloadInfo, descriptorName, 2);
1155 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1156
1157 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1158 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1159 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1160
1161 std::vector<DataType> supportedTypes =
1162 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001163 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001164 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001165 DataType::Float16,
1166 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001167 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001168 DataType::QSymmS16,
1169 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001170 };
1171
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001172 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1173 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1174 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001175
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001176 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1177 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001178
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001179 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1180 inputTensorInfo1,
1181 outputTensorInfo,
1182 descriptorName,
1183 "input_0",
1184 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001185}
1186
telsoa014fcda012018-03-09 14:13:49 +00001187void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1188{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001189 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001190
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001191 ValidateNumInputs(workloadInfo, descriptorName, 2);
1192 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1193
1194 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1195 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1196 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1197
1198 std::vector<DataType> supportedTypes =
1199 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001200 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001201 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001202 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001203 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001204 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001205 DataType::QSymmS16,
1206 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001207 };
1208
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001209 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1210 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1211 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001212
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001213 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1214 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001215
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001216 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1217 inputTensorInfo1,
1218 outputTensorInfo,
1219 descriptorName,
1220 "input_0",
1221 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001222}
1223
1224void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1225{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001226 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001227
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001228 ValidateNumInputs(workloadInfo, descriptorName, 1);
1229 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1230
1231 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1232 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001233
1234 std::vector<DataType> supportedTypes =
1235 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001236 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001237 DataType::Float16,
1238 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001239 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001240 DataType::QAsymmU8,
1241 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001242 };
1243
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001244 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1245 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001246
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001247 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001248 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001249
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001250 ValidatePointer(m_Mean, descriptorName, "mean");
1251 ValidatePointer(m_Variance, descriptorName, "variance");
1252 ValidatePointer(m_Beta, descriptorName, "beta");
1253 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001254
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001255 const TensorInfo& mean = m_Mean->GetTensorInfo();
1256 const TensorInfo& variance = m_Variance->GetTensorInfo();
1257 const TensorInfo& beta = m_Beta->GetTensorInfo();
1258 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001259
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001260 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1261 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1262 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1263 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001264
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001265 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1266 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1267 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001268}
1269
1270void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1271{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001272 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001273
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001274 uint32_t numInputs = 2;
1275 if (m_Parameters.m_BiasEnabled)
1276 {
1277 numInputs = 3;
1278 }
1279
1280 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001281 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001282
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001283 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1284 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001285
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001286 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1287 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001288
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001289 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
telsoa014fcda012018-03-09 14:13:49 +00001290
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001291 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001292
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001293 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001294
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001295 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001296 if (m_Parameters.m_BiasEnabled)
1297 {
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001298 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001299 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001300
1301 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1302 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001303 }
1304
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001305 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1306 {
1307 throw InvalidArgumentException(
1308 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1309 "cannot be either negative or 0.",
1310 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1311 }
1312
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001313 ValidatePerAxisQuantization(inputTensorInfo,
1314 outputTensorInfo,
1315 weightTensorInfo,
1316 optionalBiasTensorInfo,
1317 descriptorName);
1318
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001319 std::vector<DataType> supportedTypes =
1320 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001321 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001322 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001323 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001324 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001325 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001326 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001327 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001328 };
1329
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001330 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001331
1332 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1333 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1334 {
1335 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1336 {
1337 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1338 "for BFloat16 input.");
1339 }
1340 }
1341 else
1342 {
1343 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1344 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001345}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001346
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001347void Convolution3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1348{
1349 const std::string descriptorName{"Convolution3dQueueDescriptor"};
1350
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001351 uint32_t numInputs = 2;
1352 if (m_Parameters.m_BiasEnabled)
1353 {
1354 numInputs = 3;
1355 }
1356 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001357 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1358
1359 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1360 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1361
1362 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1363 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1364
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001365 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001366 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 5, "weight");
1367
1368 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1369
1370 Optional<TensorInfo> optionalBiasTensorInfo;
1371 if (m_Parameters.m_BiasEnabled)
1372 {
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001373 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001374 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1375
1376 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1377 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1378 }
1379
1380 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 || m_Parameters.m_StrideZ <= 0 )
1381 {
1382 throw InvalidArgumentException(
1383 fmt::format("{}: strideX (provided {}), strideY (provided {}) or strideZ (provided {})"
1384 "cannot be either negative or 0.",
1385 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY, m_Parameters.m_StrideZ));
1386 }
1387
1388 ValidatePerAxisQuantization(inputTensorInfo,
1389 outputTensorInfo,
1390 weightTensorInfo,
1391 optionalBiasTensorInfo,
1392 descriptorName);
1393
1394 std::vector<DataType> supportedTypes =
1395 {
1396 DataType::BFloat16,
1397 DataType::Float16,
1398 DataType::Float32,
1399 DataType::QAsymmS8,
1400 DataType::QAsymmU8,
1401 DataType::QSymmS16,
1402 DataType::QSymmS8
1403 };
1404
1405 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1406 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1407}
1408
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001409void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1410{
1411 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1412
Cathal Corbett06902652022-04-14 17:55:11 +01001413 uint32_t numInputs = 2;
1414 if (m_Parameters.m_BiasEnabled)
1415 {
1416 numInputs = 3;
1417 }
1418
1419 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001420 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1421
1422 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1423 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1424
1425 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1426 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1427
Cathal Corbett06902652022-04-14 17:55:11 +01001428 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001429 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1430
1431 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1432 {
1433 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001434 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1435 "cannot be smaller than 1.",
1436 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001437 }
1438
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001439 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1440 {
1441 throw InvalidArgumentException(
1442 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1443 "cannot be either negative or 0.",
1444 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1445 }
1446
Jan Eilers53ef7952021-06-02 12:01:25 +01001447 if (weightTensorInfo.GetShape()[0] != 1)
1448 {
1449 throw InvalidArgumentException(fmt::format(
1450 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1451 "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
1452 descriptorName,
1453 weightTensorInfo.GetShape()[0],
1454 weightTensorInfo.GetShape()[1],
1455 weightTensorInfo.GetShape()[2],
1456 weightTensorInfo.GetShape()[3]));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001457 }
1458
Cathal Corbett4b19d222022-05-11 20:12:17 +01001459 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1460 const unsigned int numWeightOutputChannelsRefFormat = weightTensorInfo.GetShape()[3];
1461 const unsigned int numWeightOutputChannelsAclFormat = weightTensorInfo.GetShape()[1];
1462 const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1463
1464 // Weights format has two valid options: [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] (CpuAcc/GpuAcc).
1465 bool validRefFormat = (numWeightOutputChannelsRefFormat == numOutputChannels);
1466 bool validAclFormat = (numWeightOutputChannelsAclFormat == numOutputChannels);
1467
1468 if (!(validRefFormat || validAclFormat))
1469 {
1470 throw InvalidArgumentException(fmt::format(
1471 "{0}: The weight format in armnn is expected to be [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] "
1472 "(CpuAcc/GpuAcc). But neither the 4th (CpuRef) or 2nd (CpuAcc/GpuAcc) dimension is equal to Cout."
1473 "Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
1474 descriptorName,
1475 numOutputChannels,
1476 weightTensorInfo.GetShape()[0],
1477 weightTensorInfo.GetShape()[1],
1478 weightTensorInfo.GetShape()[2],
1479 weightTensorInfo.GetShape()[3]));
1480 }
1481
Teresa Charlind8df0262019-11-11 12:28:15 +00001482 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001483
Teresa Charlind8df0262019-11-11 12:28:15 +00001484 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001485 if (m_Parameters.m_BiasEnabled)
1486 {
Cathal Corbett06902652022-04-14 17:55:11 +01001487 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Teresa Charlind8df0262019-11-11 12:28:15 +00001488 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001489
1490 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1491 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1492 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001493 ValidatePerAxisQuantization(inputTensorInfo,
1494 outputTensorInfo,
1495 weightTensorInfo,
1496 optionalBiasTensorInfo,
1497 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001498
1499 std::vector<DataType> supportedTypes =
1500 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001501 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001502 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001503 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001504 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001505 DataType::QAsymmU8,
1506 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001507 };
1508
1509 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1510 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001511}
1512
1513void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1514{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001515 const std::string descriptorName{"PermuteQueueDescriptor"};
1516
1517 ValidateNumInputs(workloadInfo, descriptorName, 1);
1518 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001519
1520 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1521
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001522 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1523 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001524
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001525 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1526 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001527
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001528 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001529 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001530 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001531 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001532 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1533 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1534 "must match dst dimension " + to_string(mapping[i]) +
1535 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001536 }
1537 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001538
1539 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001540}
1541
1542void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1543{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001544 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001545
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001546 ValidateNumInputs(workloadInfo, descriptorName, 1);
1547 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1548
1549 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1550 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1551
1552 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1553 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001554
1555 std::vector<DataType> supportedTypes =
1556 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001557 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001558 DataType::Float32,
1559 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001560 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001561 DataType::QAsymmU8,
1562 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001563 };
1564
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001565 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1566 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001567}
1568
Tamás Nyíri7b885b32021-10-26 14:47:57 +01001569void Pooling3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1570{
1571 const std::string descriptorName{"Pooling3dQueueDescriptor"};
1572
1573 ValidateNumInputs(workloadInfo, descriptorName, 1);
1574 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1575
1576 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1577 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1578
1579 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1580 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1581
1582 std::vector<DataType> supportedTypes =
1583 {
1584 DataType::BFloat16,
1585 DataType::Float32,
1586 DataType::Float16,
1587 DataType::QAsymmS8,
1588 DataType::QAsymmU8,
1589 DataType::QSymmS16
1590 };
1591
1592 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1593 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1594}
1595
Teresa Charlin970f43b2019-07-01 13:51:07 +01001596void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1597{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001598 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001599
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001600 ValidateNumInputs(workloadInfo, descriptorName, 1);
1601 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1602
1603 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1604 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1605
1606 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1607 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001608
1609 std::vector<DataType> supportedTypes =
1610 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001611 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001612 DataType::Float16,
1613 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001614 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001615 DataType::QAsymmU8,
1616 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001617 };
1618
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001619 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1620 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001621
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001622 // Resize only changes width and height: batch and channel count must match.
1623 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1624 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001625 if (inputBatchSize != outputBatchSize)
1626 {
1627 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001628 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1629 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001630 }
1631
1632 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001633 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1634 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001635 if (inputChannelCount != outputChannelCount)
1636 {
1637 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001638 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1639 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001640 }
1641}
1642
1643void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1644{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001645 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001646
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001647 ValidateNumInputs(workloadInfo, descriptorName, 1);
1648 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1649
1650 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1651 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1652
1653 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1654 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1655
1656 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1657
telsoa014fcda012018-03-09 14:13:49 +00001658 if (m_Parameters.m_Min > m_Parameters.m_Max)
1659 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001660 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001661 }
telsoa014fcda012018-03-09 14:13:49 +00001662}
1663
Kevin Mayce5045a2019-10-02 14:07:47 +01001664void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1665{
1666 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1667
1668 ValidateNumInputs(workloadInfo, descriptorName, 1);
1669 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1670
1671 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1672 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1673
1674 if (inputTensorInfo.GetNumDimensions() > 4)
1675 {
1676 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1677 }
1678
1679 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1680
1681 // Check the supported data types
1682 std::vector<DataType> supportedTypes =
1683 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001684 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001685 DataType::Float32,
1686 DataType::Float16
1687 };
1688
1689 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001690 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001691}
1692
telsoa014fcda012018-03-09 14:13:49 +00001693void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1694{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001695 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001696
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001697 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001698 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1699
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001700 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1701 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1702
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001703 if (inputTensorInfo.GetNumDimensions() > 4)
1704 {
1705 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1706 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001707
1708 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001709
1710 // Check the supported data types
1711 std::vector<DataType> supportedTypes =
1712 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001713 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001714 DataType::Float32,
1715 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001716 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001717 DataType::QAsymmU8,
1718 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001719 };
1720
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001721 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001722 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1723}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001724
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001725void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1726{
1727 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1728
1729 ValidateNumInputs(workloadInfo, descriptorName, 1);
1730 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1731
1732 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1733 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1734
1735 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1736
1737 std::vector<DataType> supportedTypes =
1738 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001739 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001740 DataType::Float32,
1741 DataType::Float16,
1742 };
1743
1744 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001745 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001746}
1747
1748void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1749{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001750 const std::string descriptorName{"ConstantQueueDescriptor"};
1751
1752 ValidateNumInputs(workloadInfo, descriptorName, 0);
1753 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001754
1755 if (!m_LayerOutput)
1756 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001757 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001758 }
1759
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001760 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1761 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001762
1763 // Check the supported data types
1764 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001765 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001766 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001767 DataType::Float32,
1768 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001769 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001770 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001771 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001772 DataType::QSymmS16,
1773 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001774 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001775
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001776 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001777}
1778
1779void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1780{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001781 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001782
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001783 ValidateNumInputs(workloadInfo, descriptorName, 1);
1784 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1785
1786 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1787 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1788
1789 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001790
1791 // Check the supported data types
1792 std::vector<DataType> supportedTypes =
1793 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001794 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001795 DataType::Float32,
1796 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001797 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001798 DataType::QAsymmU8,
1799 DataType::QSymmS16,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001800 DataType::Signed32,
1801 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001802 };
1803
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001804 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1805 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001806}
1807
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001808void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1809{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001810 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001811
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001812 ValidateNumInputs(workloadInfo, descriptorName, 1);
1813 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1814
1815 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1816 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1817
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001818 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1819 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001820 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1821 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001822 }
1823
Teresa Charlinf77cab52023-06-01 16:15:13 +01001824 if (m_Parameters.m_BlockShape.size() == 2)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001825 {
Teresa Charlinf77cab52023-06-01 16:15:13 +01001826 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1827 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1828 }
1829 else if (m_Parameters.m_BlockShape.size() == 1)
1830 {
1831 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 3, "input");
1832 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 3, "output");
1833 }
1834 else
1835 {
1836 throw InvalidArgumentException(descriptorName + ": Invalid Block and Crops size.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001837 }
1838
Teresa Charlinf77cab52023-06-01 16:15:13 +01001839 // Check input + padding and output have the same number of elements
1840 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1841 const unsigned int inputHeight = inputTensorInfo.GetShape()[dimensionIndices.GetHeightIndex()] +
1842 m_Parameters.m_PadList[0].first + m_Parameters.m_PadList[0].second;
1843 const unsigned int inputWidth = (inputTensorInfo.GetNumDimensions() == 3) ? 1 :
1844 inputTensorInfo.GetShape()[dimensionIndices.GetWidthIndex()] +
1845 m_Parameters.m_PadList[1].first + m_Parameters.m_PadList[1].second;
1846
1847 const int channelsIndex_int = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : -1;
1848 const unsigned int channelsIndex = channelsIndex_int < 0 ?
1849 static_cast<unsigned int>(channelsIndex_int) + inputTensorInfo.GetNumDimensions()
1850 : static_cast<unsigned int>(channelsIndex_int);
1851
1852 const unsigned int numInputElements = inputTensorInfo.GetShape()[0] *
1853 inputHeight *
1854 inputWidth *
1855 inputTensorInfo.GetShape()[channelsIndex];
1856
1857 if (outputTensorInfo.GetNumElements() != numInputElements)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001858 {
Teresa Charlinf77cab52023-06-01 16:15:13 +01001859 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
1860 to_string(numInputElements) + " after padding but output tensor has " +
1861 to_string(outputTensorInfo.GetNumElements()) + " elements.");
1862 }
1863
1864 // In a 4D tensor, there will be 2 spatialDimensions (H and W), and the for loop will run twice.
1865 // In a 3D tensor, there will be 1 spatialDimensions, and the for loop will run once.
1866 unsigned int firstSpatialDimension = m_Parameters.m_DataLayout == DataLayout::NCHW ? 2 : 1;
1867 for (unsigned int i = 0; i < m_Parameters.m_BlockShape.size(); ++i)
1868 {
1869 unsigned int spatialDimension = firstSpatialDimension + i;
1870 auto inputSize = inputTensorInfo.GetShape()[spatialDimension] +
1871 m_Parameters.m_PadList[i].first +
1872 m_Parameters.m_PadList[i].second;
1873 if (inputSize % m_Parameters.m_BlockShape[i] != 0)
1874 {
1875 throw InvalidArgumentException(descriptorName + ": Input dimension size after padding must be "
1876 "divisible by Block Shape in dimension: " + to_string(spatialDimension) + ".");
1877 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001878 }
nikraj01120522a2019-05-31 11:33:07 +01001879
1880 std::vector<DataType> supportedTypes =
1881 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001882 DataType::BFloat16,
1883 DataType::Float16,
1884 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001885 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001886 DataType::QAsymmU8,
1887 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001888 };
1889
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001890 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1891 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001892}
1893
Keith Davisa57eccb2019-06-14 17:33:22 +01001894void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1895{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001896 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001897
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001898 ValidateNumInputs(workloadInfo, descriptorName, 1);
1899 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001900
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001901 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1902 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1903
1904 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1905 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001906
1907 std::vector<DataType> supportedTypes =
1908 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001909 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001910 DataType::Float32,
1911 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001912 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001913 DataType::QAsymmU8,
1914 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001915 };
1916
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001917 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1918 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001919
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001920 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1921
1922 if (m_Parameters.m_BlockSize == 0)
1923 {
1924 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1925 }
1926
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001927 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1928 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1929 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1930 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001931
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001932 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001933 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001934 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001935 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1936 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001937 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001938
1939 const TensorShape& outputShape = outputTensorInfo.GetShape();
1940 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1941 {
1942 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1943 "must be divisible by the square of block size." );
1944 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001945}
1946
telsoa014fcda012018-03-09 14:13:49 +00001947void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1948{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001949 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001950
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001951 ValidateNumInputs(workloadInfo, descriptorName, 1);
1952 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1953
1954 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1955 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001956
1957 std::vector<DataType> supportedTypes =
1958 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001959 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001960 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001961 DataType::Float16,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001962 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001963 };
1964
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001965 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matthew Sloyan81beae32021-07-13 19:46:11 +01001966 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1967 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1968 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001969}
1970
telsoa01c577f2c2018-08-31 09:22:23 +01001971void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1972{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001973 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1974
1975 const std::string descriptorName{"LstmQueueDescriptor"};
1976
1977 // check dimensions of all inputs and outputs
1978 if (workloadInfo.m_InputTensorInfos.size() != 3)
1979 {
1980 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1981 }
1982 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1983 {
1984 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1985 }
1986
1987 std::vector<DataType> supportedTypes =
1988 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001989 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001990 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001991 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001992 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001993 };
1994
Jan Eilers38e05bd2019-06-26 13:10:09 +01001995 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001996 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1997
Jan Eilers38e05bd2019-06-26 13:10:09 +01001998 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001999 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01002000 {
2001 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2002 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002003 descriptorName,
2004 "input_0",
2005 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01002006 }
2007 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002008 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01002009 {
2010 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2011 workloadInfo.m_OutputTensorInfos[i],
2012 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002013 "input_0",
2014 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01002015 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002016
janeil0117d8d852019-11-15 15:00:16 +00002017 // Making sure clipping parameters have valid values.
2018 // == 0 means no clipping
2019 // > 0 means clipping
2020 if (m_Parameters.m_ClippingThresCell < 0.0f)
2021 {
2022 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
2023 }
2024 if (m_Parameters.m_ClippingThresProj < 0.0f)
2025 {
2026 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
2027 }
2028
Jan Eilers38e05bd2019-06-26 13:10:09 +01002029 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01002030 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
2031 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
2032 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
2033 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
2034 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
2035 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
2036
Jan Eilers38e05bd2019-06-26 13:10:09 +01002037 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002038 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
2039 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002040 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002041 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
2042 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002043 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002044 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
2045 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002046 // scratchBufferTensor
2047 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002048 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
2049 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002050 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002051 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
2052 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002053 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002054 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
2055 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002056 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002057 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
2058 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002059
Jan Eilers38e05bd2019-06-26 13:10:09 +01002060 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
2061 if ( m_InputToInputWeights )
2062 {
2063 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
2064 (n_cell * n_input), "InputLayerNormWeights");
2065 }
2066
2067 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2068 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
2069 (n_cell * n_input), "InputToForgetWeights");
2070
2071 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2072 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2073 (n_cell * n_input), "InputToCellWeights");
2074
2075 if ( m_RecurrentToInputWeights )
2076 {
2077 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2078 (n_cell * n_output), "RecurrentToInputWeights");
2079 }
2080
2081 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2082 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2083 (n_cell * n_output), "RecurrentToForgetWeights");
2084
2085 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2086 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2087 (n_cell * n_output), "RecurrentToCellWeights");
2088
2089 // Make sure the input-gate's parameters are either both present (regular
2090 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2091 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2092 !m_Parameters.m_CifgEnabled) ||
2093 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2094 m_Parameters.m_CifgEnabled));
2095 if (!cifg_weights_all_or_none)
2096 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002097 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2098 "RecurrentToInputWeights must either both be present (regular LSTM) "
2099 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2100 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002101 }
2102
2103 if ( m_CellToInputWeights )
2104 {
2105 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2106 n_cell, "CellToInputWeights");
2107 }
2108 if ( m_CellToForgetWeights )
2109 {
2110 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2111 n_cell, "CellToForgetWeights");
2112 }
2113 if ( m_CellToOutputWeights )
2114 {
2115 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2116 n_cell, "CellToOutputWeights");
2117 }
2118
2119 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2120 bool peephole_weights_all_or_none =
2121 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2122 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2123 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2124 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2125 if (!peephole_weights_all_or_none)
2126 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002127 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002128 }
2129
2130 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2131 if (m_Parameters.m_CifgEnabled)
2132 {
2133 if (m_InputGateBias)
2134 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002135 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002136 }
2137 }
2138 else
2139 {
2140 if (!m_InputGateBias)
2141 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002142 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2143 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002144 }
2145 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2146 n_cell, "InputGateBias");
2147 }
2148
2149 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2150 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2151
2152 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2153 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2154
2155 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2156 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2157
2158 if (m_ProjectionWeights)
2159 {
2160 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2161 (n_cell * n_output), "ProjectionWeights");
2162 }
2163 if (m_ProjectionBias)
2164 {
2165 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2166 }
2167
2168 // Making sure the projection tensors are consistent:
2169 // 1) If projection weight is not present, then projection bias should not be
2170 // present.
2171 // 2) If projection weight is present, then projection bias is optional.
2172 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2173 !m_Parameters.m_ProjectionEnabled)
2174 || (m_ProjectionWeights && !m_ProjectionBias &&
2175 m_Parameters.m_ProjectionEnabled)
2176 || (m_ProjectionWeights && m_ProjectionBias &&
2177 m_Parameters.m_ProjectionEnabled));
2178 if (!projecton_tensors_consistent)
2179 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002180 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002181 }
2182
2183 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2184 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2185 // either all have values or none of them have values. Layer normalization is used when the values of all the
2186 // layer normalization weights are present
2187 if (m_InputLayerNormWeights)
2188 {
2189 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2190 }
2191 if (m_ForgetLayerNormWeights)
2192 {
2193 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2194 }
2195 if (m_CellLayerNormWeights)
2196 {
2197 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2198 }
2199 if (m_OutputLayerNormWeights)
2200 {
2201 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2202 }
2203
Jan Eilers38e05bd2019-06-26 13:10:09 +01002204 if (m_Parameters.m_LayerNormEnabled)
2205 {
2206 if (!m_Parameters.m_CifgEnabled)
2207 {
2208 if (!m_InputLayerNormWeights)
2209 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002210 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2211 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002212 }
2213 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2214 1, n_cell, "InputLayerNormWeights");
2215 }
2216 else if (m_InputLayerNormWeights)
2217 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002218 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2219 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002220 }
2221
2222 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2223 "ForgetLayerNormWeights");
2224 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2225
2226 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2227 "OutputLayerNormWeights");
2228 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2229
2230 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2231 "CellLayerNormWeights");
2232 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2233 }
2234 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2235 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002236 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2237 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002238 }
telsoa01c577f2c2018-08-31 09:22:23 +01002239}
2240
2241void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2242{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002243 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002244
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002245 ValidateNumInputs(workloadInfo, descriptorName, 1);
2246 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2247
2248 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2249 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2250
2251 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002252 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002253 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002254 }
2255
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002256 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002257 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002258 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002259 }
2260
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002261 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002262}
2263
2264void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2265{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002266 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002267
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002268 ValidateNumInputs(workloadInfo, descriptorName, 1);
2269 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2270
2271 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2272 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2273
2274 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002275 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002276 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002277 }
2278
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002279 if (outputTensorInfo.GetDataType() != DataType::Float32)
2280 {
2281 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2282 }
2283
2284 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002285}
2286
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002287void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2288{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002289 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002290
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002291 ValidateNumInputs(workloadInfo, descriptorName, 2);
2292 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2293
2294 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2295 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2296 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2297
2298 std::vector<DataType> supportedTypes =
2299 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002300 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002301 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002302 DataType::Float32,
2303 DataType::QAsymmS8,
2304 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002305 DataType::QSymmS16,
2306 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002307 };
2308
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002309 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2310 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2311 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002312
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002313 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2314 inputTensorInfo1,
2315 outputTensorInfo,
2316 descriptorName,
2317 "input_0",
2318 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002319}
2320
David Beckc2044fe2018-09-05 15:00:38 +01002321void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2322{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002323 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002324
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002325 ValidateNumInputs(workloadInfo, descriptorName, 2);
2326 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2327
2328 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2329 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2330 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2331
2332 std::vector<DataType> supportedTypes =
2333 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002334 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002335 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002336 DataType::Float32,
2337 DataType::QAsymmS8,
2338 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002339 DataType::QSymmS16,
2340 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002341 };
2342
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002343 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2344 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2345 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002346
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002347 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2348 inputTensorInfo1,
2349 outputTensorInfo,
2350 descriptorName,
2351 "input_0",
2352 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002353}
2354
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002355void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2356{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002357 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002358
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002359 ValidateNumInputs(workloadInfo, descriptorName, 2);
2360 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2361
2362 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2363 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2364 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2365
2366 std::vector<DataType> supportedTypes =
2367 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002368 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002369 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002370 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002371 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002372 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002373 DataType::QSymmS16,
2374 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002375 };
2376
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002377 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2378 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2379 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002380
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002381 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2382 inputTensorInfo1,
2383 outputTensorInfo,
2384 descriptorName,
2385 "input_0",
2386 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002387}
2388
narpra01a6bf9122018-09-10 09:50:09 +01002389void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2390{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002391 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002392
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002393 ValidateNumInputs(workloadInfo, descriptorName, 1);
2394 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2395
2396 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2397 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002398
2399 std::vector<DataType> supportedTypes =
2400 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002401 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002402 DataType::Float32,
2403 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002404 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002405 DataType::QAsymmU8,
2406 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002407 };
narpra01eb061912018-09-10 17:35:27 +01002408
James Conroy4d1ff582019-06-10 17:06:39 +01002409 // First check if input tensor data type is supported, then
2410 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002411 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2412 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002413
narpra0132b90462018-09-13 11:07:48 +01002414 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002415 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002416 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002417 }
narpra0132b90462018-09-13 11:07:48 +01002418 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002419 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002420 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002421 }
2422 else
2423 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002424 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002425 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002426 ValidateTensorNumDimensions(outputTensorInfo,
2427 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002428 outputDim > 0 ? outputDim : 1,
2429 "output");
2430 }
narpra01a6bf9122018-09-10 09:50:09 +01002431}
2432
jimfly012c9322a2018-09-19 10:59:49 +01002433void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2434{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002435 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002436
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002437 ValidateNumInputs(workloadInfo, descriptorName, 1);
2438 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2439
2440 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2441 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002442
jimfly012c9322a2018-09-19 10:59:49 +01002443 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002444 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2445
jimfly012c9322a2018-09-19 10:59:49 +01002446 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002447 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2448 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2449 "as there are dimensions in the input tensor that is " +
2450 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2451 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002452 }
2453}
2454
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002455void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2456{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002457 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002458
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002459 ValidateNumInputs(workloadInfo, descriptorName, 1);
2460 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002461
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002462 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2463 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2464
Sadik Armagan2208b602019-07-31 16:36:27 +01002465 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002466 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002467 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002468 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002469 DataType::Float16,
2470 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002471 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002472 DataType::QAsymmU8,
2473 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002474 };
2475
2476 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002477
Keith Davis0c2eeac2020-02-11 16:51:50 +00002478 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002479 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002480 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002481 }
2482}
2483
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002484void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2485{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002486 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002487
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002488 ValidateNumInputs(workloadInfo, descriptorName, 1);
2489 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002490
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002491 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2492 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002493
Teresa Charlinf77cab52023-06-01 16:15:13 +01002494 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_Crops.size())
2495 {
2496 throw InvalidArgumentException(descriptorName + ": Crops must contain the same number of "
2497 "dimensions as Block Shape.");
2498 }
2499
2500 if (m_Parameters.m_BlockShape.size() == 2)
2501 {
2502 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2503 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
2504 }
2505 else if (m_Parameters.m_BlockShape.size() == 1)
2506 {
2507 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 3, "input");
2508 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 3, "output");
2509 }
2510 else
2511 {
2512 throw InvalidArgumentException(descriptorName + ": Invalid Block and Crops size.");
2513 }
2514
2515 // In a 4D tensor, there will be 2 spatialDimensions (H and W), and the for loop will run twice.
2516 // In a 3D tensor, there will be 1 spatialDimensions, and the for loop will run once.
2517 unsigned int firstSpatialDimension = m_Parameters.m_DataLayout == DataLayout::NCHW ? 2 : 1;
2518 for (unsigned int i = 0; i < m_Parameters.m_BlockShape.size(); ++i)
2519 {
2520 unsigned int spatialDimension = firstSpatialDimension + i;
2521 unsigned int cropSize = m_Parameters.m_Crops[i].first + m_Parameters.m_Crops[i].second;
2522 unsigned int outputSize = inputTensorInfo.GetShape()[spatialDimension] * m_Parameters.m_BlockShape[i];
2523 if (cropSize > outputSize)
2524 {
2525 throw InvalidArgumentException(descriptorName + ": CropSize must be less than or equal to the uncropped"
2526 "outputSize in dimension: " + to_string(spatialDimension) + ".");
2527 }
2528 }
2529
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002530 std::vector<DataType> supportedTypes =
2531 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002532 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002533 DataType::Float32,
2534 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002535 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002536 DataType::QAsymmU8,
2537 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002538 };
2539
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002540 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2541 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002542}
2543
Conor Kennedy430b5d82018-11-14 15:28:28 +00002544void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2545{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002546 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002547
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002548 ValidateNumInputs(workloadInfo, descriptorName, 1);
2549 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2550
2551 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2552 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002553
2554 std::vector<DataType> supportedTypes =
2555 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002556 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002557 DataType::Float16,
2558 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002559 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002560 DataType::QAsymmU8,
2561 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002562 };
2563
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002564 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2565 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002566
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002567 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002568
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002569 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002570 if (rank > 4)
2571 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002572 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002573 }
2574
Conor Kennedy430b5d82018-11-14 15:28:28 +00002575 // Begin, End & Stride length must be of rank(input0)
2576 if (m_Parameters.m_Begin.size() != rank)
2577 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002578 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002579 }
2580
2581 if (m_Parameters.m_End.size() != rank)
2582 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002583 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002584 }
2585
2586 if (m_Parameters.m_Stride.size() != rank)
2587 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002588 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002589 }
2590
2591 // Stride entries must be non-zero
2592 for (auto& stride : m_Parameters.m_Stride)
2593 {
2594 if (stride == 0)
2595 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002596 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002597 }
2598 }
2599}
2600
kevmay0190539692018-11-29 08:40:19 +00002601void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2602{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002603 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002604
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002605 ValidateNumInputs(workloadInfo, descriptorName, 2);
2606 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2607
2608 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2609 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2610 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2611
2612 std::vector<DataType> supportedTypes =
2613 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002614 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002615 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002616 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002617 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002618 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002619 DataType::QSymmS16,
2620 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002621 };
2622
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002623 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2624 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2625 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002626
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002627 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2628 inputTensorInfo1,
2629 outputTensorInfo,
2630 descriptorName,
2631 "input_0",
2632 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002633}
2634
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002635void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2636{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002637 const std::string descriptorName{"DebugQueueDescriptor"};
2638
2639 ValidateNumInputs(workloadInfo, descriptorName, 1);
2640 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002641}
2642
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002643void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2644{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002645 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002646
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002647 ValidateNumInputs(workloadInfo, descriptorName, 2);
2648 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002649
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002650 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2651 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2652 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2653
2654 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2655 inputTensorInfo1,
2656 outputTensorInfo,
2657 descriptorName,
2658 "input_0",
2659 "input_1");
2660
2661 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002662 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002663 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002664 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002665}
2666
FrancisMurtagh878f0232018-12-19 10:56:15 +00002667void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2668{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002669 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002670
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002671 ValidateNumInputs(workloadInfo, descriptorName, 2);
2672 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002673
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002674 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2675 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2676 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2677
2678 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2679 inputTensorInfo1,
2680 outputTensorInfo,
2681 descriptorName,
2682 "input_0",
2683 "input_1");
2684
2685 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002686 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002687 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002688 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002689}
2690
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002691void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2692{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002693 const std::string descriptorName{"RsqrtQueueDescriptor"};
2694
2695 ValidateNumInputs(workloadInfo, descriptorName, 1);
2696 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2697
2698 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2699 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2700
2701 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002702
2703 std::vector<DataType> supportedTypes =
2704 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002705 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002706 DataType::Float16,
2707 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002708 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002709 DataType::QAsymmU8,
2710 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002711 };
2712
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002713 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2714 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002715}
2716
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01002717void GatherNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2718{
2719 const std::string descriptorName{"GatherNdQueueDescriptor"};
2720
2721 ValidateNumInputs(workloadInfo, descriptorName, 2);
2722 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2723
2724 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2725 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
2726 {
2727 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
2728 }
2729
2730 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2731 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2732
2733 std::vector<DataType> supportedTypes =
2734 {
2735 DataType::BFloat16,
2736 DataType::Float16,
2737 DataType::Float32,
2738 DataType::QAsymmS8,
2739 DataType::QAsymmU8,
2740 DataType::QSymmS16,
2741 DataType::Signed32,
2742 };
2743
2744 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2745
2746 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2747
2748 unsigned int outputDim = outputTensorInfo.GetNumDimensions();
2749 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
2750}
2751
narpra01b89b05f2019-01-16 09:53:09 +00002752void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2753{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002754 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002755
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002756 ValidateNumInputs(workloadInfo, descriptorName, 2);
2757 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002758
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002759 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2760 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002761 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002762 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002763 }
2764
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002765 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2766 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2767
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002768 std::vector<DataType> supportedTypes =
2769 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002770 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002771 DataType::Float16,
2772 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002773 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002774 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002775 DataType::QSymmS16,
2776 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002777 };
2778
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002779 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002780
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002781 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002782
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002783 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2784 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002785}
2786
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002787void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2788{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002789 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2790
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002791 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002792
2793 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2794 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002795 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002796 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2797 }
2798
2799 if (m_Anchors == nullptr)
2800 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002801 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002802 }
2803
2804 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002805 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2806 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2807
2808 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002809 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002810 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2811 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002812
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002813 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2814 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2815 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002816
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002817 const std::vector<DataType> supportedInputTypes =
2818 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002819 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002820 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002821 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002822 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002823 DataType::QAsymmU8,
2824 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002825 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002826
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002827 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2828 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2829 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2830
2831 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2832 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2833 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2834 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2835
2836 // NOTE: Output is always Float32 regardless of input type
2837 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2838 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2839 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2840 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002841
2842 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2843 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002844 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002845 "must be positive and less than or equal to 1.");
2846 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002847
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002848 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2849 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002850 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002851 "should be equal to number of classes + 1.");
2852 }
2853}
2854
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002855void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2856{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002857 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002858
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002859 ValidateNumInputs(workloadInfo, descriptorName, 1);
2860 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2861
2862 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2863 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2864
Teresa Charlin07307f32022-05-15 14:07:05 +01002865 std::vector<DataType> inputSupportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002866 {
Teresa Charlin07307f32022-05-15 14:07:05 +01002867 DataType::QAsymmS8,
2868 DataType::QAsymmU8,
2869 DataType::QSymmS8,
2870 DataType::QSymmS16,
2871 DataType::Float16
2872 };
2873 ValidateDataTypes(inputTensorInfo, inputSupportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002874
Teresa Charlin07307f32022-05-15 14:07:05 +01002875 std::vector<DataType> outputSupportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002876 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002877 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002878 DataType::Float32,
2879 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002880 };
2881
Teresa Charlin07307f32022-05-15 14:07:05 +01002882 ValidateDataTypes(outputTensorInfo, outputSupportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002883}
2884
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002885void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2886{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002887 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002888
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002889 ValidateNumInputs(workloadInfo, descriptorName, 2);
2890 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002891
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002892 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2893 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2894 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002895
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002896 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2897 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2898
2899 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2900 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002901}
2902
Keith Davis3ae3f972021-05-21 16:33:48 +01002903void ShapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2904{
2905 const std::string& descriptorName{"ShapeQueueDescriptor"};
2906
2907 ValidateNumInputs(workloadInfo, descriptorName, 1);
2908 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2909
2910 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2911 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2912
2913 std::vector<DataType> supportedTypes =
2914 {
2915 DataType::BFloat16,
2916 DataType::Float16,
2917 DataType::Float32,
2918 DataType::QAsymmS8,
2919 DataType::QAsymmU8,
2920 DataType::QAsymmS8,
2921 DataType::QSymmS8,
2922 DataType::QSymmS16,
2923 DataType::Signed32
2924 };
2925
2926 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2927 ValidateDataTypes(outputTensorInfo, {DataType::Signed32}, descriptorName);
2928}
2929
Sadik Armaganeff363d2019-04-05 15:25:46 +01002930void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2931{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002932 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002933
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002934 ValidateNumInputs(workloadInfo, descriptorName, 2);
2935 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2936
2937 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2938 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2939
2940 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2941 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2942
2943 std::vector<DataType> supportedTypes =
2944 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002945 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002946 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002947 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002948 DataType::QAsymmU8,
2949 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002950 };
2951
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002952 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2953 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002954
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002955 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2956 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002957
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002958 ValidateTensorShapesMatch(inputTensorInfo0,
2959 outputTensorInfo0,
2960 descriptorName,
2961 "input_0",
2962 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002963
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002964 ValidateTensorShapesMatch(inputTensorInfo0,
2965 outputTensorInfo1,
2966 descriptorName,
2967 "input_0",
2968 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002969}
2970
Derek Lamberti901ea112019-12-10 22:07:09 +00002971void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002972{
2973 // This is internally generated so it should not need validation.
2974}
2975
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002976void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2977{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002978 const std::string& descriptorName{"PreluQueueDescriptor"};
2979
2980 ValidateNumInputs(workloadInfo, descriptorName, 2);
2981 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2982
2983 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2984 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2985 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002986
2987 std::vector<DataType> supportedTypes
2988 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002989 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002990 DataType::Float16,
2991 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002992 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002993 DataType::QAsymmU8,
2994 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002995 };
2996
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002997 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2998 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002999
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003000 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003001
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003002 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
3003 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003004
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003005 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
3006 alphaTensorInfo,
3007 outputTensorInfo,
3008 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003009 "input",
3010 "alpha");
3011}
3012
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003013void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3014{
3015 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
3016
3017 ValidateNumInputs(workloadInfo, descriptorName, 1);
3018 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3019
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003020 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3021 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3022
3023 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
3024 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003025
3026 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003027
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003028 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
3029 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003030
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003031 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
3032
3033 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003034 if (m_Parameters.m_BiasEnabled)
3035 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003036 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003037
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003038 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
3039 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003040
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003041 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003042 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003043 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003044
3045 ValidatePerAxisQuantization(inputTensorInfo,
3046 outputTensorInfo,
3047 weightTensorInfo,
3048 optionalBiasTensorInfo,
3049 descriptorName);
3050
3051 std::vector<DataType> supportedTypes =
3052 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003053 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003054 DataType::Float32,
3055 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003056 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003057 DataType::QAsymmU8,
3058 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003059 };
3060
3061 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3062 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003063}
3064
Mike Kellyc9ea45a2020-02-28 18:11:58 +00003065void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3066{
3067 const std::string descriptorName{"TransposeQueueDescriptor"};
3068
3069 ValidateNumInputs(workloadInfo, descriptorName, 1);
3070 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3071
3072 const PermutationVector& mapping = m_Parameters.m_DimMappings;
3073
3074 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3075 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3076
3077 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
3078 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
3079
3080 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
3081 {
3082 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
3083 {
3084 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
3085 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
3086 "must match dst dimension " + to_string(i) +
3087 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
3088 }
3089 }
3090
3091 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3092}
3093
Simon Obute51f67772021-09-03 15:50:13 +01003094void ChannelShuffleQueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
3095{
3096 const std::string descriptorName{"TransposeQueueDescriptor"};
3097
3098 ValidateNumInputs(workloadInfo, descriptorName, 1);
3099 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3100
3101 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3102 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3103
3104 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3105}
3106
James Conroy4f1f8992020-04-29 20:01:10 +01003107void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3108{
3109 const std::string descriptorName{"QLstmQueueDescriptor"};
3110
3111 // Validate number of inputs/outputs
3112 ValidateNumInputs(workloadInfo, descriptorName, 3);
3113 ValidateNumOutputs(workloadInfo, descriptorName, 3);
3114
3115 // Input/output tensor info
3116 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3117 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
3118 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
3119
3120 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3121 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3122 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
3123
3124 // Supported types for various tensors in QLSTM
3125 std::vector<DataType> inputOutputSupportedTypes =
3126 {
3127 DataType::QAsymmS8
3128 };
3129
3130 std::vector<DataType> cellStateSupportedTypes =
3131 {
3132 DataType::QSymmS16
3133 };
3134
3135 std::vector<DataType> weightsSupportedTypes =
3136 {
3137 DataType::QSymmS8
3138 };
3139
3140 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3141 {
3142 DataType::QSymmS16
3143 };
3144
3145 std::vector<DataType> biasSupportedTypes =
3146 {
3147 DataType::Signed32
3148 };
3149
3150 // Validate types of input/output tensors
3151 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3152 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3153 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3154
3155 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3156 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3157 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3158
3159 // Validate matching types of input/output tensors
3160 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3161 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3162 "outputStateIn", "outputStateOut");
3163 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3164
3165 // Infer number of batches, number of units, input size and output size from tensor dimensions
3166 const uint32_t numBatches = inputInfo.GetShape()[0];
3167 const uint32_t inputSize = inputInfo.GetShape()[1];
3168 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3169 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3170
3171 // Validate number of dimensions and number of elements for input/output tensors
3172 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3173 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3174 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3175
3176 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3177 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3178 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3179
3180 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3181 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3182 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3183 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3184
3185 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3186 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3187 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3188
3189 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3190 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3191 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3192
3193 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3194 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3195 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3196 " RecurrentToForgetWeights");
3197
3198 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3199 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3200 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3201
3202 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3203 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3204 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3205
3206 // Validate data types for MANDATORY weights tensors (all should match each other)
3207 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3208
3209 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3210 "inputToForgetWeights", "inputToCellWeights");
3211 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3212 "inputToForgetWeights", "inputToOutputWeights");
3213
3214 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3215 "inputToForgetWeights", "recurrentToForgeteights");
3216 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3217 "inputToForgetWeights", "recurrentToCellWeights");
3218 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3219 "inputToForgetWeights", "recurrentToOutputWeights");
3220
3221 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3222 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3223 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3224 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3225
3226 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3227 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3228 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3229
3230 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3231 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3232 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3233
3234 // Validate data types for MANDATORY bias tensors
3235 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3236
3237 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3238 "forgetGateBias", "cellBias");
3239 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3240 "forgetGateBias", "outputGateBias");
3241
3242 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3243 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3244 !m_Parameters.m_CifgEnabled) ||
3245 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3246 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3247
3248 if (!allCifgParamsPresentOrNot)
3249 {
3250 throw InvalidArgumentException(descriptorName +
3251 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3252 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3253 "set appropriately.");
3254 }
3255
3256 if (!m_Parameters.m_CifgEnabled)
3257 {
3258 // Validate number of dimensions and number of elements
3259 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3260 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3261
3262 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3263 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3264 " RecurrentToInputWeights");
3265
3266 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3267 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3268
3269 // Validate data types
3270 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3271 "inputToForgetWeights", "inputToInputWeights");
3272 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3273 "inputToForgetWeights", "recurrentToInputWeights");
3274 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3275 "forgetGateBias", "inputGateBias");
3276 }
3277
3278 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3279 bool allPeepholeWeightsPresentOrNot =
3280 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3281 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3282 || (!m_CellToInputWeights && !m_CellToForgetWeights
3283 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3284
3285 if (!allPeepholeWeightsPresentOrNot)
3286 {
3287 throw InvalidArgumentException(descriptorName +
3288 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3289 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3290 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3291 "appropriately.");
3292 }
3293
3294 if (m_Parameters.m_PeepholeEnabled)
3295 {
3296 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3297 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3298 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3299
3300 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3301 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3302 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3303 "cellToForgetWeight", "cellToOutputWeights");
3304
3305 if (!m_Parameters.m_CifgEnabled)
3306 {
3307 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3308 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3309 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3310 "cellToForgetWeights", "cellToInputWeights");
3311 }
3312 }
3313
3314 // Validate OPTIONAL params: Layer Norm Weights
3315 bool allLayerNormWeightsPresentOrNot =
3316 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3317 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3318 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3319 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3320
3321 if (!allLayerNormWeightsPresentOrNot)
3322 {
3323 throw InvalidArgumentException(descriptorName +
3324 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3325 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3326 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3327 "only be present when Layer Norm is enabled and CIFG is disabled. "
3328 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3329 }
3330
3331 if (m_Parameters.m_LayerNormEnabled)
3332 {
3333 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3334 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3335 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3336
3337 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3338 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3339 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3340 "forgetLayerNormWeights", "cellLayerNormWeights");
3341
3342 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3343 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3344 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3345 "forgetLayerNormWeights", "outputLayerNormWeights");
3346
3347 if (!m_Parameters.m_CifgEnabled)
3348 {
3349 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3350 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3351 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3352 "forgetLayerNormWeights", "inputLayerNormWeights");
3353 }
3354 }
3355
3356 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3357 bool correctProjectionTensorsPresent =
3358 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3359 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3360 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3361
3362 if (!correctProjectionTensorsPresent)
3363 {
3364 throw InvalidArgumentException(descriptorName +
3365 ": If projection is enabled, ProjectionWeights should be present and "
3366 "ProjectionBias is optional. If projection is disabled, neither "
3367 "ProjectionWeights nor ProjectionBias should be present.");
3368 }
3369
3370 if (m_Parameters.m_ProjectionEnabled)
3371 {
3372 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3373 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3374 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3375
3376 if (m_ProjectionBias)
3377 {
3378 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003379 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003380 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3381 }
3382
3383 }
3384 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3385 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3386 throw InvalidArgumentException(descriptorName +
3387 ": If projection is disabled, output quantization info (scale, offset) "
3388 "should match HiddenStateScale and HiddenStateZeroPoint.");
3389 }
3390
3391}
3392
James Conroy9c3cae82019-08-01 16:01:48 +01003393void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3394{
3395 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3396
3397 // Validate number of inputs/outputs
3398 ValidateNumInputs(workloadInfo, descriptorName, 3);
3399 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3400
3401 // Input/output tensor infos
3402 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3403 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3404 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3405
3406 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3407 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3408
3409 std::vector<DataType> inputOutputSupportedTypes =
3410 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003411 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003412 };
3413
3414 std::vector<DataType> cellStateSupportedTypes =
3415 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003416 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003417 };
3418
3419 std::vector<DataType> weightsSupportedTypes =
3420 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003421 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003422 };
3423
3424 std::vector<DataType> biasSupportedTypes =
3425 {
3426 DataType::Signed32
3427 };
3428
3429 // Validate types of input/output tensors
3430 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3431 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3432 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3433
3434 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3435 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3436
3437 // Validate matching types of input/output tensors
3438 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3439 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3440 "outputStateIn", "outputStateOut");
3441 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3442
3443 // Validate matching quantization info for input/output tensors
3444 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3445 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3446 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003447
James Conroy9c3cae82019-08-01 16:01:48 +01003448 // Infer number of batches, input size and output size from tensor dimensions
3449 const uint32_t numBatches = inputInfo.GetShape()[0];
3450 const uint32_t inputSize = inputInfo.GetShape()[1];
3451 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3452
3453 // Validate number of dimensions and number of elements for input/output tensors
3454 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3455 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3456 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3457 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3458 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3459
3460 // Validate number of dimensions and number of elements for weights tensors
3461 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3462 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3463 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3464
3465 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3466 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3467 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3468
3469 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3470 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3471 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3472
3473 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3474 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3475 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3476
3477 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3478 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3479 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3480
3481 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3482 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3483 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3484 " RecurrentToForgetWeights");
3485
3486 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3487 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3488 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3489
3490 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3491 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3492 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3493
3494 // Validate data types for weights tensors (all should match each other)
3495 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3496
3497 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3498 "inputToInputWeights", "inputToForgetWeights");
3499 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3500 "inputToInputWeights", "inputToCellWeights");
3501 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3502 "inputToInputWeights", "inputToOutputWeights");
3503
3504 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3505 "inputToInputWeights", "recurrentToInputWeights");
3506 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3507 "inputToInputWeights", "recurrentToForgeteights");
3508 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3509 "inputToInputWeights", "recurrentToCellWeights");
3510 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3511 "inputToInputWeights", "recurrentToOutputWeights");
3512
3513 // Validate matching quantization info for weight tensors (all should match each other)
3514 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3515 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3516 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3517 descriptorName, "inputToInputWeights", "inputToCellWeights");
3518 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3519 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3520
3521 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3522 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3523 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3524 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3525 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3526 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3527 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3528 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3529
3530 // Validate number of dimensions and number of elements in bias tensors
3531 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3532 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3533 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3534
3535 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3536 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3537 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3538
3539 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3540 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3541 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3542
3543 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3544 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3545 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3546
3547 // Validate data types for bias tensors (all should match each other)
3548 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3549
3550 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3551 "inputGateBias", "forgetGateBias");
3552 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3553 "inputGateBias", "cellBias");
3554 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3555 "inputGateBias", "outputGateBias");
3556
3557 // Validate bias tensor quantization info
3558 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3559 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3560 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3561 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3562}
3563
Kevin May868eb142019-09-04 17:29:31 +01003564void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3565{
3566 const std::string descriptorName{"AbsQueueDescriptor"};
3567
3568 ValidateNumInputs(workloadInfo, descriptorName, 1);
3569 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3570
3571 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3572 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3573
3574 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3575
3576 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003577 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003578 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003579 DataType::Float16,
3580 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003581 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003582 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003583 DataType::QSymmS16,
3584 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003585 };
Kevin May868eb142019-09-04 17:29:31 +01003586
3587 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3588 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3589}
3590
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003591void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3592{
3593 const std::string descriptorName{"SliceQueueDescriptor"};
3594
3595 ValidateNumInputs(workloadInfo, descriptorName, 1);
3596 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3597
3598 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3599 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3600
3601 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3602
3603 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3604 if (rank > 4)
3605 {
3606 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3607 }
3608
3609 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3610
3611 // Check if m_Begin and m_Size have the expected length
3612 if (m_Parameters.m_Begin.size() != rank)
3613 {
3614 throw InvalidArgumentException(descriptorName +
3615 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3616 }
3617 if (m_Parameters.m_Size.size() != rank)
3618 {
3619 throw InvalidArgumentException(descriptorName +
3620 ": Length of size descriptor must equal rank " + std::to_string(rank));
3621 }
3622
3623 // Check if the shape of the output tensor matches m_Size
3624 const TensorShape& outputShape = outputTensorInfo.GetShape();
3625 for (unsigned int i = 0u; i < rank; ++i)
3626 {
3627 if (m_Parameters.m_Size[i] != outputShape[i])
3628 {
3629 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3630 }
3631 }
3632
3633 // Check if the sum of begin offset and size in a given dimension
3634 // does not exceed the size of corresponding input
3635 const TensorShape& inputShape = inputTensorInfo.GetShape();
3636 for(unsigned int i = 0u; i < rank; ++i)
3637 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003638 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003639 {
3640 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3641 std::to_string(i) + " exceeds input size.");
3642 }
3643 }
3644}
3645
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003646void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3647{
3648 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3649
3650 ValidateNumInputs(workloadInfo, descriptorName, 1);
3651 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3652
3653 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3654 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3655
3656 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3657 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3658
3659 std::vector<DataType> supportedTypes =
3660 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003661 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003662 DataType::Float32,
3663 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003664 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003665 DataType::QAsymmU8,
3666 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003667 };
3668
3669 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3670 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3671
3672 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3673
3674 if (m_Parameters.m_BlockSize == 0)
3675 {
3676 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3677 }
3678
3679 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3680 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3681 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3682 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3683
3684 const TensorShape& outputShape = outputInfo.GetShape();
3685 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3686 {
3687 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3688 "must be divisible by block size.");
3689 }
3690
3691 const TensorShape& inputShape = inputInfo.GetShape();
3692 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3693 {
3694 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3695 "must be divisible by the square of block size." );
3696 }
3697}
3698
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003699void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3700{
3701 const std::string descriptorName{"ComparisonQueueDescriptor"};
3702
3703 ValidateNumInputs(workloadInfo, descriptorName, 2);
3704 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3705
3706 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3707 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3708 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3709
3710 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3711 inputTensorInfo1,
3712 outputTensorInfo,
3713 descriptorName,
3714 "input_0",
3715 "input_1");
3716
3717 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3718 {
3719 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3720 }
3721}
3722
Mike Kelly3ec30772023-03-08 13:47:17 +00003723void ElementwiseBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3724{
3725 const std::string descriptorName{"ElementwiseBinaryQueueDescriptor"};
3726
3727 ValidateNumInputs(workloadInfo, descriptorName, 2);
3728 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3729
3730 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3731 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3732 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3733
3734 std::vector<DataType> supportedTypes =
3735 {
3736 DataType::BFloat16,
3737 DataType::Float16,
3738 DataType::Float32,
3739 DataType::QAsymmS8,
3740 DataType::QAsymmU8,
3741 DataType::QSymmS16,
3742 DataType::Signed32
3743 };
3744
3745 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
3746 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
3747
3748 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input", "output");
3749 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input", "output");
3750}
3751
josh minor4a3c6102020-01-06 16:40:46 -06003752void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3753{
3754 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3755
3756 ValidateNumInputs(workloadInfo, descriptorName, 1);
3757 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3758
3759 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3760 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3761
3762 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3763
3764 std::vector<DataType> supportedTypes =
3765 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003766 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003767 DataType::Float16,
3768 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003769 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003770 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003771 DataType::QSymmS16,
3772 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003773 };
3774
James Conroyaba90cd2020-11-06 16:28:18 +00003775 std::vector<DataType> logicalSupportedTypes =
3776 {
3777 DataType::Boolean
3778 };
3779
3780 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3781 {
3782 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3783 }
3784 else
3785 {
3786 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3787 }
3788
3789
josh minor4a3c6102020-01-06 16:40:46 -06003790 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3791}
3792
Finn Williams2605b232020-06-10 15:53:46 +01003793void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3794{
3795 const std::string descriptorName{"RankQueueDescriptor"};
3796
3797 ValidateNumInputs(workloadInfo, descriptorName, 1);
3798 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3799
3800 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3801 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3802
3803 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3804 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3805
3806 std::vector<DataType> supportedTypes =
3807 {
3808 DataType::BFloat16,
3809 DataType::Float16,
3810 DataType::Float32,
3811 DataType::QAsymmS8,
3812 DataType::QAsymmU8,
3813 DataType::QSymmS8,
3814 DataType::QSymmS16,
3815 DataType::Signed32
3816 };
3817
3818 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3819 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3820}
3821
James Conroyaba90cd2020-11-06 16:28:18 +00003822void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3823{
3824 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3825
3826 ValidateNumInputs(workloadInfo, descriptorName, 2);
3827 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3828
3829 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3830 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3831 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3832
3833 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3834 inputTensorInfo1,
3835 outputTensorInfo,
3836 descriptorName,
3837 "input_0",
3838 "input_1");
3839
3840 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3841 {
3842 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3843 }
3844
3845 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3846 {
3847 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3848 }
3849
3850 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3851 {
3852 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3853 }
3854}
3855
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003856void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3857{
3858 const std::string descriptorName{"ReduceQueueDescriptor"};
3859
3860 ValidateNumInputs(workloadInfo, descriptorName, 1);
3861 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3862
3863 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3864 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3865
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003866 std::vector<DataType> supportedTypes =
3867 {
3868 DataType::BFloat16,
3869 DataType::Float16,
3870 DataType::Float32,
3871 DataType::QAsymmS8,
3872 DataType::QAsymmU8,
3873 DataType::QSymmS16,
3874 DataType::Signed32
3875 };
3876
3877 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3878 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3879}
3880
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003881void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3882{
3883 // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm
3884
3885 const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"};
3886
3887 // check dimensions of all inputs and outputs
3888 if (workloadInfo.m_InputTensorInfos.size() != 3)
3889 {
3890 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
3891 }
Mike Kelly12994962022-04-21 11:57:09 +01003892 if (workloadInfo.m_OutputTensorInfos.size() != 3)
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003893 {
3894 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
3895 }
3896
3897 std::vector<DataType> supportedTypes =
3898 {
Mike Kelly12994962022-04-21 11:57:09 +01003899 DataType::Float32,
3900 DataType::QAsymmS8
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003901 };
3902
3903 // check for supported type of one input and match them with all the other input and output
3904 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
3905
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003906 // Making sure clipping parameters have valid values.
3907 // == 0 means no clipping
3908 // > 0 means clipping
3909 if (m_Parameters.m_ClippingThresCell < 0.0f)
3910 {
3911 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
3912 }
3913 if (m_Parameters.m_ClippingThresProj < 0.0f)
3914 {
3915 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
3916 }
3917
3918 unsigned int batchIndx = 0;
3919 unsigned int inputIndx = 1;
3920 uint32_t timeStep = 1;
3921 unsigned int timeIndx = 1;
3922 inputIndx = 2;
3923 if (m_Parameters.m_TimeMajor)
3924 {
3925 batchIndx = 1;
3926 timeIndx = 0;
3927
3928 }
3929 timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx];
3930
3931 // Inferring batch size, number of outputs and number of cells from the inputs.
3932 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx];
3933 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx];
3934 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
3935 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
3936 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
3937 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
3938
3939 // input tensor
3940 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input),
3941 descriptorName + " input_0");
3942 // outputStateInTensor
3943 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
3944 descriptorName + " input_1");
3945 // outputStateInTensor
3946 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
3947 descriptorName + " input_2");
3948
3949 // outputTensor
Mike Kelly12994962022-04-21 11:57:09 +01003950 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 3, (timeStep * n_batch * n_output),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003951 descriptorName + " output_0");
3952
3953 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
3954 if ( m_InputToInputWeights )
3955 {
3956 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
3957 (n_cell * n_input), "InputLayerNormWeights");
3958 }
3959
3960 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
3961 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
3962 (n_cell * n_input), "InputToForgetWeights");
3963
3964 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
3965 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
3966 (n_cell * n_input), "InputToCellWeights");
3967
3968 if ( m_RecurrentToInputWeights )
3969 {
3970 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
3971 (n_cell * n_output), "RecurrentToInputWeights");
3972 }
3973
3974 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
3975 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
3976 (n_cell * n_output), "RecurrentToForgetWeights");
3977
3978 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
3979 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
3980 (n_cell * n_output), "RecurrentToCellWeights");
3981
3982 // Make sure the input-gate's parameters are either both present (regular
3983 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
3984 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
3985 !m_Parameters.m_CifgEnabled) ||
3986 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3987 m_Parameters.m_CifgEnabled));
3988 if (!cifg_weights_all_or_none)
3989 {
3990 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
3991 "RecurrentToInputWeights must either both be present (regular LSTM) "
3992 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
3993 "accordingly.");
3994 }
3995
3996 if ( m_CellToInputWeights )
3997 {
3998 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
3999 n_cell, "CellToInputWeights");
4000 }
4001 if ( m_CellToForgetWeights )
4002 {
4003 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
4004 n_cell, "CellToForgetWeights");
4005 }
4006 if ( m_CellToOutputWeights )
4007 {
4008 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
4009 n_cell, "CellToOutputWeights");
4010 }
4011
4012 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
4013 bool peephole_weights_all_or_none =
4014 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
4015 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
4016 || ( !m_CellToInputWeights && !m_CellToForgetWeights
4017 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
4018 if (!peephole_weights_all_or_none)
4019 {
4020 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
4021 }
4022
4023 // Make sure the input gate bias is present only when not a CIFG-LSTM.
4024 if (m_Parameters.m_CifgEnabled)
4025 {
4026 if (m_InputGateBias)
4027 {
4028 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
4029 }
4030 }
4031 else
4032 {
4033 if (!m_InputGateBias)
4034 {
4035 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
4036 "must be present.");
4037 }
4038 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
4039 n_cell, "InputGateBias");
4040 }
4041
4042 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
4043 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
4044
4045 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
4046 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
4047
4048 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
4049 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
4050
4051 if (m_ProjectionWeights)
4052 {
4053 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
4054 (n_cell * n_output), "ProjectionWeights");
4055 }
4056 if (m_ProjectionBias)
4057 {
4058 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
4059 }
4060
4061 // Making sure the projection tensors are consistent:
4062 // 1) If projection weight is not present, then projection bias should not be
4063 // present.
4064 // 2) If projection weight is present, then projection bias is optional.
4065 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
4066 !m_Parameters.m_ProjectionEnabled)
4067 || (m_ProjectionWeights && !m_ProjectionBias &&
4068 m_Parameters.m_ProjectionEnabled)
4069 || (m_ProjectionWeights && m_ProjectionBias &&
4070 m_Parameters.m_ProjectionEnabled));
4071 if (!projecton_tensors_consistent)
4072 {
4073 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
4074 }
4075
4076 // The four layer normalization weights either all have values or none of them have values. Additionally, if
4077 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
4078 // either all have values or none of them have values. Layer normalization is used when the values of all the
4079 // layer normalization weights are present
4080 if (m_InputLayerNormWeights)
4081 {
4082 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
4083 }
4084 if (m_ForgetLayerNormWeights)
4085 {
4086 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4087 }
4088 if (m_CellLayerNormWeights)
4089 {
4090 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4091 }
4092 if (m_OutputLayerNormWeights)
4093 {
4094 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4095 }
4096
4097 if (m_Parameters.m_LayerNormEnabled)
4098 {
4099 if (!m_Parameters.m_CifgEnabled)
4100 {
4101 if (!m_InputLayerNormWeights)
4102 {
4103 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
4104 "disabled but InputLayerNormWeights are not present");
4105 }
4106 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
4107 1, n_cell, "InputLayerNormWeights");
4108 }
4109 else if (m_InputLayerNormWeights)
4110 {
4111 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
4112 "enabled");
4113 }
4114
4115 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
4116 "ForgetLayerNormWeights");
4117 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4118
4119 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
4120 "OutputLayerNormWeights");
4121 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4122
4123 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
4124 "CellLayerNormWeights");
4125 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4126 }
4127 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
4128 {
4129 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
4130 "normalisation weights are present.");
4131 }
4132}
4133
Samuel Yap6b478092022-07-06 15:36:03 +01004134void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4135{
4136 const std::string descriptorName{"BatchMatMulDescriptor"};
4137
4138 ValidateNumInputs(workloadInfo, descriptorName, 2);
4139 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4140
4141 // Inputs must be: both 2D+
4142 // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
4143 // axes N and I must be the same size
4144
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004145 const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
4146 const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
4147 const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
4148 // Output info has already been inferred
Samuel Yap6b478092022-07-06 15:36:03 +01004149
4150 std::vector<DataType> supportedTypes =
4151 {
4152 DataType::BFloat16,
4153 DataType::Float16,
4154 DataType::Float32,
4155 DataType::QAsymmS8,
4156 DataType::QAsymmU8,
4157 DataType::QSymmS16
4158 };
4159
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004160 ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
4161 ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
4162 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
Samuel Yap6b478092022-07-06 15:36:03 +01004163
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004164 if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
4165 (inputYInfoBeforeParams.GetNumDimensions() < 2))
Samuel Yap6b478092022-07-06 15:36:03 +01004166 {
4167 throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
4168 }
4169
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004170 TensorInfo inputXInfoAfterParams;
4171 TensorInfo inputYInfoAfterParams;
4172
4173 if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) ||
4174 (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY))
Samuel Yap6b478092022-07-06 15:36:03 +01004175 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004176 throw InvalidArgumentException(descriptorName +
4177 ": Invalid descriptor parameters - Transpose and Adjoint "
4178 "cannot both be true for a given input tensor.");
4179 }
4180 if(m_Parameters.m_TransposeX)
4181 {
4182 inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
4183 BatchMatMulDescriptor::GetPermuteVec(
4184 m_Parameters.m_DataLayoutX,
4185 inputXInfoBeforeParams.GetShape()));
4186 }
4187 else if(m_Parameters.m_AdjointX)
4188 {
4189 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4190 inputXInfoBeforeParams.GetShape());
4191 if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
4192 inputXInfoBeforeParams.GetShape()[axesToMul.second])
Samuel Yap6b478092022-07-06 15:36:03 +01004193 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004194 throw InvalidArgumentException(descriptorName +
4195 ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
Samuel Yap6b478092022-07-06 15:36:03 +01004196 }
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004197 // Shape remains the same as it's square
4198 inputXInfoAfterParams = inputXInfoBeforeParams;
4199 }
4200 else
4201 {
4202 inputXInfoAfterParams = inputXInfoBeforeParams;
Samuel Yap6b478092022-07-06 15:36:03 +01004203 }
4204
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004205 if(m_Parameters.m_TransposeY)
Samuel Yap6b478092022-07-06 15:36:03 +01004206 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004207 inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
4208 BatchMatMulDescriptor::GetPermuteVec(
4209 m_Parameters.m_DataLayoutY,
4210 inputYInfoBeforeParams.GetShape()));
4211 }
4212 else if(m_Parameters.m_AdjointY)
4213 {
4214 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4215 inputYInfoBeforeParams.GetShape());
4216 if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
4217 inputYInfoBeforeParams.GetShape()[axesToMul.second])
Samuel Yap6b478092022-07-06 15:36:03 +01004218 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004219 throw InvalidArgumentException(descriptorName +
4220 ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
Samuel Yap6b478092022-07-06 15:36:03 +01004221 }
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004222 // Shape remains the same as it's square
4223 inputYInfoAfterParams = inputYInfoBeforeParams;
4224 }
4225 else
4226 {
4227 inputYInfoAfterParams = inputYInfoBeforeParams;
Samuel Yap6b478092022-07-06 15:36:03 +01004228 }
4229
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004230 switch(m_Parameters.m_DataLayoutX)
4231 {
4232 case DataLayout::NCDHW:
4233 case DataLayout::NDHWC:
4234 if(inputXInfoAfterParams.GetNumDimensions() < 3)
4235 {
4236 throw InvalidArgumentException(descriptorName +
4237 ": Input tensor X does not have the correct "
4238 "number of dimensions for the Data Layout that it has been assigned.");
4239 }
4240 break;
4241 case DataLayout::NCHW:
4242 case DataLayout::NHWC:
4243 default:
4244 break;
4245 }
Samuel Yap6b478092022-07-06 15:36:03 +01004246
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004247 switch(m_Parameters.m_DataLayoutY)
4248 {
4249 case DataLayout::NCDHW:
4250 case DataLayout::NDHWC:
4251 if(inputYInfoAfterParams.GetNumDimensions() < 3)
4252 {
4253 throw InvalidArgumentException(descriptorName +
4254 ": Input tensor Y does not have the correct "
4255 "number of dimensions for the Data Layout that it has been assigned.");
4256 }
4257 break;
4258 case DataLayout::NCHW:
4259 case DataLayout::NHWC:
4260 default:
4261 break;
4262 }
4263
4264 auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4265 inputXInfoAfterParams.GetShape());
4266 auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4267 inputXInfoBeforeParams.GetShape());
4268
4269 if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
4270 != inputYInfoAfterParams.GetShape()[axesYToMul.first])
Samuel Yap6b478092022-07-06 15:36:03 +01004271 {
4272 throw InvalidArgumentException(descriptorName +
4273 ": The final axis of input tensor X must be the same size as "
4274 "the second last axis of input tensor Y.");
4275 }
4276
Samuel Yap6b478092022-07-06 15:36:03 +01004277 { // Separate scope so we don't pollute the rest of the scope with our temp variables
4278 // e.g. NHWC isnt compatible with NCHW as of now
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004279 DataLayout xLayout = m_Parameters.m_DataLayoutX;
4280 DataLayout yLayout = m_Parameters.m_DataLayoutY;
Samuel Yap6b478092022-07-06 15:36:03 +01004281
4282 if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
4283 {
4284 if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
4285 {
4286 throw InvalidArgumentException(descriptorName +
4287 ": Invalid input tensor data layout combination.");
4288 }
4289 }
4290 if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
4291 {
4292 if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
4293 {
4294 throw InvalidArgumentException(descriptorName +
4295 ": Invalid input tensor data layout combination.");
4296 }
4297 }
4298 }
4299
4300 // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004301 unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
4302 inputYInfoAfterParams.GetNumDimensions());
Samuel Yap6b478092022-07-06 15:36:03 +01004303 if(outputTensorDimSize-2 > 0)
4304 {
4305 TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4306 DataType::Float32);
4307 TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4308 DataType::Float32);
4309 TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4310 DataType::Float32);
4311
4312 auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
4313 {
4314 auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4315
4316 for(unsigned int i = 0; i < sizeDiff; i++)
4317 {
4318 axisIndices.insert(axisIndices.begin(), 1);
4319 }
4320
4321 for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4322 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004323 ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
Samuel Yap6b478092022-07-06 15:36:03 +01004324 }
4325 };
4326
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004327 auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX,
4328 inputXInfoAfterParams.GetShape());
4329 auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY,
4330 inputYInfoAfterParams.GetShape());
4331
4332 doAxisExtension(axesXNotMul, tiXNotMul);
4333 doAxisExtension(axesYNotMul, tiYNotMul);
Samuel Yap6b478092022-07-06 15:36:03 +01004334
4335 for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
4336 {
4337 tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
4338 tiYNotMul.GetShape()[i]);
4339 }
4340
4341 ValidateBroadcastTensorShapesMatch(tiXNotMul,
4342 tiYNotMul,
4343 tiOutNotMul,
4344 descriptorName,
4345 "input_X",
4346 "input_Y");
4347 }
Samuel Yap6b478092022-07-06 15:36:03 +01004348}
4349
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01004350
mathad01df9a3222021-04-28 11:42:57 +01004351} // namespace armnn