blob: 753fe06edbd9fedd1cb021e068cd349739454687 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01002// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Colm Donelan0c479742021-12-10 12:43:54 +00006#include <armnn/backends/TensorHandle.hpp>
7#include <armnn/backends/WorkloadData.hpp>
8#include <armnn/backends/WorkloadInfo.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00009#include <armnnUtils/DataLayoutIndexed.hpp>
10#include <armnnUtils/TensorUtils.hpp>
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010011#include <armnnUtils/Permute.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010012#include <armnn/utility/NumericCast.hpp>
mathad01df9a3222021-04-28 11:42:57 +010013#include <armnn/Logging.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000014
telsoa014fcda012018-03-09 14:13:49 +000015#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000017#include <string>
18#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000019
James Ward47fce872020-09-10 11:57:28 +010020#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000021
Matteo Martincigh21350152018-11-28 16:22:22 +000022using namespace armnnUtils;
23
telsoa014fcda012018-03-09 14:13:49 +000024namespace armnn
25{
26
27//---------------------------------------------------------------
28DataType GetBiasDataType(DataType inputDataType)
29{
30 switch (inputDataType)
31 {
telsoa01c577f2c2018-08-31 09:22:23 +010032 case DataType::Float16:
33 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000034 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000035 case DataType::Float32:
36 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000037 case DataType::QAsymmS8:
38 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000039 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000040 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000041 case DataType::QSymmS8:
42 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000043 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010044 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000045 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010046 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000047 return DataType::Float32;
48 }
49}
50
51namespace
52{
53
54//---------------------------------------------------------------
55//android ndk does not support std::to_string function.
56template <typename T>
57std::string to_string(T value)
58{
59 std::ostringstream os;
60 os << value;
61 return os.str();
62}
63
64//---------------------------------------------------------------
65void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
66{
67 if (!ptr)
68 {
69 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
70 paramName + " parameter must be set.");
71 }
72}
73
74//---------------------------------------------------------------
75void ValidateTensorShapesMatch(const TensorInfo& first,
76 const TensorInfo& second,
77 std::string const& descName,
78 std::string const& firstName,
79 std::string const& secondName)
80{
81 if (first.GetShape() != second.GetShape())
82 {
83 throw InvalidArgumentException(descName + ": "
84 + firstName + " & " + secondName + " must have identical shapes");
85 }
86}
87
88//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010089void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000090{
Sadik Armaganeff363d2019-04-05 15:25:46 +010091 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000092 {
93 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010094 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000095 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
96 }
97}
98
99//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100100void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000101{
Sadik Armaganeff363d2019-04-05 15:25:46 +0100102 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000103 {
104 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100105 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000106 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
107 }
108}
109
110//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000111
112//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100113void ValidateTensorNumElements(const TensorInfo& tensor,
114 std::string const& descName,
115 unsigned int numElements,
116 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100117{
118 if (tensor.GetNumElements() != numElements)
119 {
120 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100121 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100122 tensorName + " tensor.");
123 }
124}
125
126//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000127void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
128 const std::string& descName, std::string const& tensorName)
129{
130 if (tensor.GetDataType() != dataType)
131 {
132 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
133 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
134 }
135}
136
Derek Lambertid466a542020-01-22 15:37:29 +0000137void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
138{
Jan Eilers1b2654f2021-09-24 15:45:46 +0100139 if (tensor.GetDataType() != DataType::QSymmS8)
Derek Lambertid466a542020-01-22 15:37:29 +0000140 {
141 throw InvalidArgumentException(descName +
142 ": Expected data type which supports per-axis quantization scheme but got " +
143 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
144 }
Derek Lambertid466a542020-01-22 15:37:29 +0000145}
146
telsoa014fcda012018-03-09 14:13:49 +0000147//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100148void ValidateTensorQuantizationSpace(const TensorInfo& first,
149 const TensorInfo& second,
150 const std::string& descName,
151 std::string const& firstName,
152 std::string const& secondName)
153{
154 if (!first.IsQuantized() ||
155 !second.IsQuantized())
156 {
157 // Not a quantized type, ignore the validation
158 return;
159 }
160
161 DataType firstDataType = first.GetDataType();
162 DataType secondDataType = second.GetDataType();
163
164 if (firstDataType != secondDataType)
165 {
166 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
167 " must be of the same quantized type, " +
168 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
169 secondName + " is " + GetDataTypeName(secondDataType));
170 }
171
172 if (!first.IsTypeSpaceMatch(second))
173 {
174 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
175 " must have the same quantization space, " +
176 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
177 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
178 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
179 " and scale " + to_string(second.GetQuantizationScale()));
180 }
181}
182
183//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100184void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
185 const TensorInfo& inputTensorInfo,
186 const TensorInfo& weightsTensorInfo,
187 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000188{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000189 // Helper lambda function to validate a single bias quantization scale value
190 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
191 {
mathad01df9a3222021-04-28 11:42:57 +0100192 constexpr float tolerance = 0.0001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000193 if (std::abs(biasScale - expectedScale) > tolerance)
194 {
195 // Print the float values with extra precision to see very small differences
mathad01df9a3222021-04-28 11:42:57 +0100196 ARMNN_LOG(warning) << std::setprecision(6) << descName << ": Expected " << expectedScale <<
197 " for bias quantization scale (product of input and weight scales), but got " <<
198 biasScale << ". Using scale provided.";
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000199 }
200 };
201
telsoa014fcda012018-03-09 14:13:49 +0000202 if (biasTensor.GetQuantizationOffset() != 0)
203 {
204 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
205 to_string(biasTensor.GetQuantizationOffset()));
206 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000207
James Conroy8502ade2020-11-12 19:26:29 +0000208 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000209 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000210 // Validate per-axis quantization scales
211 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
212 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
213
214 if (weightScales.size() != biasScales.size())
215 {
216 std::stringstream msg;
James Conroy8502ade2020-11-12 19:26:29 +0000217 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
218 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
219 << ", biases=" << biasScales.size();
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000220 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
221 }
222
223 for (size_t i = 0ul; i < biasScales.size(); ++i)
224 {
225 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
226 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
227 }
228 }
229 else
230 {
231 // Validate per-tensor quantization scale
232 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
233 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000234 }
235}
236
237//---------------------------------------------------------------
238void ValidateTensors(const std::vector<ITensorHandle*>& vec,
239 unsigned int numExpected,
240 const std::string& descName,
241 const std::string& varName)
242{
243 if (vec.empty() && numExpected > 0)
244 {
245 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
246 }
247
248 for (unsigned int i = 0; i < numExpected; ++i)
249 {
250 if (!vec[i])
251 {
252 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
253 }
254 }
255}
256
257//---------------------------------------------------------------
258void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
259 const TensorInfo& second,
260 const TensorInfo& output,
261 std::string const& descName,
262 std::string const& firstName,
263 std::string const& secondName)
264{
265 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
266 // broadcasted.
267 if (first.GetNumDimensions() != second.GetNumDimensions())
268 {
269 throw InvalidArgumentException(descName + ": Tensors "
270 + firstName + " & " + secondName
271 + " must have the same number of dimensions in order to be broadcasted");
272 }
273 uint32_t numDims = first.GetNumDimensions();
274 std::vector<uint32_t> outputDims(numDims, 0u);
275 for (uint32_t i = 0; i < numDims; i++)
276 {
277 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
278 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
279 if (dimsNotEqual && dimsNotOne)
280 {
281 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
282 }
283 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
284 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100285 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000286 if (broadcastShape != output.GetShape())
287 {
288 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
289 + firstName + " & " + secondName
290 + " does not match the output shape");
291 }
292}
293
294//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100295void ValidateDataTypes(const TensorInfo& info,
296 const std::vector<armnn::DataType>& supportedTypes,
297 std::string const& descName)
298{
299 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
300 if (iterator == supportedTypes.end())
301 {
302 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
303 }
304}
305
James Conroy4d1ff582019-06-10 17:06:39 +0100306//---------------------------------------------------------------
307void ValidateTensorDataTypesMatch(const TensorInfo& first,
308 const TensorInfo& second,
309 std::string const& descName,
310 std::string const& firstName,
311 std::string const& secondName)
312{
313 if (first.GetDataType() != second.GetDataType())
314 {
315 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
316 " must have identical data types.");
317 }
318}
319
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100320//---------------------------------------------------------------
321void ValidateTensorNumElementsMatch(const TensorInfo& first,
322 const TensorInfo& second,
323 std::string const& descName,
324 std::string const& firstName,
325 std::string const& secondName)
326{
327 if (first.GetNumElements() != second.GetNumElements())
328 {
329 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
330 " must have the same number of elements.");
331 }
332}
333
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000334void ValidateWeightDataType(const TensorInfo& inputInfo,
335 const TensorInfo& weightInfo,
336 const std::string& descName)
337{
338 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000339 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000340 {
341 const std::vector<DataType> validTypes =
342 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000343 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100344 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +0100345 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000346 };
347
348 ValidateDataTypes(weightInfo, validTypes, descName);
349 }
350 else
351 {
352 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
353 }
354}
355
356void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
357 const std::string& descName,
358 const std::string& tensorName)
359{
360 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
361 if (!quantizationDim.has_value())
362 {
James Ward47fce872020-09-10 11:57:28 +0100363 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
364 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000365 }
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000366}
367
368void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
369 const std::string& descName,
370 const std::string& tensorName)
371{
372 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
373 if (quantizationOffset != 0)
374 {
James Ward47fce872020-09-10 11:57:28 +0100375 throw InvalidArgumentException(fmt::format(
376 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
377 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000378 }
379}
380
381void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
382 const TensorInfo& outputInfo,
383 const TensorInfo& weightInfo,
384 const Optional<TensorInfo>& optionalBiasInfo,
385 const std::string& descName)
386{
387 if (weightInfo.HasPerAxisQuantization())
388 {
389 const DataType inputDataType = inputInfo.GetDataType();
390 const DataType outputDataType = outputInfo.GetDataType();
391
Keith Davis0c2eeac2020-02-11 16:51:50 +0000392 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000393
394 if (!canHavePerAxisQuantization)
395 {
James Ward47fce872020-09-10 11:57:28 +0100396 throw InvalidArgumentException(fmt::format(
397 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
398 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000399 }
400
Derek Lambertid466a542020-01-22 15:37:29 +0000401
402 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000403 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
404 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
405
406 if (optionalBiasInfo.has_value())
407 {
408 const TensorInfo& biasInfo = optionalBiasInfo.value();
409 if (!biasInfo.HasPerAxisQuantization())
410 {
James Ward47fce872020-09-10 11:57:28 +0100411 throw InvalidArgumentException(fmt::format(
412 "{}: Per-axis quantization parameters not set on bias tensor, "
413 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000414 }
415
416 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
417 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
418 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
419 }
420 }
421}
422
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100423} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000424
Mike Kelly80512b02022-05-16 23:10:42 +0100425//---------------------------------------------------------------
426void QueueDescriptor::ValidateTensorNumDimensions(const TensorInfo& tensor,
427 std::string const& descName,
428 unsigned int numDimensions,
429 std::string const& tensorName) const
430{
431 // If we're allowing expanded dimensions then numDimensions becomes the minimum number of Dimensions we can allow.
432 // Throw an Exception if the tensors has fewer than numDimensions or if the squeezed dimensions are greater than
433 // numDimensions.
434 if (m_AllowExpandedDims)
435 {
436 unsigned int squeezedDims = 0;
437
438 for (unsigned int i = 0; i < tensor.GetNumDimensions(); ++i)
439 {
440 if (tensor.GetShape()[i] != 1)
441 {
442 ++squeezedDims;
443 }
444 }
445 if (tensor.GetNumDimensions() < numDimensions || squeezedDims > numDimensions)
446 {
447 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " or less but got " +
448 to_string(tensor.GetNumDimensions()) + " dimensions for " +
449 tensorName + " tensor.");
450 }
451 }
452 else
453 {
454 if (tensor.GetNumDimensions() != numDimensions)
455 {
456 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
457 to_string(tensor.GetNumDimensions()) + " dimensions for " +
458 tensorName + " tensor.");
459 }
460 }
461}
462
463//---------------------------------------------------------------
464void QueueDescriptor::ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
465 unsigned int numDimension,
466 unsigned int numElements,
467 std::string const& tensorName) const
468{
469 const std::string functionName{"ValidateTensorNumDimNumElem"};
470 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
471 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
472}
473
474//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000475void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
476 unsigned int numExpectedIn, unsigned int numExpectedOut) const
477{
478 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
479 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
480}
481
482//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100483void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
484{
485 const std::string descriptorName{"MapQueueDescriptor"};
486
487 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100488 ValidateNumOutputs(workloadInfo, descriptorName, 0);
489
490 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
491 {
492 if (!m_Inputs[i])
493 {
494 throw InvalidArgumentException(
495 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
496 }
497 }
498}
499
500//---------------------------------------------------------------
501void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
502{
503 const std::string descriptorName{"UnmapQueueDescriptor"};
504
505 ValidateNumInputs(workloadInfo, descriptorName, 1);
506 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100507
508 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
509 {
510 if (!m_Inputs[i])
511 {
512 throw InvalidArgumentException(
513 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
514 }
515 }
516}
517
518//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000519void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
520{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100521 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000522
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100523 ValidateNumInputs(workloadInfo, descriptorName, 1);
524 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000525
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100526 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
527 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
528
529 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
530 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000531
532 if (m_Inputs.size() != m_Outputs.size())
533 {
James Ward47fce872020-09-10 11:57:28 +0100534 throw InvalidArgumentException(fmt::format(
535 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
536 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000537 }
538
539 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
540 {
541 if (!m_Inputs[i])
542 {
James Ward47fce872020-09-10 11:57:28 +0100543 throw InvalidArgumentException(fmt::format(
544 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000545 }
546
547 if (!m_Outputs[i])
548 {
James Ward47fce872020-09-10 11:57:28 +0100549 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000550 }
551 }
552}
553
Derek Lambertif674aa02019-08-01 15:56:25 +0100554//---------------------------------------------------------------
555void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
556{
557 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
558 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
559
560 if (workloadInfo.m_InputTensorInfos.size() != 1)
561 {
James Ward47fce872020-09-10 11:57:28 +0100562 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
563 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100564
565 }
566
567 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
568 {
James Ward47fce872020-09-10 11:57:28 +0100569 throw InvalidArgumentException(fmt::format(
570 "Number of input infos ({0}) does not match the number of output infos ({1})",
571 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100572 }
573
574 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
575 {
576 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
577 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
578 {
James Ward47fce872020-09-10 11:57:28 +0100579 throw InvalidArgumentException(fmt::format(
580 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100581 }
582 }
583
584 if (m_Inputs.size() != 1)
585 {
James Ward47fce872020-09-10 11:57:28 +0100586 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100587 }
588
589 if (m_Inputs.size() != m_Outputs.size())
590 {
James Ward47fce872020-09-10 11:57:28 +0100591 throw InvalidArgumentException(fmt::format(
592 "Number of inputs ({0}) does not match the number of outputs ({1})",
593 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100594 }
595
596 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
597 {
598 if (!m_Inputs[i])
599 {
James Ward47fce872020-09-10 11:57:28 +0100600 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100601 }
602
603 if (!m_Outputs[i])
604 {
James Ward47fce872020-09-10 11:57:28 +0100605 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100606 }
607 }
608}
609
610//---------------------------------------------------------------
611void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
612{
613 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
Derek Lambertif674aa02019-08-01 15:56:25 +0100614
Derek Lambertif674aa02019-08-01 15:56:25 +0100615 if (m_Inputs.size() != 1)
616 {
James Ward47fce872020-09-10 11:57:28 +0100617 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100618 }
619
620 if (m_Outputs.size() != 0)
621 {
James Ward47fce872020-09-10 11:57:28 +0100622 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100623 }
624
625 if (!m_Inputs[0])
626 {
James Ward47fce872020-09-10 11:57:28 +0100627 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100628 }
629}
630
631//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000632void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
633{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100634 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100635
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100636 ValidateNumInputs(workloadInfo, descriptorName, 1);
637 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100638
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100639 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
640 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100641
642 std::vector<DataType> supportedTypes =
643 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000644 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100645 DataType::Float16,
646 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000647 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000648 DataType::QAsymmU8,
649 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100650 };
651
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100652 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
653 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
654 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000655}
656
Nikhil Rajee391d52019-09-05 17:50:44 +0100657void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
658{
659 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
660
661 ValidateNumInputs(workloadInfo, descriptorName, 1);
662 ValidateNumOutputs(workloadInfo, descriptorName, 1);
663
664 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
665 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
666
Inki Daed4619e22020-09-10 15:33:54 +0900667 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
668 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100669 {
Inki Daed4619e22020-09-10 15:33:54 +0900670 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100671 }
672
James Conroyd47a0642019-09-17 14:22:06 +0100673 std::vector<DataType> supportedInputTypes =
674 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000675 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100676 DataType::Float16,
677 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100678 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000679 DataType::QAsymmU8,
680 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900681 DataType::Signed32,
682 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100683 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100684
James Conroyd47a0642019-09-17 14:22:06 +0100685 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100686
687 auto inputShape = inputTensorInfo.GetShape();
688 auto outputShape = outputTensorInfo.GetShape();
689
690 auto inputNumDimensions = inputShape.GetNumDimensions();
691 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
692
693 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
694
695 // 1D input shape results in scalar output shape
696 if (inputShape.GetNumDimensions() == 1)
697 {
698 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
699 {
700 throw InvalidArgumentException(descriptorName + outputShapeError);
701 }
702 }
703 else
704 {
705 for (unsigned int i = 0; i < unsignedAxis; ++i)
706 {
707 if (outputShape[i] != inputShape[i])
708 {
709 throw InvalidArgumentException(descriptorName + outputShapeError);
710 }
711 }
712
713 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
714 {
715 if (outputShape[i - 1] != inputShape[i])
716 {
717 throw InvalidArgumentException(descriptorName + outputShapeError);
718 }
719 }
720 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100721}
722
mathad01b392e982021-04-07 12:07:30 +0100723void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
724{
725 const std::string descriptorName{"CastQueueDescriptor"};
726
727 ValidateNumInputs(workloadInfo, descriptorName, 1);
728 ValidateNumOutputs(workloadInfo, descriptorName, 1);
729
730 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
731 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
732
733 std::vector<DataType> supportedTypes =
734 {
735 DataType::BFloat16,
736 DataType::Float16,
737 DataType::Float32,
738 DataType::QAsymmS8,
739 DataType::QAsymmU8,
740 DataType::QSymmS8,
741 DataType::QSymmS16,
742 DataType::Signed32,
743 DataType::Signed64
744 };
745
746 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
747 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
748}
749
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100750void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
751{
752 const std::string descriptorName{"SoftmaxQueueDescriptor"};
753
754 ValidateNumInputs(workloadInfo, descriptorName, 1);
755 ValidateNumOutputs(workloadInfo, descriptorName, 1);
756
757 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
758 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
759
760 std::vector<DataType> supportedTypes =
761 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000762 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100763 DataType::Float16,
764 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000765 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000766 DataType::QAsymmU8,
767 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100768 };
769
770 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
771 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
772 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
773}
774
telsoa014fcda012018-03-09 14:13:49 +0000775void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
776{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100777 const std::string descriptorName{"SplitterQueueDescriptor"};
778
779 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000780
Ruomei Yan25339c32019-05-28 16:48:20 +0100781 // Check the supported data types
782 std::vector<DataType> supportedTypes =
783 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000784 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100785 DataType::Float32,
786 DataType::Float16,
787 DataType::Boolean,
788 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100789 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000790 DataType::QAsymmU8,
791 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100792 };
793
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100794 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
795 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100796 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100797 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
798 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
799
800 const std::string outputName = "output_" + std::to_string(i);
801 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100802 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100803
telsoa014fcda012018-03-09 14:13:49 +0000804 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
805 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100806 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000807 }
808
809 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
810 {
811 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100812 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000813 "has to match number of workloadInfo.m_OutputTensorInfos. "
814 "Number of windows: " +
815 to_string(m_ViewOrigins.size()) +
816 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
817 }
818
telsoa01c577f2c2018-08-31 09:22:23 +0100819 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000820 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
821 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
822 {
telsoa01c577f2c2018-08-31 09:22:23 +0100823 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000824 ViewOrigin const& e = m_ViewOrigins[w];
825 if (e.m_Origin.size() != inputDims)
826 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100827 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000828 "have the same dimensionality as the input tensor. "
829 "Window origin (index: " +
830 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
831 " dimensions, the input "
832 "tensor has " +
833 to_string(inputDims) + " dimensions.");
834 }
835 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
836 {
837 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
838 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
839 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100840 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000841 "be smaller or equal than the size of the input in that coord.");
842 }
843 }
844 }
845}
846
Jim Flynne242f2d2019-05-22 14:24:13 +0100847void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000848{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100849 const std::string descriptorName{"ConcatQueueDescriptor"};
850
851 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000852
853 if (m_Inputs.size() <= 0)
854 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100855 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000856 }
857 if (m_Outputs.size() <= 0)
858 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100859 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000860 }
861
862 if (workloadInfo.m_InputTensorInfos.size() <= 0)
863 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100864 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000865 }
866 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
867 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100868 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000869 }
870
Nikhil Raj8599a412018-11-19 14:51:07 +0000871 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
872 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100873 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000874 }
875
876 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
877 {
878 return;
879 }
880
telsoa014fcda012018-03-09 14:13:49 +0000881 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
882 {
883 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100884 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000885 "has to match number of workloadInfo.m_InputTensorInfos. "
886 "Number of windows: " +
887 to_string(m_ViewOrigins.size()) +
888 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
889 }
890
telsoa01c577f2c2018-08-31 09:22:23 +0100891 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000892 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
893 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
894 {
telsoa01c577f2c2018-08-31 09:22:23 +0100895 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000896 ViewOrigin const& e = m_ViewOrigins[w];
897 if (e.m_Origin.size() != outputDims)
898 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100899 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000900 "have the same dimensionality as the output tensor. "
901 "Window origin (index: " +
902 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
903 " dimensions, the output "
904 "tensor has " +
905 to_string(outputDims) + " dimensions.");
906 }
telsoa01c577f2c2018-08-31 09:22:23 +0100907 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000908 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
909 {
910 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
911 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
912 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100913 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000914 "be smaller or equal than the size of the output in that coord.");
915 }
916 }
917 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100918
919 // Check the supported data types
920 std::vector<DataType> supportedTypes =
921 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000922 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100923 DataType::Float32,
924 DataType::Float16,
925 DataType::Boolean,
926 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100927 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000928 DataType::QAsymmU8,
929 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100930 };
931
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100932 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
933 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100934 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100935 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
936 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
937
938 const std::string inputName = "input_" + std::to_string(i);
939 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100940 }
telsoa014fcda012018-03-09 14:13:49 +0000941}
942
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100943void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
944{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100945 const std::string descriptorName{"StackQueueDescriptor"};
946
947 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100948
949 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
950 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100951 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100952 }
953
954 // All inputs must have the same shape, which is defined in parameters
955 const TensorShape& inputShape = m_Parameters.m_InputShape;
956 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
957 {
958 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
959 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100960 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100961 }
962 }
963
Matthew Jacksondba634f2019-08-15 15:14:18 +0100964 if (inputShape.GetNumDimensions() > 4)
965 {
966 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
967 }
968
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100969 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
970 // since the output tensor has an additional dimension.
971 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
972 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100973 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100974 "than the number of input dimensions.");
975 }
976
977 // Output shape must be as inferred from the input shape
978 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
979 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
980 {
981 if (outputShape[i] != inputShape[i])
982 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100983 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100984 "match shape inferred from input tensor.");
985 }
986 }
987
988 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
989 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100990 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100991 "match shape inferred from input tensor.");
992 }
993
994 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
995 {
996 if (outputShape[i] != inputShape[i-1])
997 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100998 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100999 "match shape inferred from input tensor.");
1000 }
1001 }
1002
Matthew Jacksondba634f2019-08-15 15:14:18 +01001003 if (outputShape.GetNumDimensions() > 5)
1004 {
1005 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
1006 }
1007
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001008 // Check the supported data types
1009 std::vector<DataType> supportedTypes =
1010 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001011 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001012 DataType::Float32,
1013 DataType::Float16,
1014 DataType::Boolean,
1015 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001016 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001017 DataType::QAsymmU8,
1018 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001019 };
1020
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001021 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001022
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001023 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001024 {
1025 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1026 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001027 descriptorName,
1028 "input_0",
1029 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001030 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001031
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001032 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1033 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001034 descriptorName,
1035 "input_0",
1036 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001037}
1038
Ryan OSheaec6c6802020-06-05 17:17:06 +01001039void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1040{
1041 const std::string descriptorName{"FillQueueDescriptor"};
1042
1043 ValidateNumInputs(workloadInfo, descriptorName, 1);
1044 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1045
1046 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1047 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1048
1049 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1050
1051 std::vector<DataType> supportedTypes =
1052 {
1053 DataType::BFloat16,
1054 DataType::Float32,
1055 DataType::Float16,
1056 DataType::Signed32
1057 };
1058
1059 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1060}
1061
telsoa014fcda012018-03-09 14:13:49 +00001062void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1063{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001064 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001065
Matthew Sloyan81beae32021-07-13 19:46:11 +01001066 uint32_t numInputs = 2;
1067 if (m_Parameters.m_BiasEnabled)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001068 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001069 numInputs = 3;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001070 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001071
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001072 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001073 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1074
1075 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1076 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1077
1078 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1079
1080 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001081 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001082 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001083 }
1084
Matthew Sloyan81beae32021-07-13 19:46:11 +01001085 TensorInfo weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001086 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001087
1088 if (m_Parameters.m_BiasEnabled)
1089 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001090 TensorInfo biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
telsoa01c577f2c2018-08-31 09:22:23 +01001091 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001092 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001093 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1094 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001095 }
1096
Francis Murtagh46c09d02019-05-28 08:15:28 +01001097 // Check the supported data types
1098 std::vector<DataType> supportedTypes =
1099 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001100 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001101 DataType::Float32,
1102 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001103 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001104 DataType::QAsymmU8,
1105 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001106 };
1107
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001108 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001109
1110 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1111 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1112 {
1113 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1114 {
1115 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1116 "for BFloat16 input.");
1117 }
1118 }
1119 else
1120 {
1121 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1122 }
telsoa014fcda012018-03-09 14:13:49 +00001123}
1124
telsoa014fcda012018-03-09 14:13:49 +00001125void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1126{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001127 const std::string descriptorName{"NormalizationQueueDescriptor"};
1128
1129 ValidateNumInputs(workloadInfo, descriptorName, 1);
1130 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1131
1132 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1133 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001134
1135 // Check the supported data types
1136 std::vector<DataType> supportedTypes =
1137 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001138 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001139 DataType::Float16,
1140 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001141 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001142 DataType::QAsymmU8,
1143 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001144 };
1145
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001146 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001147
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001148 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001149
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001150 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001151}
1152
1153void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1154{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001155 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001156
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001157 ValidateNumInputs(workloadInfo, descriptorName, 2);
1158 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1159
1160 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1161 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1162 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1163
1164 std::vector<DataType> supportedTypes =
1165 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001166 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001167 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001168 DataType::Float16,
1169 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001170 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001171 DataType::QSymmS16,
1172 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001173 };
1174
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001175 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1176 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1177 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001178
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001179 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1180 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001181
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001182 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1183 inputTensorInfo1,
1184 outputTensorInfo,
1185 descriptorName,
1186 "input_0",
1187 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001188}
1189
telsoa014fcda012018-03-09 14:13:49 +00001190void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1191{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001192 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001193
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001194 ValidateNumInputs(workloadInfo, descriptorName, 2);
1195 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1196
1197 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1198 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1199 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1200
1201 std::vector<DataType> supportedTypes =
1202 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001203 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001204 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001205 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001206 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001207 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001208 DataType::QSymmS16,
1209 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001210 };
1211
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001212 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1213 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1214 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001215
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001216 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1217 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001218
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001219 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1220 inputTensorInfo1,
1221 outputTensorInfo,
1222 descriptorName,
1223 "input_0",
1224 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001225}
1226
1227void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1228{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001229 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001230
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001231 ValidateNumInputs(workloadInfo, descriptorName, 1);
1232 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1233
1234 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1235 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001236
1237 std::vector<DataType> supportedTypes =
1238 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001239 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001240 DataType::Float16,
1241 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001242 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001243 DataType::QAsymmU8,
1244 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001245 };
1246
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001247 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1248 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001249
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001250 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001251 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001252
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001253 ValidatePointer(m_Mean, descriptorName, "mean");
1254 ValidatePointer(m_Variance, descriptorName, "variance");
1255 ValidatePointer(m_Beta, descriptorName, "beta");
1256 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001257
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001258 const TensorInfo& mean = m_Mean->GetTensorInfo();
1259 const TensorInfo& variance = m_Variance->GetTensorInfo();
1260 const TensorInfo& beta = m_Beta->GetTensorInfo();
1261 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001262
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001263 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1264 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1265 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1266 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001267
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001268 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1269 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1270 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001271}
1272
1273void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1274{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001275 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001276
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001277 uint32_t numInputs = 2;
1278 if (m_Parameters.m_BiasEnabled)
1279 {
1280 numInputs = 3;
1281 }
1282
1283 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001284 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001285
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001286 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1287 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001288
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001289 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1290 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001291
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001292 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
telsoa014fcda012018-03-09 14:13:49 +00001293
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001294 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001295
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001296 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001297
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001298 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001299 if (m_Parameters.m_BiasEnabled)
1300 {
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001301 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001302 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001303
1304 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1305 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001306 }
1307
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001308 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1309 {
1310 throw InvalidArgumentException(
1311 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1312 "cannot be either negative or 0.",
1313 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1314 }
1315
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001316 ValidatePerAxisQuantization(inputTensorInfo,
1317 outputTensorInfo,
1318 weightTensorInfo,
1319 optionalBiasTensorInfo,
1320 descriptorName);
1321
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001322 std::vector<DataType> supportedTypes =
1323 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001324 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001325 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001326 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001327 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001328 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001329 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001330 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001331 };
1332
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001333 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001334
1335 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1336 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1337 {
1338 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1339 {
1340 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1341 "for BFloat16 input.");
1342 }
1343 }
1344 else
1345 {
1346 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1347 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001348}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001349
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001350void Convolution3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1351{
1352 const std::string descriptorName{"Convolution3dQueueDescriptor"};
1353
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001354 uint32_t numInputs = 2;
1355 if (m_Parameters.m_BiasEnabled)
1356 {
1357 numInputs = 3;
1358 }
1359 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001360 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1361
1362 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1363 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1364
1365 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1366 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1367
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001368 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001369 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 5, "weight");
1370
1371 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1372
1373 Optional<TensorInfo> optionalBiasTensorInfo;
1374 if (m_Parameters.m_BiasEnabled)
1375 {
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001376 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001377 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1378
1379 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1380 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1381 }
1382
1383 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 || m_Parameters.m_StrideZ <= 0 )
1384 {
1385 throw InvalidArgumentException(
1386 fmt::format("{}: strideX (provided {}), strideY (provided {}) or strideZ (provided {})"
1387 "cannot be either negative or 0.",
1388 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY, m_Parameters.m_StrideZ));
1389 }
1390
1391 ValidatePerAxisQuantization(inputTensorInfo,
1392 outputTensorInfo,
1393 weightTensorInfo,
1394 optionalBiasTensorInfo,
1395 descriptorName);
1396
1397 std::vector<DataType> supportedTypes =
1398 {
1399 DataType::BFloat16,
1400 DataType::Float16,
1401 DataType::Float32,
1402 DataType::QAsymmS8,
1403 DataType::QAsymmU8,
1404 DataType::QSymmS16,
1405 DataType::QSymmS8
1406 };
1407
1408 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1409 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1410}
1411
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001412void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1413{
1414 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1415
Cathal Corbett06902652022-04-14 17:55:11 +01001416 uint32_t numInputs = 2;
1417 if (m_Parameters.m_BiasEnabled)
1418 {
1419 numInputs = 3;
1420 }
1421
1422 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001423 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1424
1425 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1426 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1427
1428 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1429 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1430
Cathal Corbett06902652022-04-14 17:55:11 +01001431 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001432 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1433
1434 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1435 {
1436 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001437 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1438 "cannot be smaller than 1.",
1439 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001440 }
1441
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001442 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1443 {
1444 throw InvalidArgumentException(
1445 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1446 "cannot be either negative or 0.",
1447 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1448 }
1449
Jan Eilers53ef7952021-06-02 12:01:25 +01001450 if (weightTensorInfo.GetShape()[0] != 1)
1451 {
1452 throw InvalidArgumentException(fmt::format(
1453 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1454 "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
1455 descriptorName,
1456 weightTensorInfo.GetShape()[0],
1457 weightTensorInfo.GetShape()[1],
1458 weightTensorInfo.GetShape()[2],
1459 weightTensorInfo.GetShape()[3]));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001460 }
1461
Cathal Corbett4b19d222022-05-11 20:12:17 +01001462 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1463 const unsigned int numWeightOutputChannelsRefFormat = weightTensorInfo.GetShape()[3];
1464 const unsigned int numWeightOutputChannelsAclFormat = weightTensorInfo.GetShape()[1];
1465 const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1466
1467 // Weights format has two valid options: [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] (CpuAcc/GpuAcc).
1468 bool validRefFormat = (numWeightOutputChannelsRefFormat == numOutputChannels);
1469 bool validAclFormat = (numWeightOutputChannelsAclFormat == numOutputChannels);
1470
1471 if (!(validRefFormat || validAclFormat))
1472 {
1473 throw InvalidArgumentException(fmt::format(
1474 "{0}: The weight format in armnn is expected to be [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] "
1475 "(CpuAcc/GpuAcc). But neither the 4th (CpuRef) or 2nd (CpuAcc/GpuAcc) dimension is equal to Cout."
1476 "Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
1477 descriptorName,
1478 numOutputChannels,
1479 weightTensorInfo.GetShape()[0],
1480 weightTensorInfo.GetShape()[1],
1481 weightTensorInfo.GetShape()[2],
1482 weightTensorInfo.GetShape()[3]));
1483 }
1484
Teresa Charlind8df0262019-11-11 12:28:15 +00001485 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001486
Teresa Charlind8df0262019-11-11 12:28:15 +00001487 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001488 if (m_Parameters.m_BiasEnabled)
1489 {
Cathal Corbett06902652022-04-14 17:55:11 +01001490 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Teresa Charlind8df0262019-11-11 12:28:15 +00001491 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001492
1493 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1494 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1495 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001496 ValidatePerAxisQuantization(inputTensorInfo,
1497 outputTensorInfo,
1498 weightTensorInfo,
1499 optionalBiasTensorInfo,
1500 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001501
1502 std::vector<DataType> supportedTypes =
1503 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001504 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001505 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001506 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001507 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001508 DataType::QAsymmU8,
1509 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001510 };
1511
1512 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1513 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001514}
1515
1516void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1517{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001518 const std::string descriptorName{"PermuteQueueDescriptor"};
1519
1520 ValidateNumInputs(workloadInfo, descriptorName, 1);
1521 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001522
1523 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1524
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001525 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1526 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001527
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001528 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1529 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001530
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001531 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001532 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001533 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001534 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001535 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1536 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1537 "must match dst dimension " + to_string(mapping[i]) +
1538 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001539 }
1540 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001541
1542 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001543}
1544
1545void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1546{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001547 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001548
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001549 ValidateNumInputs(workloadInfo, descriptorName, 1);
1550 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1551
1552 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1553 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1554
1555 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1556 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001557
1558 std::vector<DataType> supportedTypes =
1559 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001560 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001561 DataType::Float32,
1562 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001563 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001564 DataType::QAsymmU8,
1565 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001566 };
1567
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001568 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1569 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001570}
1571
Tamás Nyíri7b885b32021-10-26 14:47:57 +01001572void Pooling3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1573{
1574 const std::string descriptorName{"Pooling3dQueueDescriptor"};
1575
1576 ValidateNumInputs(workloadInfo, descriptorName, 1);
1577 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1578
1579 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1580 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1581
1582 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1583 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1584
1585 std::vector<DataType> supportedTypes =
1586 {
1587 DataType::BFloat16,
1588 DataType::Float32,
1589 DataType::Float16,
1590 DataType::QAsymmS8,
1591 DataType::QAsymmU8,
1592 DataType::QSymmS16
1593 };
1594
1595 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1596 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1597}
1598
Teresa Charlin970f43b2019-07-01 13:51:07 +01001599void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1600{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001601 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001602
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001603 ValidateNumInputs(workloadInfo, descriptorName, 1);
1604 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1605
1606 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1607 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1608
1609 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1610 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001611
1612 std::vector<DataType> supportedTypes =
1613 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001614 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001615 DataType::Float16,
1616 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001617 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001618 DataType::QAsymmU8,
1619 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001620 };
1621
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001622 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1623 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001624
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001625 // Resize only changes width and height: batch and channel count must match.
1626 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1627 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001628 if (inputBatchSize != outputBatchSize)
1629 {
1630 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001631 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1632 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001633 }
1634
1635 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001636 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1637 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001638 if (inputChannelCount != outputChannelCount)
1639 {
1640 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001641 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1642 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001643 }
1644}
1645
1646void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1647{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001648 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001649
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001650 ValidateNumInputs(workloadInfo, descriptorName, 1);
1651 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1652
1653 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1654 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1655
1656 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1657 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1658
1659 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1660
telsoa014fcda012018-03-09 14:13:49 +00001661 if (m_Parameters.m_Min > m_Parameters.m_Max)
1662 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001663 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001664 }
telsoa014fcda012018-03-09 14:13:49 +00001665}
1666
Kevin Mayce5045a2019-10-02 14:07:47 +01001667void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1668{
1669 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1670
1671 ValidateNumInputs(workloadInfo, descriptorName, 1);
1672 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1673
1674 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1675 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1676
1677 if (inputTensorInfo.GetNumDimensions() > 4)
1678 {
1679 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1680 }
1681
1682 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1683
1684 // Check the supported data types
1685 std::vector<DataType> supportedTypes =
1686 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001687 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001688 DataType::Float32,
1689 DataType::Float16
1690 };
1691
1692 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001693 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001694}
1695
telsoa014fcda012018-03-09 14:13:49 +00001696void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1697{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001698 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001699
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001700 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001701 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1702
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001703 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1704 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1705
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001706 if (inputTensorInfo.GetNumDimensions() > 4)
1707 {
1708 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1709 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001710
1711 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001712
1713 // Check the supported data types
1714 std::vector<DataType> supportedTypes =
1715 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001716 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001717 DataType::Float32,
1718 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001719 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001720 DataType::QAsymmU8,
1721 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001722 };
1723
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001724 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001725 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1726}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001727
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001728void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1729{
1730 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1731
1732 ValidateNumInputs(workloadInfo, descriptorName, 1);
1733 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1734
1735 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1736 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1737
1738 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1739
1740 std::vector<DataType> supportedTypes =
1741 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001742 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001743 DataType::Float32,
1744 DataType::Float16,
1745 };
1746
1747 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001748 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001749}
1750
1751void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1752{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001753 const std::string descriptorName{"ConstantQueueDescriptor"};
1754
1755 ValidateNumInputs(workloadInfo, descriptorName, 0);
1756 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001757
1758 if (!m_LayerOutput)
1759 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001760 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001761 }
1762
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001763 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1764 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001765
1766 // Check the supported data types
1767 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001768 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001769 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001770 DataType::Float32,
1771 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001772 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001773 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001774 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001775 DataType::QSymmS16,
1776 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001777 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001778
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001779 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001780}
1781
1782void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1783{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001784 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001785
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001786 ValidateNumInputs(workloadInfo, descriptorName, 1);
1787 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1788
1789 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1790 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1791
1792 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001793
1794 // Check the supported data types
1795 std::vector<DataType> supportedTypes =
1796 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001797 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001798 DataType::Float32,
1799 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001800 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001801 DataType::QAsymmU8,
1802 DataType::QSymmS16,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001803 DataType::Signed32,
1804 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001805 };
1806
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001807 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1808 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001809}
1810
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001811void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1812{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001813 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001814
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001815 ValidateNumInputs(workloadInfo, descriptorName, 1);
1816 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1817
1818 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1819 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1820
1821 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1822 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001823
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001824 if (m_Parameters.m_BlockShape.size() != 2)
1825 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001826 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001827 }
1828
1829 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1830 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001831 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1832 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001833 }
1834
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001835 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001836
1837 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001838 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001839
Matthew Bentham8800c002018-11-19 13:19:28 +00001840 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001841
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001842 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1843 widthPad.first + widthPad.second;
1844 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1845 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001846
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001847 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1848 inputShape[dimensionIndices.GetChannelsIndex()];
1849 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001850
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001851 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001852 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001853 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001854 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001855 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001856 }
1857
1858 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001859 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001860 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1861 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001862 }
nikraj01120522a2019-05-31 11:33:07 +01001863
1864 std::vector<DataType> supportedTypes =
1865 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001866 DataType::BFloat16,
1867 DataType::Float16,
1868 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001869 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001870 DataType::QAsymmU8,
1871 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001872 };
1873
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001874 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1875 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001876}
1877
Keith Davisa57eccb2019-06-14 17:33:22 +01001878void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1879{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001880 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001881
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001882 ValidateNumInputs(workloadInfo, descriptorName, 1);
1883 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001884
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001885 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1886 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1887
1888 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1889 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001890
1891 std::vector<DataType> supportedTypes =
1892 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001893 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001894 DataType::Float32,
1895 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001896 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001897 DataType::QAsymmU8,
1898 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001899 };
1900
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001901 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1902 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001903
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001904 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1905
1906 if (m_Parameters.m_BlockSize == 0)
1907 {
1908 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1909 }
1910
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001911 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1912 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1913 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1914 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001915
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001916 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001917 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001918 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001919 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1920 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001921 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001922
1923 const TensorShape& outputShape = outputTensorInfo.GetShape();
1924 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1925 {
1926 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1927 "must be divisible by the square of block size." );
1928 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001929}
1930
telsoa014fcda012018-03-09 14:13:49 +00001931void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1932{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001933 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001934
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001935 ValidateNumInputs(workloadInfo, descriptorName, 1);
1936 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1937
1938 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1939 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001940
1941 std::vector<DataType> supportedTypes =
1942 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001943 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001944 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001945 DataType::Float16,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001946 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001947 };
1948
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001949 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matthew Sloyan81beae32021-07-13 19:46:11 +01001950 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1951 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1952 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001953}
1954
telsoa01c577f2c2018-08-31 09:22:23 +01001955void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1956{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001957 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1958
1959 const std::string descriptorName{"LstmQueueDescriptor"};
1960
1961 // check dimensions of all inputs and outputs
1962 if (workloadInfo.m_InputTensorInfos.size() != 3)
1963 {
1964 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1965 }
1966 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1967 {
1968 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1969 }
1970
1971 std::vector<DataType> supportedTypes =
1972 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001973 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001974 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001975 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001976 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001977 };
1978
Jan Eilers38e05bd2019-06-26 13:10:09 +01001979 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001980 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1981
Jan Eilers38e05bd2019-06-26 13:10:09 +01001982 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001983 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001984 {
1985 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1986 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001987 descriptorName,
1988 "input_0",
1989 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001990 }
1991 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001992 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001993 {
1994 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1995 workloadInfo.m_OutputTensorInfos[i],
1996 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001997 "input_0",
1998 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001999 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002000
janeil0117d8d852019-11-15 15:00:16 +00002001 // Making sure clipping parameters have valid values.
2002 // == 0 means no clipping
2003 // > 0 means clipping
2004 if (m_Parameters.m_ClippingThresCell < 0.0f)
2005 {
2006 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
2007 }
2008 if (m_Parameters.m_ClippingThresProj < 0.0f)
2009 {
2010 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
2011 }
2012
Jan Eilers38e05bd2019-06-26 13:10:09 +01002013 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01002014 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
2015 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
2016 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
2017 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
2018 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
2019 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
2020
Jan Eilers38e05bd2019-06-26 13:10:09 +01002021 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002022 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
2023 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002024 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002025 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
2026 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002027 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002028 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
2029 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002030 // scratchBufferTensor
2031 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002032 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
2033 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002034 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002035 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
2036 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002037 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002038 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
2039 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002040 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002041 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
2042 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002043
Jan Eilers38e05bd2019-06-26 13:10:09 +01002044 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
2045 if ( m_InputToInputWeights )
2046 {
2047 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
2048 (n_cell * n_input), "InputLayerNormWeights");
2049 }
2050
2051 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2052 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
2053 (n_cell * n_input), "InputToForgetWeights");
2054
2055 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2056 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2057 (n_cell * n_input), "InputToCellWeights");
2058
2059 if ( m_RecurrentToInputWeights )
2060 {
2061 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2062 (n_cell * n_output), "RecurrentToInputWeights");
2063 }
2064
2065 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2066 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2067 (n_cell * n_output), "RecurrentToForgetWeights");
2068
2069 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2070 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2071 (n_cell * n_output), "RecurrentToCellWeights");
2072
2073 // Make sure the input-gate's parameters are either both present (regular
2074 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2075 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2076 !m_Parameters.m_CifgEnabled) ||
2077 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2078 m_Parameters.m_CifgEnabled));
2079 if (!cifg_weights_all_or_none)
2080 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002081 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2082 "RecurrentToInputWeights must either both be present (regular LSTM) "
2083 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2084 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002085 }
2086
2087 if ( m_CellToInputWeights )
2088 {
2089 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2090 n_cell, "CellToInputWeights");
2091 }
2092 if ( m_CellToForgetWeights )
2093 {
2094 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2095 n_cell, "CellToForgetWeights");
2096 }
2097 if ( m_CellToOutputWeights )
2098 {
2099 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2100 n_cell, "CellToOutputWeights");
2101 }
2102
2103 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2104 bool peephole_weights_all_or_none =
2105 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2106 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2107 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2108 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2109 if (!peephole_weights_all_or_none)
2110 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002111 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002112 }
2113
2114 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2115 if (m_Parameters.m_CifgEnabled)
2116 {
2117 if (m_InputGateBias)
2118 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002119 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002120 }
2121 }
2122 else
2123 {
2124 if (!m_InputGateBias)
2125 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002126 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2127 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002128 }
2129 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2130 n_cell, "InputGateBias");
2131 }
2132
2133 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2134 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2135
2136 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2137 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2138
2139 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2140 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2141
2142 if (m_ProjectionWeights)
2143 {
2144 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2145 (n_cell * n_output), "ProjectionWeights");
2146 }
2147 if (m_ProjectionBias)
2148 {
2149 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2150 }
2151
2152 // Making sure the projection tensors are consistent:
2153 // 1) If projection weight is not present, then projection bias should not be
2154 // present.
2155 // 2) If projection weight is present, then projection bias is optional.
2156 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2157 !m_Parameters.m_ProjectionEnabled)
2158 || (m_ProjectionWeights && !m_ProjectionBias &&
2159 m_Parameters.m_ProjectionEnabled)
2160 || (m_ProjectionWeights && m_ProjectionBias &&
2161 m_Parameters.m_ProjectionEnabled));
2162 if (!projecton_tensors_consistent)
2163 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002164 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002165 }
2166
2167 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2168 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2169 // either all have values or none of them have values. Layer normalization is used when the values of all the
2170 // layer normalization weights are present
2171 if (m_InputLayerNormWeights)
2172 {
2173 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2174 }
2175 if (m_ForgetLayerNormWeights)
2176 {
2177 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2178 }
2179 if (m_CellLayerNormWeights)
2180 {
2181 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2182 }
2183 if (m_OutputLayerNormWeights)
2184 {
2185 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2186 }
2187
Jan Eilers38e05bd2019-06-26 13:10:09 +01002188 if (m_Parameters.m_LayerNormEnabled)
2189 {
2190 if (!m_Parameters.m_CifgEnabled)
2191 {
2192 if (!m_InputLayerNormWeights)
2193 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002194 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2195 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002196 }
2197 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2198 1, n_cell, "InputLayerNormWeights");
2199 }
2200 else if (m_InputLayerNormWeights)
2201 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002202 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2203 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002204 }
2205
2206 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2207 "ForgetLayerNormWeights");
2208 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2209
2210 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2211 "OutputLayerNormWeights");
2212 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2213
2214 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2215 "CellLayerNormWeights");
2216 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2217 }
2218 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2219 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002220 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2221 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002222 }
telsoa01c577f2c2018-08-31 09:22:23 +01002223}
2224
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002225void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2226{
2227 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2228
2229 ValidateNumInputs(workloadInfo, descriptorName, 1);
2230 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2231
2232 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2233 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2234
2235 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2236 {
2237 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2238 }
2239
2240 if (outputTensorInfo.GetDataType() != DataType::Float32)
2241 {
2242 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2243 }
2244
2245 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2246}
2247
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002248void ConvertFp32ToBf16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2249{
2250 const std::string descriptorName{"ConvertFp32ToBf16QueueDescriptor"};
2251
2252 ValidateNumInputs(workloadInfo, descriptorName, 1);
2253 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2254
2255 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2256 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2257
2258 if (inputTensorInfo.GetDataType() != DataType::Float32)
2259 {
2260 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2261 }
2262
2263 if (outputTensorInfo.GetDataType() != DataType::BFloat16)
2264 {
2265 throw InvalidArgumentException(descriptorName + ": Output tensor type must be BFloat16.");
2266 }
2267
2268 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2269}
2270
telsoa01c577f2c2018-08-31 09:22:23 +01002271void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2272{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002273 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002274
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002275 ValidateNumInputs(workloadInfo, descriptorName, 1);
2276 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2277
2278 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2279 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2280
2281 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002282 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002283 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002284 }
2285
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002286 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002287 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002288 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002289 }
2290
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002291 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002292}
2293
2294void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2295{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002296 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002297
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002298 ValidateNumInputs(workloadInfo, descriptorName, 1);
2299 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2300
2301 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2302 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2303
2304 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002305 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002306 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002307 }
2308
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002309 if (outputTensorInfo.GetDataType() != DataType::Float32)
2310 {
2311 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2312 }
2313
2314 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002315}
2316
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002317void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2318{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002319 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002320
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002321 ValidateNumInputs(workloadInfo, descriptorName, 2);
2322 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2323
2324 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2325 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2326 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2327
2328 std::vector<DataType> supportedTypes =
2329 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002330 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002331 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002332 DataType::Float32,
2333 DataType::QAsymmS8,
2334 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002335 DataType::QSymmS16,
2336 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002337 };
2338
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002339 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2340 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2341 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002342
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002343 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2344 inputTensorInfo1,
2345 outputTensorInfo,
2346 descriptorName,
2347 "input_0",
2348 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002349}
2350
David Beckc2044fe2018-09-05 15:00:38 +01002351void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2352{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002353 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002354
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002355 ValidateNumInputs(workloadInfo, descriptorName, 2);
2356 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2357
2358 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2359 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2360 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2361
2362 std::vector<DataType> supportedTypes =
2363 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002364 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002365 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002366 DataType::Float32,
2367 DataType::QAsymmS8,
2368 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002369 DataType::QSymmS16,
2370 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002371 };
2372
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002373 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2374 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2375 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002376
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002377 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2378 inputTensorInfo1,
2379 outputTensorInfo,
2380 descriptorName,
2381 "input_0",
2382 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002383}
2384
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002385void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2386{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002387 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002388
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002389 ValidateNumInputs(workloadInfo, descriptorName, 2);
2390 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2391
2392 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2393 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2394 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2395
2396 std::vector<DataType> supportedTypes =
2397 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002398 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002399 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002400 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002401 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002402 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002403 DataType::QSymmS16,
2404 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002405 };
2406
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002407 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2408 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2409 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002410
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002411 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2412 inputTensorInfo1,
2413 outputTensorInfo,
2414 descriptorName,
2415 "input_0",
2416 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002417}
2418
narpra01a6bf9122018-09-10 09:50:09 +01002419void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2420{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002421 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002422
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002423 ValidateNumInputs(workloadInfo, descriptorName, 1);
2424 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2425
2426 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2427 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002428
2429 std::vector<DataType> supportedTypes =
2430 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002431 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002432 DataType::Float32,
2433 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002434 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002435 DataType::QAsymmU8,
2436 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002437 };
narpra01eb061912018-09-10 17:35:27 +01002438
James Conroy4d1ff582019-06-10 17:06:39 +01002439 // First check if input tensor data type is supported, then
2440 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002441 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2442 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002443
narpra0132b90462018-09-13 11:07:48 +01002444 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002445 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002446 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002447 }
narpra0132b90462018-09-13 11:07:48 +01002448 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002449 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002450 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002451 }
2452 else
2453 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002454 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002455 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002456 ValidateTensorNumDimensions(outputTensorInfo,
2457 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002458 outputDim > 0 ? outputDim : 1,
2459 "output");
2460 }
narpra01a6bf9122018-09-10 09:50:09 +01002461}
2462
jimfly012c9322a2018-09-19 10:59:49 +01002463void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2464{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002465 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002466
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002467 ValidateNumInputs(workloadInfo, descriptorName, 1);
2468 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2469
2470 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2471 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002472
jimfly012c9322a2018-09-19 10:59:49 +01002473 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002474 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2475
jimfly012c9322a2018-09-19 10:59:49 +01002476 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002477 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2478 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2479 "as there are dimensions in the input tensor that is " +
2480 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2481 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002482 }
2483}
2484
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002485void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2486{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002487 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002488
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002489 ValidateNumInputs(workloadInfo, descriptorName, 1);
2490 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002491
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002492 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2493 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2494
Sadik Armagan2208b602019-07-31 16:36:27 +01002495 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002496 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002497 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002498 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002499 DataType::Float16,
2500 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002501 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002502 DataType::QAsymmU8,
2503 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002504 };
2505
2506 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002507
Keith Davis0c2eeac2020-02-11 16:51:50 +00002508 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002509 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002510 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002511 }
2512}
2513
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002514void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2515{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002516 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002517
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002518 ValidateNumInputs(workloadInfo, descriptorName, 1);
2519 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002520
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002521 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2522 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002523
2524 std::vector<DataType> supportedTypes =
2525 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002526 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002527 DataType::Float32,
2528 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002529 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002530 DataType::QAsymmU8,
2531 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002532 };
2533
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002534 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2535 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002536}
2537
Conor Kennedy430b5d82018-11-14 15:28:28 +00002538void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2539{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002540 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002541
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002542 ValidateNumInputs(workloadInfo, descriptorName, 1);
2543 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2544
2545 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2546 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002547
2548 std::vector<DataType> supportedTypes =
2549 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002550 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002551 DataType::Float16,
2552 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002553 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002554 DataType::QAsymmU8,
2555 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002556 };
2557
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002558 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2559 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002560
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002561 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002562
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002563 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002564 if (rank > 4)
2565 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002566 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002567 }
2568
Conor Kennedy430b5d82018-11-14 15:28:28 +00002569 // Begin, End & Stride length must be of rank(input0)
2570 if (m_Parameters.m_Begin.size() != rank)
2571 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002572 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002573 }
2574
2575 if (m_Parameters.m_End.size() != rank)
2576 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002577 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002578 }
2579
2580 if (m_Parameters.m_Stride.size() != rank)
2581 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002582 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002583 }
2584
2585 // Stride entries must be non-zero
2586 for (auto& stride : m_Parameters.m_Stride)
2587 {
2588 if (stride == 0)
2589 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002590 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002591 }
2592 }
2593}
2594
kevmay0190539692018-11-29 08:40:19 +00002595void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2596{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002597 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002598
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002599 ValidateNumInputs(workloadInfo, descriptorName, 2);
2600 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2601
2602 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2603 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2604 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2605
2606 std::vector<DataType> supportedTypes =
2607 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002608 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002609 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002610 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002611 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002612 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002613 DataType::QSymmS16,
2614 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002615 };
2616
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002617 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2618 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2619 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002620
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002621 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2622 inputTensorInfo1,
2623 outputTensorInfo,
2624 descriptorName,
2625 "input_0",
2626 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002627}
2628
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002629void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2630{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002631 const std::string descriptorName{"DebugQueueDescriptor"};
2632
2633 ValidateNumInputs(workloadInfo, descriptorName, 1);
2634 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002635}
2636
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002637void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2638{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002639 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002640
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002641 ValidateNumInputs(workloadInfo, descriptorName, 2);
2642 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002643
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002644 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2645 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2646 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2647
2648 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2649 inputTensorInfo1,
2650 outputTensorInfo,
2651 descriptorName,
2652 "input_0",
2653 "input_1");
2654
2655 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002656 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002657 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002658 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002659}
2660
FrancisMurtagh878f0232018-12-19 10:56:15 +00002661void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2662{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002663 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002664
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002665 ValidateNumInputs(workloadInfo, descriptorName, 2);
2666 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002667
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002668 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2669 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2670 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2671
2672 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2673 inputTensorInfo1,
2674 outputTensorInfo,
2675 descriptorName,
2676 "input_0",
2677 "input_1");
2678
2679 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002680 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002681 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002682 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002683}
2684
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002685void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2686{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002687 const std::string descriptorName{"RsqrtQueueDescriptor"};
2688
2689 ValidateNumInputs(workloadInfo, descriptorName, 1);
2690 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2691
2692 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2693 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2694
2695 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002696
2697 std::vector<DataType> supportedTypes =
2698 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002699 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002700 DataType::Float16,
2701 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002702 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002703 DataType::QAsymmU8,
2704 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002705 };
2706
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002707 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2708 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002709}
2710
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01002711void GatherNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2712{
2713 const std::string descriptorName{"GatherNdQueueDescriptor"};
2714
2715 ValidateNumInputs(workloadInfo, descriptorName, 2);
2716 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2717
2718 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2719 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
2720 {
2721 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
2722 }
2723
2724 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2725 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2726
2727 std::vector<DataType> supportedTypes =
2728 {
2729 DataType::BFloat16,
2730 DataType::Float16,
2731 DataType::Float32,
2732 DataType::QAsymmS8,
2733 DataType::QAsymmU8,
2734 DataType::QSymmS16,
2735 DataType::Signed32,
2736 };
2737
2738 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2739
2740 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2741
2742 unsigned int outputDim = outputTensorInfo.GetNumDimensions();
2743 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
2744}
2745
narpra01b89b05f2019-01-16 09:53:09 +00002746void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2747{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002748 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002749
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002750 ValidateNumInputs(workloadInfo, descriptorName, 2);
2751 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002752
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002753 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2754 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002755 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002756 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002757 }
2758
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002759 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2760 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2761
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002762 std::vector<DataType> supportedTypes =
2763 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002764 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002765 DataType::Float16,
2766 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002767 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002768 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002769 DataType::QSymmS16,
2770 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002771 };
2772
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002773 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002774
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002775 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002776
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002777 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2778 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002779}
2780
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002781void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2782{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002783 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2784
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002785 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002786
2787 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2788 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002789 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002790 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2791 }
2792
2793 if (m_Anchors == nullptr)
2794 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002795 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002796 }
2797
2798 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002799 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2800 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2801
2802 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002803 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002804 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2805 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002806
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002807 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2808 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2809 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002810
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002811 const std::vector<DataType> supportedInputTypes =
2812 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002813 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002814 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002815 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002816 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002817 DataType::QAsymmU8,
2818 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002819 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002820
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002821 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2822 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2823 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2824
2825 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2826 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2827 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2828 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2829
2830 // NOTE: Output is always Float32 regardless of input type
2831 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2832 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2833 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2834 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002835
2836 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2837 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002838 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002839 "must be positive and less than or equal to 1.");
2840 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002841
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002842 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2843 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002844 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002845 "should be equal to number of classes + 1.");
2846 }
2847}
2848
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002849void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2850{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002851 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002852
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002853 ValidateNumInputs(workloadInfo, descriptorName, 1);
2854 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2855
2856 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2857 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2858
Teresa Charlin07307f32022-05-15 14:07:05 +01002859 std::vector<DataType> inputSupportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002860 {
Teresa Charlin07307f32022-05-15 14:07:05 +01002861 DataType::QAsymmS8,
2862 DataType::QAsymmU8,
2863 DataType::QSymmS8,
2864 DataType::QSymmS16,
2865 DataType::Float16
2866 };
2867 ValidateDataTypes(inputTensorInfo, inputSupportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002868
Teresa Charlin07307f32022-05-15 14:07:05 +01002869 std::vector<DataType> outputSupportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002870 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002871 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002872 DataType::Float32,
2873 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002874 };
2875
Teresa Charlin07307f32022-05-15 14:07:05 +01002876 ValidateDataTypes(outputTensorInfo, outputSupportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002877}
2878
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002879void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2880{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002881 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002882
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002883 ValidateNumInputs(workloadInfo, descriptorName, 2);
2884 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002885
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002886 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2887 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2888 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002889
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002890 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2891 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2892
2893 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2894 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002895}
2896
Keith Davis3ae3f972021-05-21 16:33:48 +01002897void ShapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2898{
2899 const std::string& descriptorName{"ShapeQueueDescriptor"};
2900
2901 ValidateNumInputs(workloadInfo, descriptorName, 1);
2902 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2903
2904 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2905 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2906
2907 std::vector<DataType> supportedTypes =
2908 {
2909 DataType::BFloat16,
2910 DataType::Float16,
2911 DataType::Float32,
2912 DataType::QAsymmS8,
2913 DataType::QAsymmU8,
2914 DataType::QAsymmS8,
2915 DataType::QSymmS8,
2916 DataType::QSymmS16,
2917 DataType::Signed32
2918 };
2919
2920 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2921 ValidateDataTypes(outputTensorInfo, {DataType::Signed32}, descriptorName);
2922}
2923
Sadik Armaganeff363d2019-04-05 15:25:46 +01002924void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2925{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002926 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002927
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002928 ValidateNumInputs(workloadInfo, descriptorName, 2);
2929 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2930
2931 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2932 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2933
2934 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2935 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2936
2937 std::vector<DataType> supportedTypes =
2938 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002939 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002940 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002941 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002942 DataType::QAsymmU8,
2943 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002944 };
2945
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002946 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2947 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002948
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002949 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2950 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002951
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002952 ValidateTensorShapesMatch(inputTensorInfo0,
2953 outputTensorInfo0,
2954 descriptorName,
2955 "input_0",
2956 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002957
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002958 ValidateTensorShapesMatch(inputTensorInfo0,
2959 outputTensorInfo1,
2960 descriptorName,
2961 "input_0",
2962 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002963}
2964
Derek Lamberti901ea112019-12-10 22:07:09 +00002965void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002966{
2967 // This is internally generated so it should not need validation.
2968}
2969
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002970void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2971{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002972 const std::string& descriptorName{"PreluQueueDescriptor"};
2973
2974 ValidateNumInputs(workloadInfo, descriptorName, 2);
2975 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2976
2977 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2978 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2979 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002980
2981 std::vector<DataType> supportedTypes
2982 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002983 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002984 DataType::Float16,
2985 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002986 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002987 DataType::QAsymmU8,
2988 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002989 };
2990
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002991 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2992 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002993
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002994 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002995
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002996 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2997 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002998
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002999 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
3000 alphaTensorInfo,
3001 outputTensorInfo,
3002 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003003 "input",
3004 "alpha");
3005}
3006
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003007void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3008{
3009 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
3010
3011 ValidateNumInputs(workloadInfo, descriptorName, 1);
3012 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3013
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003014 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3015 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3016
3017 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
3018 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003019
3020 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003021
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003022 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
3023 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003024
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003025 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
3026
3027 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003028 if (m_Parameters.m_BiasEnabled)
3029 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003030 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003031
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003032 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
3033 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003034
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003035 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003036 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003037 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003038
3039 ValidatePerAxisQuantization(inputTensorInfo,
3040 outputTensorInfo,
3041 weightTensorInfo,
3042 optionalBiasTensorInfo,
3043 descriptorName);
3044
3045 std::vector<DataType> supportedTypes =
3046 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003047 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003048 DataType::Float32,
3049 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003050 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003051 DataType::QAsymmU8,
3052 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003053 };
3054
3055 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3056 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003057}
3058
Mike Kellyc9ea45a2020-02-28 18:11:58 +00003059void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3060{
3061 const std::string descriptorName{"TransposeQueueDescriptor"};
3062
3063 ValidateNumInputs(workloadInfo, descriptorName, 1);
3064 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3065
3066 const PermutationVector& mapping = m_Parameters.m_DimMappings;
3067
3068 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3069 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3070
3071 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
3072 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
3073
3074 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
3075 {
3076 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
3077 {
3078 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
3079 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
3080 "must match dst dimension " + to_string(i) +
3081 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
3082 }
3083 }
3084
3085 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3086}
3087
Simon Obute51f67772021-09-03 15:50:13 +01003088void ChannelShuffleQueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
3089{
3090 const std::string descriptorName{"TransposeQueueDescriptor"};
3091
3092 ValidateNumInputs(workloadInfo, descriptorName, 1);
3093 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3094
3095 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3096 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3097
3098 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3099}
3100
James Conroy4f1f8992020-04-29 20:01:10 +01003101void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3102{
3103 const std::string descriptorName{"QLstmQueueDescriptor"};
3104
3105 // Validate number of inputs/outputs
3106 ValidateNumInputs(workloadInfo, descriptorName, 3);
3107 ValidateNumOutputs(workloadInfo, descriptorName, 3);
3108
3109 // Input/output tensor info
3110 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3111 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
3112 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
3113
3114 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3115 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3116 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
3117
3118 // Supported types for various tensors in QLSTM
3119 std::vector<DataType> inputOutputSupportedTypes =
3120 {
3121 DataType::QAsymmS8
3122 };
3123
3124 std::vector<DataType> cellStateSupportedTypes =
3125 {
3126 DataType::QSymmS16
3127 };
3128
3129 std::vector<DataType> weightsSupportedTypes =
3130 {
3131 DataType::QSymmS8
3132 };
3133
3134 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3135 {
3136 DataType::QSymmS16
3137 };
3138
3139 std::vector<DataType> biasSupportedTypes =
3140 {
3141 DataType::Signed32
3142 };
3143
3144 // Validate types of input/output tensors
3145 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3146 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3147 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3148
3149 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3150 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3151 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3152
3153 // Validate matching types of input/output tensors
3154 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3155 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3156 "outputStateIn", "outputStateOut");
3157 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3158
3159 // Infer number of batches, number of units, input size and output size from tensor dimensions
3160 const uint32_t numBatches = inputInfo.GetShape()[0];
3161 const uint32_t inputSize = inputInfo.GetShape()[1];
3162 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3163 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3164
3165 // Validate number of dimensions and number of elements for input/output tensors
3166 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3167 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3168 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3169
3170 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3171 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3172 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3173
3174 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3175 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3176 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3177 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3178
3179 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3180 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3181 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3182
3183 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3184 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3185 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3186
3187 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3188 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3189 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3190 " RecurrentToForgetWeights");
3191
3192 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3193 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3194 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3195
3196 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3197 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3198 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3199
3200 // Validate data types for MANDATORY weights tensors (all should match each other)
3201 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3202
3203 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3204 "inputToForgetWeights", "inputToCellWeights");
3205 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3206 "inputToForgetWeights", "inputToOutputWeights");
3207
3208 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3209 "inputToForgetWeights", "recurrentToForgeteights");
3210 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3211 "inputToForgetWeights", "recurrentToCellWeights");
3212 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3213 "inputToForgetWeights", "recurrentToOutputWeights");
3214
3215 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3216 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3217 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3218 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3219
3220 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3221 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3222 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3223
3224 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3225 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3226 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3227
3228 // Validate data types for MANDATORY bias tensors
3229 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3230
3231 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3232 "forgetGateBias", "cellBias");
3233 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3234 "forgetGateBias", "outputGateBias");
3235
3236 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3237 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3238 !m_Parameters.m_CifgEnabled) ||
3239 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3240 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3241
3242 if (!allCifgParamsPresentOrNot)
3243 {
3244 throw InvalidArgumentException(descriptorName +
3245 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3246 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3247 "set appropriately.");
3248 }
3249
3250 if (!m_Parameters.m_CifgEnabled)
3251 {
3252 // Validate number of dimensions and number of elements
3253 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3254 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3255
3256 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3257 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3258 " RecurrentToInputWeights");
3259
3260 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3261 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3262
3263 // Validate data types
3264 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3265 "inputToForgetWeights", "inputToInputWeights");
3266 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3267 "inputToForgetWeights", "recurrentToInputWeights");
3268 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3269 "forgetGateBias", "inputGateBias");
3270 }
3271
3272 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3273 bool allPeepholeWeightsPresentOrNot =
3274 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3275 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3276 || (!m_CellToInputWeights && !m_CellToForgetWeights
3277 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3278
3279 if (!allPeepholeWeightsPresentOrNot)
3280 {
3281 throw InvalidArgumentException(descriptorName +
3282 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3283 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3284 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3285 "appropriately.");
3286 }
3287
3288 if (m_Parameters.m_PeepholeEnabled)
3289 {
3290 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3291 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3292 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3293
3294 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3295 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3296 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3297 "cellToForgetWeight", "cellToOutputWeights");
3298
3299 if (!m_Parameters.m_CifgEnabled)
3300 {
3301 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3302 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3303 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3304 "cellToForgetWeights", "cellToInputWeights");
3305 }
3306 }
3307
3308 // Validate OPTIONAL params: Layer Norm Weights
3309 bool allLayerNormWeightsPresentOrNot =
3310 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3311 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3312 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3313 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3314
3315 if (!allLayerNormWeightsPresentOrNot)
3316 {
3317 throw InvalidArgumentException(descriptorName +
3318 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3319 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3320 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3321 "only be present when Layer Norm is enabled and CIFG is disabled. "
3322 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3323 }
3324
3325 if (m_Parameters.m_LayerNormEnabled)
3326 {
3327 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3328 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3329 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3330
3331 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3332 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3333 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3334 "forgetLayerNormWeights", "cellLayerNormWeights");
3335
3336 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3337 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3338 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3339 "forgetLayerNormWeights", "outputLayerNormWeights");
3340
3341 if (!m_Parameters.m_CifgEnabled)
3342 {
3343 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3344 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3345 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3346 "forgetLayerNormWeights", "inputLayerNormWeights");
3347 }
3348 }
3349
3350 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3351 bool correctProjectionTensorsPresent =
3352 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3353 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3354 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3355
3356 if (!correctProjectionTensorsPresent)
3357 {
3358 throw InvalidArgumentException(descriptorName +
3359 ": If projection is enabled, ProjectionWeights should be present and "
3360 "ProjectionBias is optional. If projection is disabled, neither "
3361 "ProjectionWeights nor ProjectionBias should be present.");
3362 }
3363
3364 if (m_Parameters.m_ProjectionEnabled)
3365 {
3366 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3367 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3368 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3369
3370 if (m_ProjectionBias)
3371 {
3372 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003373 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003374 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3375 }
3376
3377 }
3378 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3379 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3380 throw InvalidArgumentException(descriptorName +
3381 ": If projection is disabled, output quantization info (scale, offset) "
3382 "should match HiddenStateScale and HiddenStateZeroPoint.");
3383 }
3384
3385}
3386
James Conroy9c3cae82019-08-01 16:01:48 +01003387void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3388{
3389 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3390
3391 // Validate number of inputs/outputs
3392 ValidateNumInputs(workloadInfo, descriptorName, 3);
3393 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3394
3395 // Input/output tensor infos
3396 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3397 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3398 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3399
3400 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3401 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3402
3403 std::vector<DataType> inputOutputSupportedTypes =
3404 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003405 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003406 };
3407
3408 std::vector<DataType> cellStateSupportedTypes =
3409 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003410 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003411 };
3412
3413 std::vector<DataType> weightsSupportedTypes =
3414 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003415 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003416 };
3417
3418 std::vector<DataType> biasSupportedTypes =
3419 {
3420 DataType::Signed32
3421 };
3422
3423 // Validate types of input/output tensors
3424 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3425 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3426 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3427
3428 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3429 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3430
3431 // Validate matching types of input/output tensors
3432 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3433 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3434 "outputStateIn", "outputStateOut");
3435 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3436
3437 // Validate matching quantization info for input/output tensors
3438 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3439 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3440 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003441
James Conroy9c3cae82019-08-01 16:01:48 +01003442 // Infer number of batches, input size and output size from tensor dimensions
3443 const uint32_t numBatches = inputInfo.GetShape()[0];
3444 const uint32_t inputSize = inputInfo.GetShape()[1];
3445 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3446
3447 // Validate number of dimensions and number of elements for input/output tensors
3448 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3449 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3450 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3451 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3452 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3453
3454 // Validate number of dimensions and number of elements for weights tensors
3455 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3456 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3457 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3458
3459 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3460 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3461 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3462
3463 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3464 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3465 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3466
3467 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3468 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3469 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3470
3471 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3472 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3473 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3474
3475 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3476 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3477 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3478 " RecurrentToForgetWeights");
3479
3480 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3481 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3482 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3483
3484 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3485 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3486 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3487
3488 // Validate data types for weights tensors (all should match each other)
3489 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3490
3491 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3492 "inputToInputWeights", "inputToForgetWeights");
3493 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3494 "inputToInputWeights", "inputToCellWeights");
3495 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3496 "inputToInputWeights", "inputToOutputWeights");
3497
3498 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3499 "inputToInputWeights", "recurrentToInputWeights");
3500 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3501 "inputToInputWeights", "recurrentToForgeteights");
3502 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3503 "inputToInputWeights", "recurrentToCellWeights");
3504 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3505 "inputToInputWeights", "recurrentToOutputWeights");
3506
3507 // Validate matching quantization info for weight tensors (all should match each other)
3508 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3509 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3510 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3511 descriptorName, "inputToInputWeights", "inputToCellWeights");
3512 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3513 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3514
3515 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3516 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3517 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3518 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3519 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3520 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3521 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3522 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3523
3524 // Validate number of dimensions and number of elements in bias tensors
3525 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3526 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3527 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3528
3529 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3530 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3531 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3532
3533 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3534 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3535 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3536
3537 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3538 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3539 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3540
3541 // Validate data types for bias tensors (all should match each other)
3542 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3543
3544 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3545 "inputGateBias", "forgetGateBias");
3546 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3547 "inputGateBias", "cellBias");
3548 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3549 "inputGateBias", "outputGateBias");
3550
3551 // Validate bias tensor quantization info
3552 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3553 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3554 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3555 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3556}
3557
Kevin May868eb142019-09-04 17:29:31 +01003558void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3559{
3560 const std::string descriptorName{"AbsQueueDescriptor"};
3561
3562 ValidateNumInputs(workloadInfo, descriptorName, 1);
3563 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3564
3565 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3566 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3567
3568 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3569
3570 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003571 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003572 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003573 DataType::Float16,
3574 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003575 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003576 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003577 DataType::QSymmS16,
3578 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003579 };
Kevin May868eb142019-09-04 17:29:31 +01003580
3581 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3582 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3583}
3584
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003585void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3586{
3587 const std::string descriptorName{"SliceQueueDescriptor"};
3588
3589 ValidateNumInputs(workloadInfo, descriptorName, 1);
3590 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3591
3592 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3593 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3594
3595 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3596
3597 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3598 if (rank > 4)
3599 {
3600 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3601 }
3602
3603 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3604
3605 // Check if m_Begin and m_Size have the expected length
3606 if (m_Parameters.m_Begin.size() != rank)
3607 {
3608 throw InvalidArgumentException(descriptorName +
3609 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3610 }
3611 if (m_Parameters.m_Size.size() != rank)
3612 {
3613 throw InvalidArgumentException(descriptorName +
3614 ": Length of size descriptor must equal rank " + std::to_string(rank));
3615 }
3616
3617 // Check if the shape of the output tensor matches m_Size
3618 const TensorShape& outputShape = outputTensorInfo.GetShape();
3619 for (unsigned int i = 0u; i < rank; ++i)
3620 {
3621 if (m_Parameters.m_Size[i] != outputShape[i])
3622 {
3623 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3624 }
3625 }
3626
3627 // Check if the sum of begin offset and size in a given dimension
3628 // does not exceed the size of corresponding input
3629 const TensorShape& inputShape = inputTensorInfo.GetShape();
3630 for(unsigned int i = 0u; i < rank; ++i)
3631 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003632 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003633 {
3634 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3635 std::to_string(i) + " exceeds input size.");
3636 }
3637 }
3638}
3639
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003640void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3641{
3642 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3643
3644 ValidateNumInputs(workloadInfo, descriptorName, 1);
3645 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3646
3647 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3648 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3649
3650 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3651 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3652
3653 std::vector<DataType> supportedTypes =
3654 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003655 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003656 DataType::Float32,
3657 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003658 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003659 DataType::QAsymmU8,
3660 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003661 };
3662
3663 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3664 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3665
3666 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3667
3668 if (m_Parameters.m_BlockSize == 0)
3669 {
3670 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3671 }
3672
3673 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3674 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3675 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3676 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3677
3678 const TensorShape& outputShape = outputInfo.GetShape();
3679 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3680 {
3681 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3682 "must be divisible by block size.");
3683 }
3684
3685 const TensorShape& inputShape = inputInfo.GetShape();
3686 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3687 {
3688 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3689 "must be divisible by the square of block size." );
3690 }
3691}
3692
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003693void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3694{
3695 const std::string descriptorName{"ComparisonQueueDescriptor"};
3696
3697 ValidateNumInputs(workloadInfo, descriptorName, 2);
3698 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3699
3700 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3701 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3702 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3703
3704 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3705 inputTensorInfo1,
3706 outputTensorInfo,
3707 descriptorName,
3708 "input_0",
3709 "input_1");
3710
3711 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3712 {
3713 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3714 }
3715}
3716
josh minor4a3c6102020-01-06 16:40:46 -06003717void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3718{
3719 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3720
3721 ValidateNumInputs(workloadInfo, descriptorName, 1);
3722 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3723
3724 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3725 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3726
3727 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3728
3729 std::vector<DataType> supportedTypes =
3730 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003731 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003732 DataType::Float16,
3733 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003734 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003735 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003736 DataType::QSymmS16,
3737 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003738 };
3739
James Conroyaba90cd2020-11-06 16:28:18 +00003740 std::vector<DataType> logicalSupportedTypes =
3741 {
3742 DataType::Boolean
3743 };
3744
3745 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3746 {
3747 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3748 }
3749 else
3750 {
3751 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3752 }
3753
3754
josh minor4a3c6102020-01-06 16:40:46 -06003755 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3756}
3757
Finn Williams2605b232020-06-10 15:53:46 +01003758void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3759{
3760 const std::string descriptorName{"RankQueueDescriptor"};
3761
3762 ValidateNumInputs(workloadInfo, descriptorName, 1);
3763 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3764
3765 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3766 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3767
3768 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3769 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3770
3771 std::vector<DataType> supportedTypes =
3772 {
3773 DataType::BFloat16,
3774 DataType::Float16,
3775 DataType::Float32,
3776 DataType::QAsymmS8,
3777 DataType::QAsymmU8,
3778 DataType::QSymmS8,
3779 DataType::QSymmS16,
3780 DataType::Signed32
3781 };
3782
3783 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3784 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3785}
3786
James Conroyaba90cd2020-11-06 16:28:18 +00003787void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3788{
3789 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3790
3791 ValidateNumInputs(workloadInfo, descriptorName, 2);
3792 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3793
3794 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3795 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3796 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3797
3798 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3799 inputTensorInfo1,
3800 outputTensorInfo,
3801 descriptorName,
3802 "input_0",
3803 "input_1");
3804
3805 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3806 {
3807 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3808 }
3809
3810 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3811 {
3812 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3813 }
3814
3815 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3816 {
3817 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3818 }
3819}
3820
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003821void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3822{
3823 const std::string descriptorName{"ReduceQueueDescriptor"};
3824
3825 ValidateNumInputs(workloadInfo, descriptorName, 1);
3826 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3827
3828 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3829 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3830
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003831 std::vector<DataType> supportedTypes =
3832 {
3833 DataType::BFloat16,
3834 DataType::Float16,
3835 DataType::Float32,
3836 DataType::QAsymmS8,
3837 DataType::QAsymmU8,
3838 DataType::QSymmS16,
3839 DataType::Signed32
3840 };
3841
3842 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3843 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3844}
3845
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003846void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3847{
3848 // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm
3849
3850 const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"};
3851
3852 // check dimensions of all inputs and outputs
3853 if (workloadInfo.m_InputTensorInfos.size() != 3)
3854 {
3855 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
3856 }
Mike Kelly12994962022-04-21 11:57:09 +01003857 if (workloadInfo.m_OutputTensorInfos.size() != 3)
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003858 {
3859 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
3860 }
3861
3862 std::vector<DataType> supportedTypes =
3863 {
Mike Kelly12994962022-04-21 11:57:09 +01003864 DataType::Float32,
3865 DataType::QAsymmS8
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003866 };
3867
3868 // check for supported type of one input and match them with all the other input and output
3869 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
3870
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003871 // Making sure clipping parameters have valid values.
3872 // == 0 means no clipping
3873 // > 0 means clipping
3874 if (m_Parameters.m_ClippingThresCell < 0.0f)
3875 {
3876 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
3877 }
3878 if (m_Parameters.m_ClippingThresProj < 0.0f)
3879 {
3880 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
3881 }
3882
3883 unsigned int batchIndx = 0;
3884 unsigned int inputIndx = 1;
3885 uint32_t timeStep = 1;
3886 unsigned int timeIndx = 1;
3887 inputIndx = 2;
3888 if (m_Parameters.m_TimeMajor)
3889 {
3890 batchIndx = 1;
3891 timeIndx = 0;
3892
3893 }
3894 timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx];
3895
3896 // Inferring batch size, number of outputs and number of cells from the inputs.
3897 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx];
3898 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx];
3899 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
3900 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
3901 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
3902 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
3903
3904 // input tensor
3905 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input),
3906 descriptorName + " input_0");
3907 // outputStateInTensor
3908 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
3909 descriptorName + " input_1");
3910 // outputStateInTensor
3911 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
3912 descriptorName + " input_2");
3913
3914 // outputTensor
Mike Kelly12994962022-04-21 11:57:09 +01003915 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 3, (timeStep * n_batch * n_output),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003916 descriptorName + " output_0");
3917
3918 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
3919 if ( m_InputToInputWeights )
3920 {
3921 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
3922 (n_cell * n_input), "InputLayerNormWeights");
3923 }
3924
3925 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
3926 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
3927 (n_cell * n_input), "InputToForgetWeights");
3928
3929 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
3930 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
3931 (n_cell * n_input), "InputToCellWeights");
3932
3933 if ( m_RecurrentToInputWeights )
3934 {
3935 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
3936 (n_cell * n_output), "RecurrentToInputWeights");
3937 }
3938
3939 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
3940 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
3941 (n_cell * n_output), "RecurrentToForgetWeights");
3942
3943 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
3944 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
3945 (n_cell * n_output), "RecurrentToCellWeights");
3946
3947 // Make sure the input-gate's parameters are either both present (regular
3948 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
3949 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
3950 !m_Parameters.m_CifgEnabled) ||
3951 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3952 m_Parameters.m_CifgEnabled));
3953 if (!cifg_weights_all_or_none)
3954 {
3955 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
3956 "RecurrentToInputWeights must either both be present (regular LSTM) "
3957 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
3958 "accordingly.");
3959 }
3960
3961 if ( m_CellToInputWeights )
3962 {
3963 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
3964 n_cell, "CellToInputWeights");
3965 }
3966 if ( m_CellToForgetWeights )
3967 {
3968 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
3969 n_cell, "CellToForgetWeights");
3970 }
3971 if ( m_CellToOutputWeights )
3972 {
3973 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
3974 n_cell, "CellToOutputWeights");
3975 }
3976
3977 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
3978 bool peephole_weights_all_or_none =
3979 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3980 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3981 || ( !m_CellToInputWeights && !m_CellToForgetWeights
3982 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3983 if (!peephole_weights_all_or_none)
3984 {
3985 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
3986 }
3987
3988 // Make sure the input gate bias is present only when not a CIFG-LSTM.
3989 if (m_Parameters.m_CifgEnabled)
3990 {
3991 if (m_InputGateBias)
3992 {
3993 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
3994 }
3995 }
3996 else
3997 {
3998 if (!m_InputGateBias)
3999 {
4000 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
4001 "must be present.");
4002 }
4003 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
4004 n_cell, "InputGateBias");
4005 }
4006
4007 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
4008 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
4009
4010 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
4011 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
4012
4013 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
4014 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
4015
4016 if (m_ProjectionWeights)
4017 {
4018 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
4019 (n_cell * n_output), "ProjectionWeights");
4020 }
4021 if (m_ProjectionBias)
4022 {
4023 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
4024 }
4025
4026 // Making sure the projection tensors are consistent:
4027 // 1) If projection weight is not present, then projection bias should not be
4028 // present.
4029 // 2) If projection weight is present, then projection bias is optional.
4030 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
4031 !m_Parameters.m_ProjectionEnabled)
4032 || (m_ProjectionWeights && !m_ProjectionBias &&
4033 m_Parameters.m_ProjectionEnabled)
4034 || (m_ProjectionWeights && m_ProjectionBias &&
4035 m_Parameters.m_ProjectionEnabled));
4036 if (!projecton_tensors_consistent)
4037 {
4038 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
4039 }
4040
4041 // The four layer normalization weights either all have values or none of them have values. Additionally, if
4042 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
4043 // either all have values or none of them have values. Layer normalization is used when the values of all the
4044 // layer normalization weights are present
4045 if (m_InputLayerNormWeights)
4046 {
4047 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
4048 }
4049 if (m_ForgetLayerNormWeights)
4050 {
4051 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4052 }
4053 if (m_CellLayerNormWeights)
4054 {
4055 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4056 }
4057 if (m_OutputLayerNormWeights)
4058 {
4059 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4060 }
4061
4062 if (m_Parameters.m_LayerNormEnabled)
4063 {
4064 if (!m_Parameters.m_CifgEnabled)
4065 {
4066 if (!m_InputLayerNormWeights)
4067 {
4068 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
4069 "disabled but InputLayerNormWeights are not present");
4070 }
4071 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
4072 1, n_cell, "InputLayerNormWeights");
4073 }
4074 else if (m_InputLayerNormWeights)
4075 {
4076 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
4077 "enabled");
4078 }
4079
4080 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
4081 "ForgetLayerNormWeights");
4082 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4083
4084 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
4085 "OutputLayerNormWeights");
4086 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4087
4088 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
4089 "CellLayerNormWeights");
4090 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4091 }
4092 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
4093 {
4094 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
4095 "normalisation weights are present.");
4096 }
4097}
4098
Samuel Yap6b478092022-07-06 15:36:03 +01004099void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4100{
4101 const std::string descriptorName{"BatchMatMulDescriptor"};
4102
4103 ValidateNumInputs(workloadInfo, descriptorName, 2);
4104 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4105
4106 // Inputs must be: both 2D+
4107 // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
4108 // axes N and I must be the same size
4109
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004110 const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
4111 const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
4112 const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
4113 // Output info has already been inferred
Samuel Yap6b478092022-07-06 15:36:03 +01004114
4115 std::vector<DataType> supportedTypes =
4116 {
4117 DataType::BFloat16,
4118 DataType::Float16,
4119 DataType::Float32,
4120 DataType::QAsymmS8,
4121 DataType::QAsymmU8,
4122 DataType::QSymmS16
4123 };
4124
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004125 ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
4126 ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
4127 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
Samuel Yap6b478092022-07-06 15:36:03 +01004128
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004129 if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
4130 (inputYInfoBeforeParams.GetNumDimensions() < 2))
Samuel Yap6b478092022-07-06 15:36:03 +01004131 {
4132 throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
4133 }
4134
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004135 TensorInfo inputXInfoAfterParams;
4136 TensorInfo inputYInfoAfterParams;
4137
4138 if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) ||
4139 (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY))
Samuel Yap6b478092022-07-06 15:36:03 +01004140 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004141 throw InvalidArgumentException(descriptorName +
4142 ": Invalid descriptor parameters - Transpose and Adjoint "
4143 "cannot both be true for a given input tensor.");
4144 }
4145 if(m_Parameters.m_TransposeX)
4146 {
4147 inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
4148 BatchMatMulDescriptor::GetPermuteVec(
4149 m_Parameters.m_DataLayoutX,
4150 inputXInfoBeforeParams.GetShape()));
4151 }
4152 else if(m_Parameters.m_AdjointX)
4153 {
4154 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4155 inputXInfoBeforeParams.GetShape());
4156 if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
4157 inputXInfoBeforeParams.GetShape()[axesToMul.second])
Samuel Yap6b478092022-07-06 15:36:03 +01004158 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004159 throw InvalidArgumentException(descriptorName +
4160 ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
Samuel Yap6b478092022-07-06 15:36:03 +01004161 }
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004162 // Shape remains the same as it's square
4163 inputXInfoAfterParams = inputXInfoBeforeParams;
4164 }
4165 else
4166 {
4167 inputXInfoAfterParams = inputXInfoBeforeParams;
Samuel Yap6b478092022-07-06 15:36:03 +01004168 }
4169
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004170 if(m_Parameters.m_TransposeY)
Samuel Yap6b478092022-07-06 15:36:03 +01004171 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004172 inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
4173 BatchMatMulDescriptor::GetPermuteVec(
4174 m_Parameters.m_DataLayoutY,
4175 inputYInfoBeforeParams.GetShape()));
4176 }
4177 else if(m_Parameters.m_AdjointY)
4178 {
4179 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4180 inputYInfoBeforeParams.GetShape());
4181 if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
4182 inputYInfoBeforeParams.GetShape()[axesToMul.second])
Samuel Yap6b478092022-07-06 15:36:03 +01004183 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004184 throw InvalidArgumentException(descriptorName +
4185 ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
Samuel Yap6b478092022-07-06 15:36:03 +01004186 }
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004187 // Shape remains the same as it's square
4188 inputYInfoAfterParams = inputYInfoBeforeParams;
4189 }
4190 else
4191 {
4192 inputYInfoAfterParams = inputYInfoBeforeParams;
Samuel Yap6b478092022-07-06 15:36:03 +01004193 }
4194
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004195 switch(m_Parameters.m_DataLayoutX)
4196 {
4197 case DataLayout::NCDHW:
4198 case DataLayout::NDHWC:
4199 if(inputXInfoAfterParams.GetNumDimensions() < 3)
4200 {
4201 throw InvalidArgumentException(descriptorName +
4202 ": Input tensor X does not have the correct "
4203 "number of dimensions for the Data Layout that it has been assigned.");
4204 }
4205 break;
4206 case DataLayout::NCHW:
4207 case DataLayout::NHWC:
4208 default:
4209 break;
4210 }
Samuel Yap6b478092022-07-06 15:36:03 +01004211
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004212 switch(m_Parameters.m_DataLayoutY)
4213 {
4214 case DataLayout::NCDHW:
4215 case DataLayout::NDHWC:
4216 if(inputYInfoAfterParams.GetNumDimensions() < 3)
4217 {
4218 throw InvalidArgumentException(descriptorName +
4219 ": Input tensor Y does not have the correct "
4220 "number of dimensions for the Data Layout that it has been assigned.");
4221 }
4222 break;
4223 case DataLayout::NCHW:
4224 case DataLayout::NHWC:
4225 default:
4226 break;
4227 }
4228
4229 auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4230 inputXInfoAfterParams.GetShape());
4231 auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4232 inputXInfoBeforeParams.GetShape());
4233
4234 if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
4235 != inputYInfoAfterParams.GetShape()[axesYToMul.first])
Samuel Yap6b478092022-07-06 15:36:03 +01004236 {
4237 throw InvalidArgumentException(descriptorName +
4238 ": The final axis of input tensor X must be the same size as "
4239 "the second last axis of input tensor Y.");
4240 }
4241
Samuel Yap6b478092022-07-06 15:36:03 +01004242 { // Separate scope so we don't pollute the rest of the scope with our temp variables
4243 // e.g. NHWC isnt compatible with NCHW as of now
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004244 DataLayout xLayout = m_Parameters.m_DataLayoutX;
4245 DataLayout yLayout = m_Parameters.m_DataLayoutY;
Samuel Yap6b478092022-07-06 15:36:03 +01004246
4247 if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
4248 {
4249 if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
4250 {
4251 throw InvalidArgumentException(descriptorName +
4252 ": Invalid input tensor data layout combination.");
4253 }
4254 }
4255 if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
4256 {
4257 if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
4258 {
4259 throw InvalidArgumentException(descriptorName +
4260 ": Invalid input tensor data layout combination.");
4261 }
4262 }
4263 }
4264
4265 // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004266 unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
4267 inputYInfoAfterParams.GetNumDimensions());
Samuel Yap6b478092022-07-06 15:36:03 +01004268 if(outputTensorDimSize-2 > 0)
4269 {
4270 TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4271 DataType::Float32);
4272 TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4273 DataType::Float32);
4274 TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4275 DataType::Float32);
4276
4277 auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
4278 {
4279 auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4280
4281 for(unsigned int i = 0; i < sizeDiff; i++)
4282 {
4283 axisIndices.insert(axisIndices.begin(), 1);
4284 }
4285
4286 for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4287 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004288 ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
Samuel Yap6b478092022-07-06 15:36:03 +01004289 }
4290 };
4291
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004292 auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX,
4293 inputXInfoAfterParams.GetShape());
4294 auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY,
4295 inputYInfoAfterParams.GetShape());
4296
4297 doAxisExtension(axesXNotMul, tiXNotMul);
4298 doAxisExtension(axesYNotMul, tiYNotMul);
Samuel Yap6b478092022-07-06 15:36:03 +01004299
4300 for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
4301 {
4302 tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
4303 tiYNotMul.GetShape()[i]);
4304 }
4305
4306 ValidateBroadcastTensorShapesMatch(tiXNotMul,
4307 tiYNotMul,
4308 tiOutNotMul,
4309 descriptorName,
4310 "input_X",
4311 "input_Y");
4312 }
Samuel Yap6b478092022-07-06 15:36:03 +01004313}
4314
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01004315
mathad01df9a3222021-04-28 11:42:57 +01004316} // namespace armnn