blob: 62dfc6a38b33e7e836579d88dae7190e063cb4bc [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01002// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Colm Donelan0c479742021-12-10 12:43:54 +00006#include <armnn/backends/TensorHandle.hpp>
7#include <armnn/backends/WorkloadData.hpp>
8#include <armnn/backends/WorkloadInfo.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00009#include <armnnUtils/DataLayoutIndexed.hpp>
10#include <armnnUtils/TensorUtils.hpp>
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010011#include <armnnUtils/Permute.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010012#include <armnn/utility/NumericCast.hpp>
mathad01df9a3222021-04-28 11:42:57 +010013#include <armnn/Logging.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000014
telsoa014fcda012018-03-09 14:13:49 +000015#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000017#include <string>
18#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000019
James Ward47fce872020-09-10 11:57:28 +010020#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000021
Matteo Martincigh21350152018-11-28 16:22:22 +000022using namespace armnnUtils;
23
telsoa014fcda012018-03-09 14:13:49 +000024namespace armnn
25{
26
27//---------------------------------------------------------------
28DataType GetBiasDataType(DataType inputDataType)
29{
30 switch (inputDataType)
31 {
telsoa01c577f2c2018-08-31 09:22:23 +010032 case DataType::Float16:
33 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000034 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000035 case DataType::Float32:
36 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000037 case DataType::QAsymmS8:
38 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000039 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000040 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000041 case DataType::QSymmS8:
42 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000043 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010044 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000045 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010046 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000047 return DataType::Float32;
48 }
49}
50
51namespace
52{
53
54//---------------------------------------------------------------
55//android ndk does not support std::to_string function.
56template <typename T>
57std::string to_string(T value)
58{
59 std::ostringstream os;
60 os << value;
61 return os.str();
62}
63
64//---------------------------------------------------------------
65void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
66{
67 if (!ptr)
68 {
69 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
70 paramName + " parameter must be set.");
71 }
72}
73
74//---------------------------------------------------------------
75void ValidateTensorShapesMatch(const TensorInfo& first,
76 const TensorInfo& second,
77 std::string const& descName,
78 std::string const& firstName,
79 std::string const& secondName)
80{
81 if (first.GetShape() != second.GetShape())
82 {
83 throw InvalidArgumentException(descName + ": "
84 + firstName + " & " + secondName + " must have identical shapes");
85 }
86}
87
88//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010089void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000090{
Sadik Armaganeff363d2019-04-05 15:25:46 +010091 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000092 {
93 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010094 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000095 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
96 }
97}
98
99//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100100void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000101{
Sadik Armaganeff363d2019-04-05 15:25:46 +0100102 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000103 {
104 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100105 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000106 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
107 }
108}
109
110//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000111
112//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100113void ValidateTensorNumElements(const TensorInfo& tensor,
114 std::string const& descName,
115 unsigned int numElements,
116 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100117{
118 if (tensor.GetNumElements() != numElements)
119 {
120 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100121 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100122 tensorName + " tensor.");
123 }
124}
125
126//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000127void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
128 const std::string& descName, std::string const& tensorName)
129{
130 if (tensor.GetDataType() != dataType)
131 {
132 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
133 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
134 }
135}
136
Derek Lambertid466a542020-01-22 15:37:29 +0000137void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
138{
Jan Eilers1b2654f2021-09-24 15:45:46 +0100139 if (tensor.GetDataType() != DataType::QSymmS8)
Derek Lambertid466a542020-01-22 15:37:29 +0000140 {
141 throw InvalidArgumentException(descName +
142 ": Expected data type which supports per-axis quantization scheme but got " +
143 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
144 }
Derek Lambertid466a542020-01-22 15:37:29 +0000145}
146
telsoa014fcda012018-03-09 14:13:49 +0000147//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100148void ValidateTensorQuantizationSpace(const TensorInfo& first,
149 const TensorInfo& second,
150 const std::string& descName,
151 std::string const& firstName,
152 std::string const& secondName)
153{
154 if (!first.IsQuantized() ||
155 !second.IsQuantized())
156 {
157 // Not a quantized type, ignore the validation
158 return;
159 }
160
161 DataType firstDataType = first.GetDataType();
162 DataType secondDataType = second.GetDataType();
163
164 if (firstDataType != secondDataType)
165 {
166 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
167 " must be of the same quantized type, " +
168 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
169 secondName + " is " + GetDataTypeName(secondDataType));
170 }
171
172 if (!first.IsTypeSpaceMatch(second))
173 {
174 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
175 " must have the same quantization space, " +
176 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
177 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
178 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
179 " and scale " + to_string(second.GetQuantizationScale()));
180 }
181}
182
183//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100184void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
185 const TensorInfo& inputTensorInfo,
186 const TensorInfo& weightsTensorInfo,
187 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000188{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000189 // Helper lambda function to validate a single bias quantization scale value
190 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
191 {
mathad01df9a3222021-04-28 11:42:57 +0100192 constexpr float tolerance = 0.0001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000193 if (std::abs(biasScale - expectedScale) > tolerance)
194 {
195 // Print the float values with extra precision to see very small differences
mathad01df9a3222021-04-28 11:42:57 +0100196 ARMNN_LOG(warning) << std::setprecision(6) << descName << ": Expected " << expectedScale <<
197 " for bias quantization scale (product of input and weight scales), but got " <<
198 biasScale << ". Using scale provided.";
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000199 }
200 };
201
telsoa014fcda012018-03-09 14:13:49 +0000202 if (biasTensor.GetQuantizationOffset() != 0)
203 {
204 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
205 to_string(biasTensor.GetQuantizationOffset()));
206 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000207
James Conroy8502ade2020-11-12 19:26:29 +0000208 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000209 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000210 // Validate per-axis quantization scales
211 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
212 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
213
214 if (weightScales.size() != biasScales.size())
215 {
216 std::stringstream msg;
James Conroy8502ade2020-11-12 19:26:29 +0000217 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
218 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
219 << ", biases=" << biasScales.size();
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000220 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
221 }
222
223 for (size_t i = 0ul; i < biasScales.size(); ++i)
224 {
225 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
226 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
227 }
228 }
229 else
230 {
231 // Validate per-tensor quantization scale
232 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
233 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000234 }
235}
236
237//---------------------------------------------------------------
238void ValidateTensors(const std::vector<ITensorHandle*>& vec,
239 unsigned int numExpected,
240 const std::string& descName,
241 const std::string& varName)
242{
243 if (vec.empty() && numExpected > 0)
244 {
245 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
246 }
247
248 for (unsigned int i = 0; i < numExpected; ++i)
249 {
250 if (!vec[i])
251 {
252 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
253 }
254 }
255}
256
257//---------------------------------------------------------------
258void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
259 const TensorInfo& second,
260 const TensorInfo& output,
261 std::string const& descName,
262 std::string const& firstName,
263 std::string const& secondName)
264{
265 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
266 // broadcasted.
267 if (first.GetNumDimensions() != second.GetNumDimensions())
268 {
269 throw InvalidArgumentException(descName + ": Tensors "
270 + firstName + " & " + secondName
271 + " must have the same number of dimensions in order to be broadcasted");
272 }
273 uint32_t numDims = first.GetNumDimensions();
274 std::vector<uint32_t> outputDims(numDims, 0u);
275 for (uint32_t i = 0; i < numDims; i++)
276 {
277 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
278 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
279 if (dimsNotEqual && dimsNotOne)
280 {
281 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
282 }
283 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
284 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100285 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000286 if (broadcastShape != output.GetShape())
287 {
288 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
289 + firstName + " & " + secondName
290 + " does not match the output shape");
291 }
292}
293
294//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100295void ValidateDataTypes(const TensorInfo& info,
296 const std::vector<armnn::DataType>& supportedTypes,
297 std::string const& descName)
298{
299 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
300 if (iterator == supportedTypes.end())
301 {
302 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
303 }
304}
305
James Conroy4d1ff582019-06-10 17:06:39 +0100306//---------------------------------------------------------------
307void ValidateTensorDataTypesMatch(const TensorInfo& first,
308 const TensorInfo& second,
309 std::string const& descName,
310 std::string const& firstName,
311 std::string const& secondName)
312{
313 if (first.GetDataType() != second.GetDataType())
314 {
315 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
316 " must have identical data types.");
317 }
318}
319
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100320//---------------------------------------------------------------
321void ValidateTensorNumElementsMatch(const TensorInfo& first,
322 const TensorInfo& second,
323 std::string const& descName,
324 std::string const& firstName,
325 std::string const& secondName)
326{
327 if (first.GetNumElements() != second.GetNumElements())
328 {
329 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
330 " must have the same number of elements.");
331 }
332}
333
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000334void ValidateWeightDataType(const TensorInfo& inputInfo,
335 const TensorInfo& weightInfo,
336 const std::string& descName)
337{
338 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000339 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000340 {
341 const std::vector<DataType> validTypes =
342 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000343 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100344 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +0100345 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000346 };
347
348 ValidateDataTypes(weightInfo, validTypes, descName);
349 }
350 else
351 {
352 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
353 }
354}
355
356void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
357 const std::string& descName,
358 const std::string& tensorName)
359{
360 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
361 if (!quantizationDim.has_value())
362 {
James Ward47fce872020-09-10 11:57:28 +0100363 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
364 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000365 }
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000366}
367
368void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
369 const std::string& descName,
370 const std::string& tensorName)
371{
372 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
373 if (quantizationOffset != 0)
374 {
James Ward47fce872020-09-10 11:57:28 +0100375 throw InvalidArgumentException(fmt::format(
376 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
377 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000378 }
379}
380
381void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
382 const TensorInfo& outputInfo,
383 const TensorInfo& weightInfo,
384 const Optional<TensorInfo>& optionalBiasInfo,
385 const std::string& descName)
386{
387 if (weightInfo.HasPerAxisQuantization())
388 {
389 const DataType inputDataType = inputInfo.GetDataType();
390 const DataType outputDataType = outputInfo.GetDataType();
391
Keith Davis0c2eeac2020-02-11 16:51:50 +0000392 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000393
394 if (!canHavePerAxisQuantization)
395 {
James Ward47fce872020-09-10 11:57:28 +0100396 throw InvalidArgumentException(fmt::format(
397 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
398 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000399 }
400
Derek Lambertid466a542020-01-22 15:37:29 +0000401
402 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000403 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
404 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
405
406 if (optionalBiasInfo.has_value())
407 {
408 const TensorInfo& biasInfo = optionalBiasInfo.value();
409 if (!biasInfo.HasPerAxisQuantization())
410 {
James Ward47fce872020-09-10 11:57:28 +0100411 throw InvalidArgumentException(fmt::format(
412 "{}: Per-axis quantization parameters not set on bias tensor, "
413 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000414 }
415
416 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
417 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
418 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
419 }
420 }
421}
422
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100423} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000424
Mike Kelly80512b02022-05-16 23:10:42 +0100425//---------------------------------------------------------------
426void QueueDescriptor::ValidateTensorNumDimensions(const TensorInfo& tensor,
427 std::string const& descName,
428 unsigned int numDimensions,
429 std::string const& tensorName) const
430{
431 // If we're allowing expanded dimensions then numDimensions becomes the minimum number of Dimensions we can allow.
432 // Throw an Exception if the tensors has fewer than numDimensions or if the squeezed dimensions are greater than
433 // numDimensions.
434 if (m_AllowExpandedDims)
435 {
436 unsigned int squeezedDims = 0;
437
438 for (unsigned int i = 0; i < tensor.GetNumDimensions(); ++i)
439 {
440 if (tensor.GetShape()[i] != 1)
441 {
442 ++squeezedDims;
443 }
444 }
445 if (tensor.GetNumDimensions() < numDimensions || squeezedDims > numDimensions)
446 {
447 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " or less but got " +
448 to_string(tensor.GetNumDimensions()) + " dimensions for " +
449 tensorName + " tensor.");
450 }
451 }
452 else
453 {
454 if (tensor.GetNumDimensions() != numDimensions)
455 {
456 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
457 to_string(tensor.GetNumDimensions()) + " dimensions for " +
458 tensorName + " tensor.");
459 }
460 }
461}
462
463//---------------------------------------------------------------
464void QueueDescriptor::ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
465 unsigned int numDimension,
466 unsigned int numElements,
467 std::string const& tensorName) const
468{
469 const std::string functionName{"ValidateTensorNumDimNumElem"};
470 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
471 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
472}
473
474//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000475void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
476 unsigned int numExpectedIn, unsigned int numExpectedOut) const
477{
478 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
479 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
480}
481
482//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100483void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
484{
485 const std::string descriptorName{"MapQueueDescriptor"};
486
487 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100488 ValidateNumOutputs(workloadInfo, descriptorName, 0);
489
490 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
491 {
492 if (!m_Inputs[i])
493 {
494 throw InvalidArgumentException(
495 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
496 }
497 }
498}
499
500//---------------------------------------------------------------
501void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
502{
503 const std::string descriptorName{"UnmapQueueDescriptor"};
504
505 ValidateNumInputs(workloadInfo, descriptorName, 1);
506 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100507
508 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
509 {
510 if (!m_Inputs[i])
511 {
512 throw InvalidArgumentException(
513 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
514 }
515 }
516}
517
518//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000519void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
520{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100521 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000522
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100523 ValidateNumInputs(workloadInfo, descriptorName, 1);
524 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000525
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100526 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
527 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
528
529 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
530 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000531
532 if (m_Inputs.size() != m_Outputs.size())
533 {
James Ward47fce872020-09-10 11:57:28 +0100534 throw InvalidArgumentException(fmt::format(
535 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
536 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000537 }
538
539 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
540 {
541 if (!m_Inputs[i])
542 {
James Ward47fce872020-09-10 11:57:28 +0100543 throw InvalidArgumentException(fmt::format(
544 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000545 }
546
547 if (!m_Outputs[i])
548 {
James Ward47fce872020-09-10 11:57:28 +0100549 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000550 }
551 }
552}
553
Derek Lambertif674aa02019-08-01 15:56:25 +0100554//---------------------------------------------------------------
555void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
556{
557 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
558 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
559
560 if (workloadInfo.m_InputTensorInfos.size() != 1)
561 {
James Ward47fce872020-09-10 11:57:28 +0100562 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
563 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100564
565 }
566
567 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
568 {
James Ward47fce872020-09-10 11:57:28 +0100569 throw InvalidArgumentException(fmt::format(
570 "Number of input infos ({0}) does not match the number of output infos ({1})",
571 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100572 }
573
574 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
575 {
576 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
577 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
578 {
James Ward47fce872020-09-10 11:57:28 +0100579 throw InvalidArgumentException(fmt::format(
580 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100581 }
582 }
583
584 if (m_Inputs.size() != 1)
585 {
James Ward47fce872020-09-10 11:57:28 +0100586 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100587 }
588
589 if (m_Inputs.size() != m_Outputs.size())
590 {
James Ward47fce872020-09-10 11:57:28 +0100591 throw InvalidArgumentException(fmt::format(
592 "Number of inputs ({0}) does not match the number of outputs ({1})",
593 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100594 }
595
596 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
597 {
598 if (!m_Inputs[i])
599 {
James Ward47fce872020-09-10 11:57:28 +0100600 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100601 }
602
603 if (!m_Outputs[i])
604 {
James Ward47fce872020-09-10 11:57:28 +0100605 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100606 }
607 }
608}
609
610//---------------------------------------------------------------
611void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
612{
613 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
Derek Lambertif674aa02019-08-01 15:56:25 +0100614
Derek Lambertif674aa02019-08-01 15:56:25 +0100615 if (m_Inputs.size() != 1)
616 {
James Ward47fce872020-09-10 11:57:28 +0100617 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100618 }
619
620 if (m_Outputs.size() != 0)
621 {
James Ward47fce872020-09-10 11:57:28 +0100622 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100623 }
624
625 if (!m_Inputs[0])
626 {
James Ward47fce872020-09-10 11:57:28 +0100627 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100628 }
629}
630
631//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000632void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
633{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100634 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100635
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100636 ValidateNumInputs(workloadInfo, descriptorName, 1);
637 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100638
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100639 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
640 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100641
642 std::vector<DataType> supportedTypes =
643 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000644 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100645 DataType::Float16,
646 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000647 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000648 DataType::QAsymmU8,
649 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100650 };
651
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100652 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
653 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
654 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000655}
656
Nikhil Rajee391d52019-09-05 17:50:44 +0100657void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
658{
659 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
660
661 ValidateNumInputs(workloadInfo, descriptorName, 1);
662 ValidateNumOutputs(workloadInfo, descriptorName, 1);
663
664 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
665 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
666
Inki Daed4619e22020-09-10 15:33:54 +0900667 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
668 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100669 {
Inki Daed4619e22020-09-10 15:33:54 +0900670 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100671 }
672
James Conroyd47a0642019-09-17 14:22:06 +0100673 std::vector<DataType> supportedInputTypes =
674 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000675 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100676 DataType::Float16,
677 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100678 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000679 DataType::QAsymmU8,
680 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900681 DataType::Signed32,
682 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100683 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100684
James Conroyd47a0642019-09-17 14:22:06 +0100685 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100686
687 auto inputShape = inputTensorInfo.GetShape();
688 auto outputShape = outputTensorInfo.GetShape();
689
690 auto inputNumDimensions = inputShape.GetNumDimensions();
691 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
692
693 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
694
695 // 1D input shape results in scalar output shape
696 if (inputShape.GetNumDimensions() == 1)
697 {
698 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
699 {
700 throw InvalidArgumentException(descriptorName + outputShapeError);
701 }
702 }
703 else
704 {
705 for (unsigned int i = 0; i < unsignedAxis; ++i)
706 {
707 if (outputShape[i] != inputShape[i])
708 {
709 throw InvalidArgumentException(descriptorName + outputShapeError);
710 }
711 }
712
713 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
714 {
715 if (outputShape[i - 1] != inputShape[i])
716 {
717 throw InvalidArgumentException(descriptorName + outputShapeError);
718 }
719 }
720 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100721}
722
mathad01b392e982021-04-07 12:07:30 +0100723void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
724{
725 const std::string descriptorName{"CastQueueDescriptor"};
726
727 ValidateNumInputs(workloadInfo, descriptorName, 1);
728 ValidateNumOutputs(workloadInfo, descriptorName, 1);
729
730 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
731 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
732
733 std::vector<DataType> supportedTypes =
734 {
735 DataType::BFloat16,
736 DataType::Float16,
737 DataType::Float32,
738 DataType::QAsymmS8,
739 DataType::QAsymmU8,
740 DataType::QSymmS8,
741 DataType::QSymmS16,
742 DataType::Signed32,
743 DataType::Signed64
744 };
745
746 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
747 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
748}
749
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100750void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
751{
752 const std::string descriptorName{"SoftmaxQueueDescriptor"};
753
754 ValidateNumInputs(workloadInfo, descriptorName, 1);
755 ValidateNumOutputs(workloadInfo, descriptorName, 1);
756
757 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
758 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
759
760 std::vector<DataType> supportedTypes =
761 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000762 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100763 DataType::Float16,
764 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000765 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000766 DataType::QAsymmU8,
767 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100768 };
769
770 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
771 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
772 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
773}
774
telsoa014fcda012018-03-09 14:13:49 +0000775void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
776{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100777 const std::string descriptorName{"SplitterQueueDescriptor"};
778
779 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000780
Ruomei Yan25339c32019-05-28 16:48:20 +0100781 // Check the supported data types
782 std::vector<DataType> supportedTypes =
783 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000784 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100785 DataType::Float32,
786 DataType::Float16,
787 DataType::Boolean,
788 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100789 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000790 DataType::QAsymmU8,
791 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100792 };
793
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100794 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
795 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100796 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100797 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
798 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
799
800 const std::string outputName = "output_" + std::to_string(i);
801 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100802 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100803
telsoa014fcda012018-03-09 14:13:49 +0000804 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
805 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100806 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000807 }
808
809 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
810 {
811 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100812 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000813 "has to match number of workloadInfo.m_OutputTensorInfos. "
814 "Number of windows: " +
815 to_string(m_ViewOrigins.size()) +
816 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
817 }
818
telsoa01c577f2c2018-08-31 09:22:23 +0100819 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000820 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
821 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
822 {
telsoa01c577f2c2018-08-31 09:22:23 +0100823 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000824 ViewOrigin const& e = m_ViewOrigins[w];
825 if (e.m_Origin.size() != inputDims)
826 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100827 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000828 "have the same dimensionality as the input tensor. "
829 "Window origin (index: " +
830 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
831 " dimensions, the input "
832 "tensor has " +
833 to_string(inputDims) + " dimensions.");
834 }
835 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
836 {
837 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
838 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
839 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100840 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000841 "be smaller or equal than the size of the input in that coord.");
842 }
843 }
844 }
845}
846
Jim Flynne242f2d2019-05-22 14:24:13 +0100847void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000848{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100849 const std::string descriptorName{"ConcatQueueDescriptor"};
850
851 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000852
853 if (m_Inputs.size() <= 0)
854 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100855 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000856 }
857 if (m_Outputs.size() <= 0)
858 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100859 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000860 }
861
862 if (workloadInfo.m_InputTensorInfos.size() <= 0)
863 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100864 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000865 }
866 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
867 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100868 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000869 }
870
Nikhil Raj8599a412018-11-19 14:51:07 +0000871 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
872 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100873 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000874 }
875
876 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
877 {
878 return;
879 }
880
telsoa014fcda012018-03-09 14:13:49 +0000881 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
882 {
883 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100884 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000885 "has to match number of workloadInfo.m_InputTensorInfos. "
886 "Number of windows: " +
887 to_string(m_ViewOrigins.size()) +
888 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
889 }
890
telsoa01c577f2c2018-08-31 09:22:23 +0100891 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000892 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
893 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
894 {
telsoa01c577f2c2018-08-31 09:22:23 +0100895 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000896 ViewOrigin const& e = m_ViewOrigins[w];
897 if (e.m_Origin.size() != outputDims)
898 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100899 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000900 "have the same dimensionality as the output tensor. "
901 "Window origin (index: " +
902 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
903 " dimensions, the output "
904 "tensor has " +
905 to_string(outputDims) + " dimensions.");
906 }
telsoa01c577f2c2018-08-31 09:22:23 +0100907 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000908 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
909 {
910 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
911 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
912 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100913 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000914 "be smaller or equal than the size of the output in that coord.");
915 }
916 }
917 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100918
919 // Check the supported data types
920 std::vector<DataType> supportedTypes =
921 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000922 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100923 DataType::Float32,
924 DataType::Float16,
925 DataType::Boolean,
926 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100927 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000928 DataType::QAsymmU8,
929 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100930 };
931
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100932 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
933 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100934 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100935 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
936 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
937
938 const std::string inputName = "input_" + std::to_string(i);
939 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100940 }
telsoa014fcda012018-03-09 14:13:49 +0000941}
942
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100943void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
944{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100945 const std::string descriptorName{"StackQueueDescriptor"};
946
947 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100948
949 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
950 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100951 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100952 }
953
954 // All inputs must have the same shape, which is defined in parameters
955 const TensorShape& inputShape = m_Parameters.m_InputShape;
956 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
957 {
958 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
959 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100960 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100961 }
962 }
963
Matthew Jacksondba634f2019-08-15 15:14:18 +0100964 if (inputShape.GetNumDimensions() > 4)
965 {
966 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
967 }
968
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100969 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
970 // since the output tensor has an additional dimension.
971 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
972 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100973 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100974 "than the number of input dimensions.");
975 }
976
977 // Output shape must be as inferred from the input shape
978 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
979 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
980 {
981 if (outputShape[i] != inputShape[i])
982 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100983 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100984 "match shape inferred from input tensor.");
985 }
986 }
987
988 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
989 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100990 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100991 "match shape inferred from input tensor.");
992 }
993
994 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
995 {
996 if (outputShape[i] != inputShape[i-1])
997 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100998 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100999 "match shape inferred from input tensor.");
1000 }
1001 }
1002
Matthew Jacksondba634f2019-08-15 15:14:18 +01001003 if (outputShape.GetNumDimensions() > 5)
1004 {
1005 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
1006 }
1007
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001008 // Check the supported data types
1009 std::vector<DataType> supportedTypes =
1010 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001011 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001012 DataType::Float32,
1013 DataType::Float16,
1014 DataType::Boolean,
1015 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001016 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001017 DataType::QAsymmU8,
1018 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001019 };
1020
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001021 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001022
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001023 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001024 {
1025 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1026 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001027 descriptorName,
1028 "input_0",
1029 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001030 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001031
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001032 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1033 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001034 descriptorName,
1035 "input_0",
1036 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001037}
1038
Ryan OSheaec6c6802020-06-05 17:17:06 +01001039void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1040{
1041 const std::string descriptorName{"FillQueueDescriptor"};
1042
1043 ValidateNumInputs(workloadInfo, descriptorName, 1);
1044 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1045
1046 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1047 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1048
1049 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1050
1051 std::vector<DataType> supportedTypes =
1052 {
1053 DataType::BFloat16,
1054 DataType::Float32,
1055 DataType::Float16,
1056 DataType::Signed32
1057 };
1058
1059 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1060}
1061
telsoa014fcda012018-03-09 14:13:49 +00001062void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1063{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001064 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001065
Matthew Sloyan81beae32021-07-13 19:46:11 +01001066 uint32_t numInputs = 2;
1067 if (m_Parameters.m_BiasEnabled)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001068 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001069 numInputs = 3;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001070 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001071
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001072 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001073 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1074
1075 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1076 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1077
1078 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1079
1080 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001081 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001082 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001083 }
1084
Matthew Sloyan81beae32021-07-13 19:46:11 +01001085 TensorInfo weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001086 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001087
1088 if (m_Parameters.m_BiasEnabled)
1089 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001090 TensorInfo biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
telsoa01c577f2c2018-08-31 09:22:23 +01001091 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001092 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001093 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1094 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001095 }
1096
Francis Murtagh46c09d02019-05-28 08:15:28 +01001097 // Check the supported data types
1098 std::vector<DataType> supportedTypes =
1099 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001100 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001101 DataType::Float32,
1102 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001103 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001104 DataType::QAsymmU8,
1105 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001106 };
1107
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001108 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001109
1110 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1111 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1112 {
1113 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1114 {
1115 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1116 "for BFloat16 input.");
1117 }
1118 }
1119 else
1120 {
1121 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1122 }
telsoa014fcda012018-03-09 14:13:49 +00001123}
1124
telsoa014fcda012018-03-09 14:13:49 +00001125void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1126{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001127 const std::string descriptorName{"NormalizationQueueDescriptor"};
1128
1129 ValidateNumInputs(workloadInfo, descriptorName, 1);
1130 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1131
1132 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1133 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001134
1135 // Check the supported data types
1136 std::vector<DataType> supportedTypes =
1137 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001138 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001139 DataType::Float16,
1140 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001141 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001142 DataType::QAsymmU8,
1143 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001144 };
1145
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001146 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001147
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001148 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001149
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001150 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001151}
1152
1153void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1154{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001155 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001156
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001157 ValidateNumInputs(workloadInfo, descriptorName, 2);
1158 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1159
1160 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1161 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1162 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1163
1164 std::vector<DataType> supportedTypes =
1165 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001166 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001167 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001168 DataType::Float16,
1169 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001170 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001171 DataType::QSymmS16,
1172 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001173 };
1174
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001175 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1176 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1177 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001178
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001179 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1180 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001181
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001182 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1183 inputTensorInfo1,
1184 outputTensorInfo,
1185 descriptorName,
1186 "input_0",
1187 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001188}
1189
telsoa014fcda012018-03-09 14:13:49 +00001190void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1191{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001192 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001193
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001194 ValidateNumInputs(workloadInfo, descriptorName, 2);
1195 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1196
1197 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1198 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1199 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1200
1201 std::vector<DataType> supportedTypes =
1202 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001203 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001204 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001205 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001206 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001207 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001208 DataType::QSymmS16,
1209 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001210 };
1211
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001212 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1213 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1214 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001215
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001216 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1217 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001218
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001219 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1220 inputTensorInfo1,
1221 outputTensorInfo,
1222 descriptorName,
1223 "input_0",
1224 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001225}
1226
1227void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1228{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001229 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001230
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001231 ValidateNumInputs(workloadInfo, descriptorName, 1);
1232 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1233
1234 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1235 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001236
1237 std::vector<DataType> supportedTypes =
1238 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001239 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001240 DataType::Float16,
1241 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001242 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001243 DataType::QAsymmU8,
1244 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001245 };
1246
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001247 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1248 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001249
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001250 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001251 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001252
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001253 ValidatePointer(m_Mean, descriptorName, "mean");
1254 ValidatePointer(m_Variance, descriptorName, "variance");
1255 ValidatePointer(m_Beta, descriptorName, "beta");
1256 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001257
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001258 const TensorInfo& mean = m_Mean->GetTensorInfo();
1259 const TensorInfo& variance = m_Variance->GetTensorInfo();
1260 const TensorInfo& beta = m_Beta->GetTensorInfo();
1261 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001262
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001263 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1264 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1265 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1266 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001267
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001268 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1269 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1270 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001271}
1272
1273void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1274{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001275 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001276
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001277 uint32_t numInputs = 2;
1278 if (m_Parameters.m_BiasEnabled)
1279 {
1280 numInputs = 3;
1281 }
1282
1283 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001284 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001285
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001286 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1287 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001288
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001289 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1290 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001291
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001292 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
telsoa014fcda012018-03-09 14:13:49 +00001293
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001294 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001295
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001296 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001297
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001298 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001299 if (m_Parameters.m_BiasEnabled)
1300 {
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001301 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001302 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001303
1304 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1305 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001306 }
1307
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001308 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1309 {
1310 throw InvalidArgumentException(
1311 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1312 "cannot be either negative or 0.",
1313 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1314 }
1315
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001316 ValidatePerAxisQuantization(inputTensorInfo,
1317 outputTensorInfo,
1318 weightTensorInfo,
1319 optionalBiasTensorInfo,
1320 descriptorName);
1321
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001322 std::vector<DataType> supportedTypes =
1323 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001324 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001325 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001326 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001327 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001328 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001329 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001330 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001331 };
1332
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001333 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001334
1335 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1336 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1337 {
1338 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1339 {
1340 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1341 "for BFloat16 input.");
1342 }
1343 }
1344 else
1345 {
1346 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1347 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001348}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001349
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001350void Convolution3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1351{
1352 const std::string descriptorName{"Convolution3dQueueDescriptor"};
1353
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001354 uint32_t numInputs = 2;
1355 if (m_Parameters.m_BiasEnabled)
1356 {
1357 numInputs = 3;
1358 }
1359 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001360 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1361
1362 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1363 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1364
1365 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1366 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1367
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001368 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001369 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 5, "weight");
1370
1371 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1372
1373 Optional<TensorInfo> optionalBiasTensorInfo;
1374 if (m_Parameters.m_BiasEnabled)
1375 {
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001376 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001377 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1378
1379 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1380 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1381 }
1382
1383 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 || m_Parameters.m_StrideZ <= 0 )
1384 {
1385 throw InvalidArgumentException(
1386 fmt::format("{}: strideX (provided {}), strideY (provided {}) or strideZ (provided {})"
1387 "cannot be either negative or 0.",
1388 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY, m_Parameters.m_StrideZ));
1389 }
1390
1391 ValidatePerAxisQuantization(inputTensorInfo,
1392 outputTensorInfo,
1393 weightTensorInfo,
1394 optionalBiasTensorInfo,
1395 descriptorName);
1396
1397 std::vector<DataType> supportedTypes =
1398 {
1399 DataType::BFloat16,
1400 DataType::Float16,
1401 DataType::Float32,
1402 DataType::QAsymmS8,
1403 DataType::QAsymmU8,
1404 DataType::QSymmS16,
1405 DataType::QSymmS8
1406 };
1407
1408 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1409 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1410}
1411
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001412void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1413{
1414 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1415
Cathal Corbett06902652022-04-14 17:55:11 +01001416 uint32_t numInputs = 2;
1417 if (m_Parameters.m_BiasEnabled)
1418 {
1419 numInputs = 3;
1420 }
1421
1422 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001423 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1424
1425 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1426 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1427
1428 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1429 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1430
Cathal Corbett06902652022-04-14 17:55:11 +01001431 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001432 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1433
1434 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1435 {
1436 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001437 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1438 "cannot be smaller than 1.",
1439 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001440 }
1441
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001442 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1443 {
1444 throw InvalidArgumentException(
1445 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1446 "cannot be either negative or 0.",
1447 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1448 }
1449
Jan Eilers53ef7952021-06-02 12:01:25 +01001450 if (weightTensorInfo.GetShape()[0] != 1)
1451 {
1452 throw InvalidArgumentException(fmt::format(
1453 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1454 "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
1455 descriptorName,
1456 weightTensorInfo.GetShape()[0],
1457 weightTensorInfo.GetShape()[1],
1458 weightTensorInfo.GetShape()[2],
1459 weightTensorInfo.GetShape()[3]));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001460 }
1461
Cathal Corbett4b19d222022-05-11 20:12:17 +01001462 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1463 const unsigned int numWeightOutputChannelsRefFormat = weightTensorInfo.GetShape()[3];
1464 const unsigned int numWeightOutputChannelsAclFormat = weightTensorInfo.GetShape()[1];
1465 const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1466
1467 // Weights format has two valid options: [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] (CpuAcc/GpuAcc).
1468 bool validRefFormat = (numWeightOutputChannelsRefFormat == numOutputChannels);
1469 bool validAclFormat = (numWeightOutputChannelsAclFormat == numOutputChannels);
1470
1471 if (!(validRefFormat || validAclFormat))
1472 {
1473 throw InvalidArgumentException(fmt::format(
1474 "{0}: The weight format in armnn is expected to be [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] "
1475 "(CpuAcc/GpuAcc). But neither the 4th (CpuRef) or 2nd (CpuAcc/GpuAcc) dimension is equal to Cout."
1476 "Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
1477 descriptorName,
1478 numOutputChannels,
1479 weightTensorInfo.GetShape()[0],
1480 weightTensorInfo.GetShape()[1],
1481 weightTensorInfo.GetShape()[2],
1482 weightTensorInfo.GetShape()[3]));
1483 }
1484
Teresa Charlind8df0262019-11-11 12:28:15 +00001485 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001486
Teresa Charlind8df0262019-11-11 12:28:15 +00001487 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001488 if (m_Parameters.m_BiasEnabled)
1489 {
Cathal Corbett06902652022-04-14 17:55:11 +01001490 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Teresa Charlind8df0262019-11-11 12:28:15 +00001491 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001492
1493 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1494 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1495 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001496 ValidatePerAxisQuantization(inputTensorInfo,
1497 outputTensorInfo,
1498 weightTensorInfo,
1499 optionalBiasTensorInfo,
1500 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001501
1502 std::vector<DataType> supportedTypes =
1503 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001504 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001505 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001506 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001507 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001508 DataType::QAsymmU8,
1509 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001510 };
1511
1512 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1513 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001514}
1515
1516void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1517{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001518 const std::string descriptorName{"PermuteQueueDescriptor"};
1519
1520 ValidateNumInputs(workloadInfo, descriptorName, 1);
1521 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001522
1523 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1524
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001525 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1526 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001527
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001528 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1529 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001530
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001531 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001532 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001533 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001534 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001535 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1536 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1537 "must match dst dimension " + to_string(mapping[i]) +
1538 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001539 }
1540 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001541
1542 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001543}
1544
1545void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1546{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001547 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001548
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001549 ValidateNumInputs(workloadInfo, descriptorName, 1);
1550 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1551
1552 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1553 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1554
1555 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1556 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001557
1558 std::vector<DataType> supportedTypes =
1559 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001560 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001561 DataType::Float32,
1562 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001563 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001564 DataType::QAsymmU8,
1565 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001566 };
1567
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001568 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1569 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001570}
1571
Tamás Nyíri7b885b32021-10-26 14:47:57 +01001572void Pooling3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1573{
1574 const std::string descriptorName{"Pooling3dQueueDescriptor"};
1575
1576 ValidateNumInputs(workloadInfo, descriptorName, 1);
1577 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1578
1579 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1580 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1581
1582 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1583 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1584
1585 std::vector<DataType> supportedTypes =
1586 {
1587 DataType::BFloat16,
1588 DataType::Float32,
1589 DataType::Float16,
1590 DataType::QAsymmS8,
1591 DataType::QAsymmU8,
1592 DataType::QSymmS16
1593 };
1594
1595 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1596 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1597}
1598
Teresa Charlin970f43b2019-07-01 13:51:07 +01001599void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1600{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001601 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001602
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001603 ValidateNumInputs(workloadInfo, descriptorName, 1);
1604 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1605
1606 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1607 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1608
1609 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1610 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001611
1612 std::vector<DataType> supportedTypes =
1613 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001614 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001615 DataType::Float16,
1616 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001617 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001618 DataType::QAsymmU8,
1619 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001620 };
1621
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001622 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1623 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001624
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001625 // Resize only changes width and height: batch and channel count must match.
1626 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1627 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001628 if (inputBatchSize != outputBatchSize)
1629 {
1630 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001631 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1632 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001633 }
1634
1635 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001636 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1637 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001638 if (inputChannelCount != outputChannelCount)
1639 {
1640 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001641 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1642 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001643 }
1644}
1645
1646void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1647{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001648 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001649
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001650 ValidateNumInputs(workloadInfo, descriptorName, 1);
1651 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1652
1653 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1654 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1655
1656 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1657 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1658
1659 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1660
telsoa014fcda012018-03-09 14:13:49 +00001661 if (m_Parameters.m_Min > m_Parameters.m_Max)
1662 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001663 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001664 }
telsoa014fcda012018-03-09 14:13:49 +00001665}
1666
Kevin Mayce5045a2019-10-02 14:07:47 +01001667void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1668{
1669 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1670
1671 ValidateNumInputs(workloadInfo, descriptorName, 1);
1672 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1673
1674 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1675 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1676
1677 if (inputTensorInfo.GetNumDimensions() > 4)
1678 {
1679 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1680 }
1681
1682 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1683
1684 // Check the supported data types
1685 std::vector<DataType> supportedTypes =
1686 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001687 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001688 DataType::Float32,
1689 DataType::Float16
1690 };
1691
1692 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001693 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001694}
1695
telsoa014fcda012018-03-09 14:13:49 +00001696void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1697{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001698 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001699
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001700 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001701 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1702
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001703 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1704 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1705
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001706 if (inputTensorInfo.GetNumDimensions() > 4)
1707 {
1708 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1709 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001710
1711 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001712
1713 // Check the supported data types
1714 std::vector<DataType> supportedTypes =
1715 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001716 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001717 DataType::Float32,
1718 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001719 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001720 DataType::QAsymmU8,
1721 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001722 };
1723
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001724 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001725 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1726}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001727
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001728void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1729{
1730 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1731
1732 ValidateNumInputs(workloadInfo, descriptorName, 1);
1733 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1734
1735 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1736 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1737
1738 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1739
1740 std::vector<DataType> supportedTypes =
1741 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001742 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001743 DataType::Float32,
1744 DataType::Float16,
1745 };
1746
1747 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001748 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001749}
1750
1751void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1752{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001753 const std::string descriptorName{"ConstantQueueDescriptor"};
1754
1755 ValidateNumInputs(workloadInfo, descriptorName, 0);
1756 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001757
1758 if (!m_LayerOutput)
1759 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001760 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001761 }
1762
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001763 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1764 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001765
1766 // Check the supported data types
1767 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001768 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001769 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001770 DataType::Float32,
1771 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001772 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001773 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001774 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001775 DataType::QSymmS16,
1776 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001777 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001778
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001779 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001780}
1781
1782void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1783{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001784 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001785
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001786 ValidateNumInputs(workloadInfo, descriptorName, 1);
1787 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1788
1789 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1790 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1791
1792 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001793
1794 // Check the supported data types
1795 std::vector<DataType> supportedTypes =
1796 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001797 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001798 DataType::Float32,
1799 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001800 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001801 DataType::QAsymmU8,
1802 DataType::QSymmS16,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001803 DataType::Signed32,
1804 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001805 };
1806
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001807 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1808 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001809}
1810
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001811void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1812{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001813 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001814
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001815 ValidateNumInputs(workloadInfo, descriptorName, 1);
1816 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1817
1818 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1819 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1820
1821 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1822 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001823
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001824 if (m_Parameters.m_BlockShape.size() != 2)
1825 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001826 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001827 }
1828
1829 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1830 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001831 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1832 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001833 }
1834
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001835 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001836
1837 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001838 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001839
Matthew Bentham8800c002018-11-19 13:19:28 +00001840 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001841
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001842 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1843 widthPad.first + widthPad.second;
1844 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1845 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001846
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001847 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1848 inputShape[dimensionIndices.GetChannelsIndex()];
1849 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001850
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001851 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001852 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001853 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001854 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001855 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001856 }
1857
1858 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001859 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001860 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1861 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001862 }
nikraj01120522a2019-05-31 11:33:07 +01001863
1864 std::vector<DataType> supportedTypes =
1865 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001866 DataType::BFloat16,
1867 DataType::Float16,
1868 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001869 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001870 DataType::QAsymmU8,
1871 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001872 };
1873
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001874 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1875 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001876}
1877
Keith Davisa57eccb2019-06-14 17:33:22 +01001878void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1879{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001880 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001881
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001882 ValidateNumInputs(workloadInfo, descriptorName, 1);
1883 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001884
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001885 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1886 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1887
1888 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1889 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001890
1891 std::vector<DataType> supportedTypes =
1892 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001893 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001894 DataType::Float32,
1895 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001896 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001897 DataType::QAsymmU8,
1898 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001899 };
1900
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001901 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1902 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001903
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001904 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1905
1906 if (m_Parameters.m_BlockSize == 0)
1907 {
1908 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1909 }
1910
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001911 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1912 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1913 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1914 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001915
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001916 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001917 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001918 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001919 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1920 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001921 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001922
1923 const TensorShape& outputShape = outputTensorInfo.GetShape();
1924 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1925 {
1926 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1927 "must be divisible by the square of block size." );
1928 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001929}
1930
telsoa014fcda012018-03-09 14:13:49 +00001931void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1932{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001933 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001934
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001935 ValidateNumInputs(workloadInfo, descriptorName, 1);
1936 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1937
1938 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1939 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001940
1941 std::vector<DataType> supportedTypes =
1942 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001943 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001944 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001945 DataType::Float16,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001946 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001947 };
1948
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001949 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matthew Sloyan81beae32021-07-13 19:46:11 +01001950 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1951 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1952 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001953}
1954
telsoa01c577f2c2018-08-31 09:22:23 +01001955void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1956{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001957 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1958
1959 const std::string descriptorName{"LstmQueueDescriptor"};
1960
1961 // check dimensions of all inputs and outputs
1962 if (workloadInfo.m_InputTensorInfos.size() != 3)
1963 {
1964 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1965 }
1966 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1967 {
1968 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1969 }
1970
1971 std::vector<DataType> supportedTypes =
1972 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001973 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001974 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001975 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001976 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001977 };
1978
Jan Eilers38e05bd2019-06-26 13:10:09 +01001979 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001980 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1981
Jan Eilers38e05bd2019-06-26 13:10:09 +01001982 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001983 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001984 {
1985 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1986 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001987 descriptorName,
1988 "input_0",
1989 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001990 }
1991 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001992 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001993 {
1994 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1995 workloadInfo.m_OutputTensorInfos[i],
1996 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001997 "input_0",
1998 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001999 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002000
janeil0117d8d852019-11-15 15:00:16 +00002001 // Making sure clipping parameters have valid values.
2002 // == 0 means no clipping
2003 // > 0 means clipping
2004 if (m_Parameters.m_ClippingThresCell < 0.0f)
2005 {
2006 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
2007 }
2008 if (m_Parameters.m_ClippingThresProj < 0.0f)
2009 {
2010 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
2011 }
2012
Jan Eilers38e05bd2019-06-26 13:10:09 +01002013 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01002014 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
2015 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
2016 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
2017 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
2018 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
2019 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
2020
Jan Eilers38e05bd2019-06-26 13:10:09 +01002021 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002022 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
2023 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002024 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002025 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
2026 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002027 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002028 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
2029 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002030 // scratchBufferTensor
2031 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002032 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
2033 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002034 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002035 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
2036 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002037 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002038 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
2039 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002040 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002041 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
2042 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002043
Jan Eilers38e05bd2019-06-26 13:10:09 +01002044 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
2045 if ( m_InputToInputWeights )
2046 {
2047 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
2048 (n_cell * n_input), "InputLayerNormWeights");
2049 }
2050
2051 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2052 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
2053 (n_cell * n_input), "InputToForgetWeights");
2054
2055 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2056 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2057 (n_cell * n_input), "InputToCellWeights");
2058
2059 if ( m_RecurrentToInputWeights )
2060 {
2061 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2062 (n_cell * n_output), "RecurrentToInputWeights");
2063 }
2064
2065 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2066 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2067 (n_cell * n_output), "RecurrentToForgetWeights");
2068
2069 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2070 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2071 (n_cell * n_output), "RecurrentToCellWeights");
2072
2073 // Make sure the input-gate's parameters are either both present (regular
2074 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2075 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2076 !m_Parameters.m_CifgEnabled) ||
2077 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2078 m_Parameters.m_CifgEnabled));
2079 if (!cifg_weights_all_or_none)
2080 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002081 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2082 "RecurrentToInputWeights must either both be present (regular LSTM) "
2083 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2084 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002085 }
2086
2087 if ( m_CellToInputWeights )
2088 {
2089 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2090 n_cell, "CellToInputWeights");
2091 }
2092 if ( m_CellToForgetWeights )
2093 {
2094 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2095 n_cell, "CellToForgetWeights");
2096 }
2097 if ( m_CellToOutputWeights )
2098 {
2099 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2100 n_cell, "CellToOutputWeights");
2101 }
2102
2103 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2104 bool peephole_weights_all_or_none =
2105 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2106 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2107 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2108 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2109 if (!peephole_weights_all_or_none)
2110 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002111 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002112 }
2113
2114 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2115 if (m_Parameters.m_CifgEnabled)
2116 {
2117 if (m_InputGateBias)
2118 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002119 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002120 }
2121 }
2122 else
2123 {
2124 if (!m_InputGateBias)
2125 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002126 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2127 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002128 }
2129 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2130 n_cell, "InputGateBias");
2131 }
2132
2133 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2134 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2135
2136 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2137 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2138
2139 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2140 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2141
2142 if (m_ProjectionWeights)
2143 {
2144 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2145 (n_cell * n_output), "ProjectionWeights");
2146 }
2147 if (m_ProjectionBias)
2148 {
2149 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2150 }
2151
2152 // Making sure the projection tensors are consistent:
2153 // 1) If projection weight is not present, then projection bias should not be
2154 // present.
2155 // 2) If projection weight is present, then projection bias is optional.
2156 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2157 !m_Parameters.m_ProjectionEnabled)
2158 || (m_ProjectionWeights && !m_ProjectionBias &&
2159 m_Parameters.m_ProjectionEnabled)
2160 || (m_ProjectionWeights && m_ProjectionBias &&
2161 m_Parameters.m_ProjectionEnabled));
2162 if (!projecton_tensors_consistent)
2163 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002164 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002165 }
2166
2167 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2168 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2169 // either all have values or none of them have values. Layer normalization is used when the values of all the
2170 // layer normalization weights are present
2171 if (m_InputLayerNormWeights)
2172 {
2173 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2174 }
2175 if (m_ForgetLayerNormWeights)
2176 {
2177 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2178 }
2179 if (m_CellLayerNormWeights)
2180 {
2181 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2182 }
2183 if (m_OutputLayerNormWeights)
2184 {
2185 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2186 }
2187
Jan Eilers38e05bd2019-06-26 13:10:09 +01002188 if (m_Parameters.m_LayerNormEnabled)
2189 {
2190 if (!m_Parameters.m_CifgEnabled)
2191 {
2192 if (!m_InputLayerNormWeights)
2193 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002194 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2195 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002196 }
2197 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2198 1, n_cell, "InputLayerNormWeights");
2199 }
2200 else if (m_InputLayerNormWeights)
2201 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002202 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2203 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002204 }
2205
2206 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2207 "ForgetLayerNormWeights");
2208 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2209
2210 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2211 "OutputLayerNormWeights");
2212 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2213
2214 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2215 "CellLayerNormWeights");
2216 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2217 }
2218 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2219 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002220 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2221 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002222 }
telsoa01c577f2c2018-08-31 09:22:23 +01002223}
2224
2225void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2226{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002227 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002228
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002229 ValidateNumInputs(workloadInfo, descriptorName, 1);
2230 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2231
2232 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2233 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2234
2235 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002236 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002237 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002238 }
2239
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002240 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002241 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002242 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002243 }
2244
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002245 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002246}
2247
2248void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2249{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002250 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002251
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002252 ValidateNumInputs(workloadInfo, descriptorName, 1);
2253 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2254
2255 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2256 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2257
2258 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002259 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002260 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002261 }
2262
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002263 if (outputTensorInfo.GetDataType() != DataType::Float32)
2264 {
2265 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2266 }
2267
2268 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002269}
2270
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002271void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2272{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002273 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002274
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002275 ValidateNumInputs(workloadInfo, descriptorName, 2);
2276 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2277
2278 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2279 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2280 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2281
2282 std::vector<DataType> supportedTypes =
2283 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002284 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002285 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002286 DataType::Float32,
2287 DataType::QAsymmS8,
2288 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002289 DataType::QSymmS16,
2290 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002291 };
2292
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002293 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2294 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2295 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002296
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002297 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2298 inputTensorInfo1,
2299 outputTensorInfo,
2300 descriptorName,
2301 "input_0",
2302 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002303}
2304
David Beckc2044fe2018-09-05 15:00:38 +01002305void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2306{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002307 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002308
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002309 ValidateNumInputs(workloadInfo, descriptorName, 2);
2310 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2311
2312 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2313 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2314 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2315
2316 std::vector<DataType> supportedTypes =
2317 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002318 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002319 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002320 DataType::Float32,
2321 DataType::QAsymmS8,
2322 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002323 DataType::QSymmS16,
2324 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002325 };
2326
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002327 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2328 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2329 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002330
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002331 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2332 inputTensorInfo1,
2333 outputTensorInfo,
2334 descriptorName,
2335 "input_0",
2336 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002337}
2338
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002339void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2340{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002341 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002342
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002343 ValidateNumInputs(workloadInfo, descriptorName, 2);
2344 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2345
2346 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2347 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2348 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2349
2350 std::vector<DataType> supportedTypes =
2351 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002352 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002353 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002354 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002355 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002356 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002357 DataType::QSymmS16,
2358 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002359 };
2360
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002361 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2362 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2363 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002364
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002365 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2366 inputTensorInfo1,
2367 outputTensorInfo,
2368 descriptorName,
2369 "input_0",
2370 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002371}
2372
narpra01a6bf9122018-09-10 09:50:09 +01002373void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2374{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002375 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002376
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002377 ValidateNumInputs(workloadInfo, descriptorName, 1);
2378 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2379
2380 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2381 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002382
2383 std::vector<DataType> supportedTypes =
2384 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002385 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002386 DataType::Float32,
2387 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002388 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002389 DataType::QAsymmU8,
2390 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002391 };
narpra01eb061912018-09-10 17:35:27 +01002392
James Conroy4d1ff582019-06-10 17:06:39 +01002393 // First check if input tensor data type is supported, then
2394 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002395 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2396 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002397
narpra0132b90462018-09-13 11:07:48 +01002398 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002399 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002400 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002401 }
narpra0132b90462018-09-13 11:07:48 +01002402 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002403 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002404 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002405 }
2406 else
2407 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002408 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002409 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002410 ValidateTensorNumDimensions(outputTensorInfo,
2411 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002412 outputDim > 0 ? outputDim : 1,
2413 "output");
2414 }
narpra01a6bf9122018-09-10 09:50:09 +01002415}
2416
jimfly012c9322a2018-09-19 10:59:49 +01002417void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2418{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002419 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002420
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002421 ValidateNumInputs(workloadInfo, descriptorName, 1);
2422 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2423
2424 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2425 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002426
jimfly012c9322a2018-09-19 10:59:49 +01002427 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002428 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2429
jimfly012c9322a2018-09-19 10:59:49 +01002430 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002431 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2432 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2433 "as there are dimensions in the input tensor that is " +
2434 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2435 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002436 }
2437}
2438
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002439void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2440{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002441 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002442
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002443 ValidateNumInputs(workloadInfo, descriptorName, 1);
2444 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002445
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002446 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2447 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2448
Sadik Armagan2208b602019-07-31 16:36:27 +01002449 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002450 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002451 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002452 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002453 DataType::Float16,
2454 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002455 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002456 DataType::QAsymmU8,
2457 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002458 };
2459
2460 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002461
Keith Davis0c2eeac2020-02-11 16:51:50 +00002462 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002463 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002464 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002465 }
2466}
2467
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002468void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2469{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002470 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002471
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002472 ValidateNumInputs(workloadInfo, descriptorName, 1);
2473 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002474
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002475 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2476 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002477
2478 std::vector<DataType> supportedTypes =
2479 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002480 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002481 DataType::Float32,
2482 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002483 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002484 DataType::QAsymmU8,
2485 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002486 };
2487
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002488 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2489 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002490}
2491
Conor Kennedy430b5d82018-11-14 15:28:28 +00002492void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2493{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002494 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002495
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002496 ValidateNumInputs(workloadInfo, descriptorName, 1);
2497 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2498
2499 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2500 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002501
2502 std::vector<DataType> supportedTypes =
2503 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002504 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002505 DataType::Float16,
2506 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002507 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002508 DataType::QAsymmU8,
2509 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002510 };
2511
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002512 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2513 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002514
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002515 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002516
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002517 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002518 if (rank > 4)
2519 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002520 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002521 }
2522
Conor Kennedy430b5d82018-11-14 15:28:28 +00002523 // Begin, End & Stride length must be of rank(input0)
2524 if (m_Parameters.m_Begin.size() != rank)
2525 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002526 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002527 }
2528
2529 if (m_Parameters.m_End.size() != rank)
2530 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002531 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002532 }
2533
2534 if (m_Parameters.m_Stride.size() != rank)
2535 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002536 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002537 }
2538
2539 // Stride entries must be non-zero
2540 for (auto& stride : m_Parameters.m_Stride)
2541 {
2542 if (stride == 0)
2543 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002544 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002545 }
2546 }
2547}
2548
kevmay0190539692018-11-29 08:40:19 +00002549void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2550{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002551 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002552
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002553 ValidateNumInputs(workloadInfo, descriptorName, 2);
2554 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2555
2556 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2557 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2558 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2559
2560 std::vector<DataType> supportedTypes =
2561 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002562 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002563 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002564 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002565 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002566 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002567 DataType::QSymmS16,
2568 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002569 };
2570
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002571 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2572 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2573 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002574
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002575 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2576 inputTensorInfo1,
2577 outputTensorInfo,
2578 descriptorName,
2579 "input_0",
2580 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002581}
2582
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002583void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2584{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002585 const std::string descriptorName{"DebugQueueDescriptor"};
2586
2587 ValidateNumInputs(workloadInfo, descriptorName, 1);
2588 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002589}
2590
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002591void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2592{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002593 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002594
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002595 ValidateNumInputs(workloadInfo, descriptorName, 2);
2596 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002597
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002598 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2599 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2600 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2601
2602 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2603 inputTensorInfo1,
2604 outputTensorInfo,
2605 descriptorName,
2606 "input_0",
2607 "input_1");
2608
2609 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002610 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002611 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002612 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002613}
2614
FrancisMurtagh878f0232018-12-19 10:56:15 +00002615void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2616{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002617 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002618
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002619 ValidateNumInputs(workloadInfo, descriptorName, 2);
2620 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002621
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002622 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2623 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2624 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2625
2626 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2627 inputTensorInfo1,
2628 outputTensorInfo,
2629 descriptorName,
2630 "input_0",
2631 "input_1");
2632
2633 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002634 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002635 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002636 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002637}
2638
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002639void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2640{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002641 const std::string descriptorName{"RsqrtQueueDescriptor"};
2642
2643 ValidateNumInputs(workloadInfo, descriptorName, 1);
2644 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2645
2646 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2647 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2648
2649 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002650
2651 std::vector<DataType> supportedTypes =
2652 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002653 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002654 DataType::Float16,
2655 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002656 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002657 DataType::QAsymmU8,
2658 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002659 };
2660
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002661 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2662 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002663}
2664
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01002665void GatherNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2666{
2667 const std::string descriptorName{"GatherNdQueueDescriptor"};
2668
2669 ValidateNumInputs(workloadInfo, descriptorName, 2);
2670 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2671
2672 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2673 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
2674 {
2675 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
2676 }
2677
2678 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2679 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2680
2681 std::vector<DataType> supportedTypes =
2682 {
2683 DataType::BFloat16,
2684 DataType::Float16,
2685 DataType::Float32,
2686 DataType::QAsymmS8,
2687 DataType::QAsymmU8,
2688 DataType::QSymmS16,
2689 DataType::Signed32,
2690 };
2691
2692 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2693
2694 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2695
2696 unsigned int outputDim = outputTensorInfo.GetNumDimensions();
2697 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
2698}
2699
narpra01b89b05f2019-01-16 09:53:09 +00002700void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2701{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002702 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002703
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002704 ValidateNumInputs(workloadInfo, descriptorName, 2);
2705 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002706
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002707 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2708 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002709 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002710 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002711 }
2712
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002713 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2714 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2715
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002716 std::vector<DataType> supportedTypes =
2717 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002718 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002719 DataType::Float16,
2720 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002721 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002722 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002723 DataType::QSymmS16,
2724 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002725 };
2726
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002727 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002728
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002729 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002730
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002731 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2732 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002733}
2734
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002735void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2736{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002737 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2738
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002739 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002740
2741 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2742 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002743 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002744 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2745 }
2746
2747 if (m_Anchors == nullptr)
2748 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002749 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002750 }
2751
2752 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002753 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2754 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2755
2756 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002757 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002758 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2759 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002760
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002761 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2762 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2763 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002764
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002765 const std::vector<DataType> supportedInputTypes =
2766 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002767 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002768 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002769 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002770 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002771 DataType::QAsymmU8,
2772 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002773 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002774
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002775 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2776 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2777 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2778
2779 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2780 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2781 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2782 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2783
2784 // NOTE: Output is always Float32 regardless of input type
2785 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2786 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2787 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2788 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002789
2790 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2791 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002792 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002793 "must be positive and less than or equal to 1.");
2794 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002795
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002796 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2797 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002798 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002799 "should be equal to number of classes + 1.");
2800 }
2801}
2802
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002803void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2804{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002805 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002806
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002807 ValidateNumInputs(workloadInfo, descriptorName, 1);
2808 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2809
2810 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2811 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2812
Teresa Charlin07307f32022-05-15 14:07:05 +01002813 std::vector<DataType> inputSupportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002814 {
Teresa Charlin07307f32022-05-15 14:07:05 +01002815 DataType::QAsymmS8,
2816 DataType::QAsymmU8,
2817 DataType::QSymmS8,
2818 DataType::QSymmS16,
2819 DataType::Float16
2820 };
2821 ValidateDataTypes(inputTensorInfo, inputSupportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002822
Teresa Charlin07307f32022-05-15 14:07:05 +01002823 std::vector<DataType> outputSupportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002824 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002825 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002826 DataType::Float32,
2827 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002828 };
2829
Teresa Charlin07307f32022-05-15 14:07:05 +01002830 ValidateDataTypes(outputTensorInfo, outputSupportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002831}
2832
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002833void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2834{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002835 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002836
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002837 ValidateNumInputs(workloadInfo, descriptorName, 2);
2838 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002839
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002840 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2841 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2842 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002843
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002844 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2845 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2846
2847 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2848 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002849}
2850
Keith Davis3ae3f972021-05-21 16:33:48 +01002851void ShapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2852{
2853 const std::string& descriptorName{"ShapeQueueDescriptor"};
2854
2855 ValidateNumInputs(workloadInfo, descriptorName, 1);
2856 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2857
2858 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2859 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2860
2861 std::vector<DataType> supportedTypes =
2862 {
2863 DataType::BFloat16,
2864 DataType::Float16,
2865 DataType::Float32,
2866 DataType::QAsymmS8,
2867 DataType::QAsymmU8,
2868 DataType::QAsymmS8,
2869 DataType::QSymmS8,
2870 DataType::QSymmS16,
2871 DataType::Signed32
2872 };
2873
2874 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2875 ValidateDataTypes(outputTensorInfo, {DataType::Signed32}, descriptorName);
2876}
2877
Sadik Armaganeff363d2019-04-05 15:25:46 +01002878void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2879{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002880 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002881
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002882 ValidateNumInputs(workloadInfo, descriptorName, 2);
2883 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2884
2885 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2886 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2887
2888 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2889 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2890
2891 std::vector<DataType> supportedTypes =
2892 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002893 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002894 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002895 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002896 DataType::QAsymmU8,
2897 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002898 };
2899
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002900 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2901 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002902
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002903 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2904 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002905
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002906 ValidateTensorShapesMatch(inputTensorInfo0,
2907 outputTensorInfo0,
2908 descriptorName,
2909 "input_0",
2910 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002911
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002912 ValidateTensorShapesMatch(inputTensorInfo0,
2913 outputTensorInfo1,
2914 descriptorName,
2915 "input_0",
2916 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002917}
2918
Derek Lamberti901ea112019-12-10 22:07:09 +00002919void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002920{
2921 // This is internally generated so it should not need validation.
2922}
2923
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002924void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2925{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002926 const std::string& descriptorName{"PreluQueueDescriptor"};
2927
2928 ValidateNumInputs(workloadInfo, descriptorName, 2);
2929 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2930
2931 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2932 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2933 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002934
2935 std::vector<DataType> supportedTypes
2936 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002937 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002938 DataType::Float16,
2939 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002940 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002941 DataType::QAsymmU8,
2942 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002943 };
2944
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002945 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2946 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002947
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002948 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002949
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002950 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2951 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002952
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002953 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2954 alphaTensorInfo,
2955 outputTensorInfo,
2956 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002957 "input",
2958 "alpha");
2959}
2960
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002961void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2962{
2963 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2964
2965 ValidateNumInputs(workloadInfo, descriptorName, 1);
2966 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2967
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002968 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2969 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2970
2971 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2972 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002973
2974 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002975
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002976 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2977 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002978
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002979 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2980
2981 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002982 if (m_Parameters.m_BiasEnabled)
2983 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002984 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002985
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002986 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2987 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002988
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002989 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002990 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002991 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002992
2993 ValidatePerAxisQuantization(inputTensorInfo,
2994 outputTensorInfo,
2995 weightTensorInfo,
2996 optionalBiasTensorInfo,
2997 descriptorName);
2998
2999 std::vector<DataType> supportedTypes =
3000 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003001 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003002 DataType::Float32,
3003 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003004 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003005 DataType::QAsymmU8,
3006 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003007 };
3008
3009 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3010 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003011}
3012
Mike Kellyc9ea45a2020-02-28 18:11:58 +00003013void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3014{
3015 const std::string descriptorName{"TransposeQueueDescriptor"};
3016
3017 ValidateNumInputs(workloadInfo, descriptorName, 1);
3018 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3019
3020 const PermutationVector& mapping = m_Parameters.m_DimMappings;
3021
3022 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3023 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3024
3025 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
3026 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
3027
3028 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
3029 {
3030 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
3031 {
3032 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
3033 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
3034 "must match dst dimension " + to_string(i) +
3035 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
3036 }
3037 }
3038
3039 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3040}
3041
Simon Obute51f67772021-09-03 15:50:13 +01003042void ChannelShuffleQueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
3043{
3044 const std::string descriptorName{"TransposeQueueDescriptor"};
3045
3046 ValidateNumInputs(workloadInfo, descriptorName, 1);
3047 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3048
3049 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3050 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3051
3052 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3053}
3054
James Conroy4f1f8992020-04-29 20:01:10 +01003055void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3056{
3057 const std::string descriptorName{"QLstmQueueDescriptor"};
3058
3059 // Validate number of inputs/outputs
3060 ValidateNumInputs(workloadInfo, descriptorName, 3);
3061 ValidateNumOutputs(workloadInfo, descriptorName, 3);
3062
3063 // Input/output tensor info
3064 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3065 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
3066 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
3067
3068 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3069 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3070 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
3071
3072 // Supported types for various tensors in QLSTM
3073 std::vector<DataType> inputOutputSupportedTypes =
3074 {
3075 DataType::QAsymmS8
3076 };
3077
3078 std::vector<DataType> cellStateSupportedTypes =
3079 {
3080 DataType::QSymmS16
3081 };
3082
3083 std::vector<DataType> weightsSupportedTypes =
3084 {
3085 DataType::QSymmS8
3086 };
3087
3088 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3089 {
3090 DataType::QSymmS16
3091 };
3092
3093 std::vector<DataType> biasSupportedTypes =
3094 {
3095 DataType::Signed32
3096 };
3097
3098 // Validate types of input/output tensors
3099 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3100 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3101 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3102
3103 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3104 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3105 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3106
3107 // Validate matching types of input/output tensors
3108 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3109 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3110 "outputStateIn", "outputStateOut");
3111 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3112
3113 // Infer number of batches, number of units, input size and output size from tensor dimensions
3114 const uint32_t numBatches = inputInfo.GetShape()[0];
3115 const uint32_t inputSize = inputInfo.GetShape()[1];
3116 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3117 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3118
3119 // Validate number of dimensions and number of elements for input/output tensors
3120 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3121 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3122 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3123
3124 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3125 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3126 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3127
3128 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3129 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3130 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3131 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3132
3133 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3134 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3135 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3136
3137 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3138 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3139 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3140
3141 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3142 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3143 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3144 " RecurrentToForgetWeights");
3145
3146 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3147 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3148 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3149
3150 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3151 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3152 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3153
3154 // Validate data types for MANDATORY weights tensors (all should match each other)
3155 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3156
3157 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3158 "inputToForgetWeights", "inputToCellWeights");
3159 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3160 "inputToForgetWeights", "inputToOutputWeights");
3161
3162 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3163 "inputToForgetWeights", "recurrentToForgeteights");
3164 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3165 "inputToForgetWeights", "recurrentToCellWeights");
3166 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3167 "inputToForgetWeights", "recurrentToOutputWeights");
3168
3169 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3170 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3171 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3172 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3173
3174 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3175 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3176 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3177
3178 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3179 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3180 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3181
3182 // Validate data types for MANDATORY bias tensors
3183 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3184
3185 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3186 "forgetGateBias", "cellBias");
3187 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3188 "forgetGateBias", "outputGateBias");
3189
3190 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3191 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3192 !m_Parameters.m_CifgEnabled) ||
3193 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3194 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3195
3196 if (!allCifgParamsPresentOrNot)
3197 {
3198 throw InvalidArgumentException(descriptorName +
3199 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3200 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3201 "set appropriately.");
3202 }
3203
3204 if (!m_Parameters.m_CifgEnabled)
3205 {
3206 // Validate number of dimensions and number of elements
3207 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3208 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3209
3210 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3211 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3212 " RecurrentToInputWeights");
3213
3214 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3215 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3216
3217 // Validate data types
3218 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3219 "inputToForgetWeights", "inputToInputWeights");
3220 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3221 "inputToForgetWeights", "recurrentToInputWeights");
3222 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3223 "forgetGateBias", "inputGateBias");
3224 }
3225
3226 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3227 bool allPeepholeWeightsPresentOrNot =
3228 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3229 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3230 || (!m_CellToInputWeights && !m_CellToForgetWeights
3231 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3232
3233 if (!allPeepholeWeightsPresentOrNot)
3234 {
3235 throw InvalidArgumentException(descriptorName +
3236 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3237 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3238 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3239 "appropriately.");
3240 }
3241
3242 if (m_Parameters.m_PeepholeEnabled)
3243 {
3244 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3245 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3246 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3247
3248 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3249 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3250 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3251 "cellToForgetWeight", "cellToOutputWeights");
3252
3253 if (!m_Parameters.m_CifgEnabled)
3254 {
3255 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3256 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3257 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3258 "cellToForgetWeights", "cellToInputWeights");
3259 }
3260 }
3261
3262 // Validate OPTIONAL params: Layer Norm Weights
3263 bool allLayerNormWeightsPresentOrNot =
3264 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3265 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3266 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3267 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3268
3269 if (!allLayerNormWeightsPresentOrNot)
3270 {
3271 throw InvalidArgumentException(descriptorName +
3272 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3273 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3274 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3275 "only be present when Layer Norm is enabled and CIFG is disabled. "
3276 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3277 }
3278
3279 if (m_Parameters.m_LayerNormEnabled)
3280 {
3281 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3282 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3283 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3284
3285 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3286 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3287 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3288 "forgetLayerNormWeights", "cellLayerNormWeights");
3289
3290 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3291 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3292 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3293 "forgetLayerNormWeights", "outputLayerNormWeights");
3294
3295 if (!m_Parameters.m_CifgEnabled)
3296 {
3297 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3298 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3299 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3300 "forgetLayerNormWeights", "inputLayerNormWeights");
3301 }
3302 }
3303
3304 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3305 bool correctProjectionTensorsPresent =
3306 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3307 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3308 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3309
3310 if (!correctProjectionTensorsPresent)
3311 {
3312 throw InvalidArgumentException(descriptorName +
3313 ": If projection is enabled, ProjectionWeights should be present and "
3314 "ProjectionBias is optional. If projection is disabled, neither "
3315 "ProjectionWeights nor ProjectionBias should be present.");
3316 }
3317
3318 if (m_Parameters.m_ProjectionEnabled)
3319 {
3320 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3321 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3322 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3323
3324 if (m_ProjectionBias)
3325 {
3326 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003327 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003328 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3329 }
3330
3331 }
3332 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3333 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3334 throw InvalidArgumentException(descriptorName +
3335 ": If projection is disabled, output quantization info (scale, offset) "
3336 "should match HiddenStateScale and HiddenStateZeroPoint.");
3337 }
3338
3339}
3340
James Conroy9c3cae82019-08-01 16:01:48 +01003341void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3342{
3343 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3344
3345 // Validate number of inputs/outputs
3346 ValidateNumInputs(workloadInfo, descriptorName, 3);
3347 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3348
3349 // Input/output tensor infos
3350 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3351 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3352 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3353
3354 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3355 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3356
3357 std::vector<DataType> inputOutputSupportedTypes =
3358 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003359 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003360 };
3361
3362 std::vector<DataType> cellStateSupportedTypes =
3363 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003364 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003365 };
3366
3367 std::vector<DataType> weightsSupportedTypes =
3368 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003369 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003370 };
3371
3372 std::vector<DataType> biasSupportedTypes =
3373 {
3374 DataType::Signed32
3375 };
3376
3377 // Validate types of input/output tensors
3378 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3379 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3380 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3381
3382 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3383 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3384
3385 // Validate matching types of input/output tensors
3386 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3387 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3388 "outputStateIn", "outputStateOut");
3389 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3390
3391 // Validate matching quantization info for input/output tensors
3392 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3393 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3394 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003395
James Conroy9c3cae82019-08-01 16:01:48 +01003396 // Infer number of batches, input size and output size from tensor dimensions
3397 const uint32_t numBatches = inputInfo.GetShape()[0];
3398 const uint32_t inputSize = inputInfo.GetShape()[1];
3399 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3400
3401 // Validate number of dimensions and number of elements for input/output tensors
3402 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3403 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3404 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3405 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3406 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3407
3408 // Validate number of dimensions and number of elements for weights tensors
3409 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3410 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3411 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3412
3413 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3414 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3415 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3416
3417 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3418 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3419 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3420
3421 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3422 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3423 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3424
3425 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3426 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3427 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3428
3429 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3430 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3431 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3432 " RecurrentToForgetWeights");
3433
3434 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3435 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3436 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3437
3438 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3439 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3440 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3441
3442 // Validate data types for weights tensors (all should match each other)
3443 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3444
3445 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3446 "inputToInputWeights", "inputToForgetWeights");
3447 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3448 "inputToInputWeights", "inputToCellWeights");
3449 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3450 "inputToInputWeights", "inputToOutputWeights");
3451
3452 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3453 "inputToInputWeights", "recurrentToInputWeights");
3454 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3455 "inputToInputWeights", "recurrentToForgeteights");
3456 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3457 "inputToInputWeights", "recurrentToCellWeights");
3458 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3459 "inputToInputWeights", "recurrentToOutputWeights");
3460
3461 // Validate matching quantization info for weight tensors (all should match each other)
3462 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3463 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3464 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3465 descriptorName, "inputToInputWeights", "inputToCellWeights");
3466 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3467 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3468
3469 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3470 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3471 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3472 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3473 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3474 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3475 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3476 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3477
3478 // Validate number of dimensions and number of elements in bias tensors
3479 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3480 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3481 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3482
3483 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3484 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3485 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3486
3487 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3488 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3489 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3490
3491 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3492 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3493 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3494
3495 // Validate data types for bias tensors (all should match each other)
3496 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3497
3498 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3499 "inputGateBias", "forgetGateBias");
3500 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3501 "inputGateBias", "cellBias");
3502 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3503 "inputGateBias", "outputGateBias");
3504
3505 // Validate bias tensor quantization info
3506 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3507 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3508 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3509 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3510}
3511
Kevin May868eb142019-09-04 17:29:31 +01003512void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3513{
3514 const std::string descriptorName{"AbsQueueDescriptor"};
3515
3516 ValidateNumInputs(workloadInfo, descriptorName, 1);
3517 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3518
3519 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3520 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3521
3522 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3523
3524 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003525 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003526 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003527 DataType::Float16,
3528 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003529 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003530 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003531 DataType::QSymmS16,
3532 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003533 };
Kevin May868eb142019-09-04 17:29:31 +01003534
3535 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3536 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3537}
3538
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003539void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3540{
3541 const std::string descriptorName{"SliceQueueDescriptor"};
3542
3543 ValidateNumInputs(workloadInfo, descriptorName, 1);
3544 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3545
3546 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3547 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3548
3549 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3550
3551 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3552 if (rank > 4)
3553 {
3554 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3555 }
3556
3557 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3558
3559 // Check if m_Begin and m_Size have the expected length
3560 if (m_Parameters.m_Begin.size() != rank)
3561 {
3562 throw InvalidArgumentException(descriptorName +
3563 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3564 }
3565 if (m_Parameters.m_Size.size() != rank)
3566 {
3567 throw InvalidArgumentException(descriptorName +
3568 ": Length of size descriptor must equal rank " + std::to_string(rank));
3569 }
3570
3571 // Check if the shape of the output tensor matches m_Size
3572 const TensorShape& outputShape = outputTensorInfo.GetShape();
3573 for (unsigned int i = 0u; i < rank; ++i)
3574 {
3575 if (m_Parameters.m_Size[i] != outputShape[i])
3576 {
3577 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3578 }
3579 }
3580
3581 // Check if the sum of begin offset and size in a given dimension
3582 // does not exceed the size of corresponding input
3583 const TensorShape& inputShape = inputTensorInfo.GetShape();
3584 for(unsigned int i = 0u; i < rank; ++i)
3585 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003586 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003587 {
3588 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3589 std::to_string(i) + " exceeds input size.");
3590 }
3591 }
3592}
3593
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003594void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3595{
3596 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3597
3598 ValidateNumInputs(workloadInfo, descriptorName, 1);
3599 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3600
3601 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3602 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3603
3604 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3605 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3606
3607 std::vector<DataType> supportedTypes =
3608 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003609 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003610 DataType::Float32,
3611 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003612 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003613 DataType::QAsymmU8,
3614 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003615 };
3616
3617 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3618 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3619
3620 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3621
3622 if (m_Parameters.m_BlockSize == 0)
3623 {
3624 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3625 }
3626
3627 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3628 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3629 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3630 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3631
3632 const TensorShape& outputShape = outputInfo.GetShape();
3633 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3634 {
3635 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3636 "must be divisible by block size.");
3637 }
3638
3639 const TensorShape& inputShape = inputInfo.GetShape();
3640 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3641 {
3642 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3643 "must be divisible by the square of block size." );
3644 }
3645}
3646
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003647void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3648{
3649 const std::string descriptorName{"ComparisonQueueDescriptor"};
3650
3651 ValidateNumInputs(workloadInfo, descriptorName, 2);
3652 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3653
3654 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3655 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3656 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3657
3658 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3659 inputTensorInfo1,
3660 outputTensorInfo,
3661 descriptorName,
3662 "input_0",
3663 "input_1");
3664
3665 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3666 {
3667 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3668 }
3669}
3670
josh minor4a3c6102020-01-06 16:40:46 -06003671void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3672{
3673 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3674
3675 ValidateNumInputs(workloadInfo, descriptorName, 1);
3676 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3677
3678 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3679 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3680
3681 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3682
3683 std::vector<DataType> supportedTypes =
3684 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003685 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003686 DataType::Float16,
3687 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003688 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003689 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003690 DataType::QSymmS16,
3691 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003692 };
3693
James Conroyaba90cd2020-11-06 16:28:18 +00003694 std::vector<DataType> logicalSupportedTypes =
3695 {
3696 DataType::Boolean
3697 };
3698
3699 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3700 {
3701 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3702 }
3703 else
3704 {
3705 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3706 }
3707
3708
josh minor4a3c6102020-01-06 16:40:46 -06003709 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3710}
3711
Finn Williams2605b232020-06-10 15:53:46 +01003712void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3713{
3714 const std::string descriptorName{"RankQueueDescriptor"};
3715
3716 ValidateNumInputs(workloadInfo, descriptorName, 1);
3717 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3718
3719 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3720 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3721
3722 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3723 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3724
3725 std::vector<DataType> supportedTypes =
3726 {
3727 DataType::BFloat16,
3728 DataType::Float16,
3729 DataType::Float32,
3730 DataType::QAsymmS8,
3731 DataType::QAsymmU8,
3732 DataType::QSymmS8,
3733 DataType::QSymmS16,
3734 DataType::Signed32
3735 };
3736
3737 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3738 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3739}
3740
James Conroyaba90cd2020-11-06 16:28:18 +00003741void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3742{
3743 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3744
3745 ValidateNumInputs(workloadInfo, descriptorName, 2);
3746 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3747
3748 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3749 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3750 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3751
3752 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3753 inputTensorInfo1,
3754 outputTensorInfo,
3755 descriptorName,
3756 "input_0",
3757 "input_1");
3758
3759 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3760 {
3761 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3762 }
3763
3764 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3765 {
3766 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3767 }
3768
3769 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3770 {
3771 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3772 }
3773}
3774
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003775void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3776{
3777 const std::string descriptorName{"ReduceQueueDescriptor"};
3778
3779 ValidateNumInputs(workloadInfo, descriptorName, 1);
3780 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3781
3782 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3783 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3784
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003785 std::vector<DataType> supportedTypes =
3786 {
3787 DataType::BFloat16,
3788 DataType::Float16,
3789 DataType::Float32,
3790 DataType::QAsymmS8,
3791 DataType::QAsymmU8,
3792 DataType::QSymmS16,
3793 DataType::Signed32
3794 };
3795
3796 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3797 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3798}
3799
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003800void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3801{
3802 // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm
3803
3804 const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"};
3805
3806 // check dimensions of all inputs and outputs
3807 if (workloadInfo.m_InputTensorInfos.size() != 3)
3808 {
3809 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
3810 }
Mike Kelly12994962022-04-21 11:57:09 +01003811 if (workloadInfo.m_OutputTensorInfos.size() != 3)
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003812 {
3813 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
3814 }
3815
3816 std::vector<DataType> supportedTypes =
3817 {
Mike Kelly12994962022-04-21 11:57:09 +01003818 DataType::Float32,
3819 DataType::QAsymmS8
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003820 };
3821
3822 // check for supported type of one input and match them with all the other input and output
3823 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
3824
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003825 // Making sure clipping parameters have valid values.
3826 // == 0 means no clipping
3827 // > 0 means clipping
3828 if (m_Parameters.m_ClippingThresCell < 0.0f)
3829 {
3830 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
3831 }
3832 if (m_Parameters.m_ClippingThresProj < 0.0f)
3833 {
3834 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
3835 }
3836
3837 unsigned int batchIndx = 0;
3838 unsigned int inputIndx = 1;
3839 uint32_t timeStep = 1;
3840 unsigned int timeIndx = 1;
3841 inputIndx = 2;
3842 if (m_Parameters.m_TimeMajor)
3843 {
3844 batchIndx = 1;
3845 timeIndx = 0;
3846
3847 }
3848 timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx];
3849
3850 // Inferring batch size, number of outputs and number of cells from the inputs.
3851 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx];
3852 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx];
3853 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
3854 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
3855 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
3856 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
3857
3858 // input tensor
3859 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input),
3860 descriptorName + " input_0");
3861 // outputStateInTensor
3862 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
3863 descriptorName + " input_1");
3864 // outputStateInTensor
3865 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
3866 descriptorName + " input_2");
3867
3868 // outputTensor
Mike Kelly12994962022-04-21 11:57:09 +01003869 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 3, (timeStep * n_batch * n_output),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003870 descriptorName + " output_0");
3871
3872 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
3873 if ( m_InputToInputWeights )
3874 {
3875 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
3876 (n_cell * n_input), "InputLayerNormWeights");
3877 }
3878
3879 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
3880 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
3881 (n_cell * n_input), "InputToForgetWeights");
3882
3883 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
3884 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
3885 (n_cell * n_input), "InputToCellWeights");
3886
3887 if ( m_RecurrentToInputWeights )
3888 {
3889 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
3890 (n_cell * n_output), "RecurrentToInputWeights");
3891 }
3892
3893 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
3894 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
3895 (n_cell * n_output), "RecurrentToForgetWeights");
3896
3897 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
3898 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
3899 (n_cell * n_output), "RecurrentToCellWeights");
3900
3901 // Make sure the input-gate's parameters are either both present (regular
3902 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
3903 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
3904 !m_Parameters.m_CifgEnabled) ||
3905 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3906 m_Parameters.m_CifgEnabled));
3907 if (!cifg_weights_all_or_none)
3908 {
3909 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
3910 "RecurrentToInputWeights must either both be present (regular LSTM) "
3911 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
3912 "accordingly.");
3913 }
3914
3915 if ( m_CellToInputWeights )
3916 {
3917 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
3918 n_cell, "CellToInputWeights");
3919 }
3920 if ( m_CellToForgetWeights )
3921 {
3922 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
3923 n_cell, "CellToForgetWeights");
3924 }
3925 if ( m_CellToOutputWeights )
3926 {
3927 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
3928 n_cell, "CellToOutputWeights");
3929 }
3930
3931 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
3932 bool peephole_weights_all_or_none =
3933 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3934 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3935 || ( !m_CellToInputWeights && !m_CellToForgetWeights
3936 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3937 if (!peephole_weights_all_or_none)
3938 {
3939 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
3940 }
3941
3942 // Make sure the input gate bias is present only when not a CIFG-LSTM.
3943 if (m_Parameters.m_CifgEnabled)
3944 {
3945 if (m_InputGateBias)
3946 {
3947 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
3948 }
3949 }
3950 else
3951 {
3952 if (!m_InputGateBias)
3953 {
3954 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
3955 "must be present.");
3956 }
3957 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
3958 n_cell, "InputGateBias");
3959 }
3960
3961 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
3962 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
3963
3964 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
3965 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
3966
3967 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
3968 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
3969
3970 if (m_ProjectionWeights)
3971 {
3972 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
3973 (n_cell * n_output), "ProjectionWeights");
3974 }
3975 if (m_ProjectionBias)
3976 {
3977 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
3978 }
3979
3980 // Making sure the projection tensors are consistent:
3981 // 1) If projection weight is not present, then projection bias should not be
3982 // present.
3983 // 2) If projection weight is present, then projection bias is optional.
3984 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
3985 !m_Parameters.m_ProjectionEnabled)
3986 || (m_ProjectionWeights && !m_ProjectionBias &&
3987 m_Parameters.m_ProjectionEnabled)
3988 || (m_ProjectionWeights && m_ProjectionBias &&
3989 m_Parameters.m_ProjectionEnabled));
3990 if (!projecton_tensors_consistent)
3991 {
3992 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
3993 }
3994
3995 // The four layer normalization weights either all have values or none of them have values. Additionally, if
3996 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
3997 // either all have values or none of them have values. Layer normalization is used when the values of all the
3998 // layer normalization weights are present
3999 if (m_InputLayerNormWeights)
4000 {
4001 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
4002 }
4003 if (m_ForgetLayerNormWeights)
4004 {
4005 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4006 }
4007 if (m_CellLayerNormWeights)
4008 {
4009 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4010 }
4011 if (m_OutputLayerNormWeights)
4012 {
4013 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4014 }
4015
4016 if (m_Parameters.m_LayerNormEnabled)
4017 {
4018 if (!m_Parameters.m_CifgEnabled)
4019 {
4020 if (!m_InputLayerNormWeights)
4021 {
4022 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
4023 "disabled but InputLayerNormWeights are not present");
4024 }
4025 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
4026 1, n_cell, "InputLayerNormWeights");
4027 }
4028 else if (m_InputLayerNormWeights)
4029 {
4030 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
4031 "enabled");
4032 }
4033
4034 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
4035 "ForgetLayerNormWeights");
4036 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4037
4038 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
4039 "OutputLayerNormWeights");
4040 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4041
4042 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
4043 "CellLayerNormWeights");
4044 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4045 }
4046 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
4047 {
4048 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
4049 "normalisation weights are present.");
4050 }
4051}
4052
Samuel Yap6b478092022-07-06 15:36:03 +01004053void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4054{
4055 const std::string descriptorName{"BatchMatMulDescriptor"};
4056
4057 ValidateNumInputs(workloadInfo, descriptorName, 2);
4058 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4059
4060 // Inputs must be: both 2D+
4061 // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
4062 // axes N and I must be the same size
4063
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004064 const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
4065 const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
4066 const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
4067 // Output info has already been inferred
Samuel Yap6b478092022-07-06 15:36:03 +01004068
4069 std::vector<DataType> supportedTypes =
4070 {
4071 DataType::BFloat16,
4072 DataType::Float16,
4073 DataType::Float32,
4074 DataType::QAsymmS8,
4075 DataType::QAsymmU8,
4076 DataType::QSymmS16
4077 };
4078
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004079 ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
4080 ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
4081 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
Samuel Yap6b478092022-07-06 15:36:03 +01004082
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004083 if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
4084 (inputYInfoBeforeParams.GetNumDimensions() < 2))
Samuel Yap6b478092022-07-06 15:36:03 +01004085 {
4086 throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
4087 }
4088
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004089 TensorInfo inputXInfoAfterParams;
4090 TensorInfo inputYInfoAfterParams;
4091
4092 if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) ||
4093 (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY))
Samuel Yap6b478092022-07-06 15:36:03 +01004094 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004095 throw InvalidArgumentException(descriptorName +
4096 ": Invalid descriptor parameters - Transpose and Adjoint "
4097 "cannot both be true for a given input tensor.");
4098 }
4099 if(m_Parameters.m_TransposeX)
4100 {
4101 inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
4102 BatchMatMulDescriptor::GetPermuteVec(
4103 m_Parameters.m_DataLayoutX,
4104 inputXInfoBeforeParams.GetShape()));
4105 }
4106 else if(m_Parameters.m_AdjointX)
4107 {
4108 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4109 inputXInfoBeforeParams.GetShape());
4110 if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
4111 inputXInfoBeforeParams.GetShape()[axesToMul.second])
Samuel Yap6b478092022-07-06 15:36:03 +01004112 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004113 throw InvalidArgumentException(descriptorName +
4114 ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
Samuel Yap6b478092022-07-06 15:36:03 +01004115 }
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004116 // Shape remains the same as it's square
4117 inputXInfoAfterParams = inputXInfoBeforeParams;
4118 }
4119 else
4120 {
4121 inputXInfoAfterParams = inputXInfoBeforeParams;
Samuel Yap6b478092022-07-06 15:36:03 +01004122 }
4123
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004124 if(m_Parameters.m_TransposeY)
Samuel Yap6b478092022-07-06 15:36:03 +01004125 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004126 inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
4127 BatchMatMulDescriptor::GetPermuteVec(
4128 m_Parameters.m_DataLayoutY,
4129 inputYInfoBeforeParams.GetShape()));
4130 }
4131 else if(m_Parameters.m_AdjointY)
4132 {
4133 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4134 inputYInfoBeforeParams.GetShape());
4135 if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
4136 inputYInfoBeforeParams.GetShape()[axesToMul.second])
Samuel Yap6b478092022-07-06 15:36:03 +01004137 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004138 throw InvalidArgumentException(descriptorName +
4139 ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
Samuel Yap6b478092022-07-06 15:36:03 +01004140 }
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004141 // Shape remains the same as it's square
4142 inputYInfoAfterParams = inputYInfoBeforeParams;
4143 }
4144 else
4145 {
4146 inputYInfoAfterParams = inputYInfoBeforeParams;
Samuel Yap6b478092022-07-06 15:36:03 +01004147 }
4148
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004149 switch(m_Parameters.m_DataLayoutX)
4150 {
4151 case DataLayout::NCDHW:
4152 case DataLayout::NDHWC:
4153 if(inputXInfoAfterParams.GetNumDimensions() < 3)
4154 {
4155 throw InvalidArgumentException(descriptorName +
4156 ": Input tensor X does not have the correct "
4157 "number of dimensions for the Data Layout that it has been assigned.");
4158 }
4159 break;
4160 case DataLayout::NCHW:
4161 case DataLayout::NHWC:
4162 default:
4163 break;
4164 }
Samuel Yap6b478092022-07-06 15:36:03 +01004165
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004166 switch(m_Parameters.m_DataLayoutY)
4167 {
4168 case DataLayout::NCDHW:
4169 case DataLayout::NDHWC:
4170 if(inputYInfoAfterParams.GetNumDimensions() < 3)
4171 {
4172 throw InvalidArgumentException(descriptorName +
4173 ": Input tensor Y does not have the correct "
4174 "number of dimensions for the Data Layout that it has been assigned.");
4175 }
4176 break;
4177 case DataLayout::NCHW:
4178 case DataLayout::NHWC:
4179 default:
4180 break;
4181 }
4182
4183 auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4184 inputXInfoAfterParams.GetShape());
4185 auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4186 inputXInfoBeforeParams.GetShape());
4187
4188 if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
4189 != inputYInfoAfterParams.GetShape()[axesYToMul.first])
Samuel Yap6b478092022-07-06 15:36:03 +01004190 {
4191 throw InvalidArgumentException(descriptorName +
4192 ": The final axis of input tensor X must be the same size as "
4193 "the second last axis of input tensor Y.");
4194 }
4195
Samuel Yap6b478092022-07-06 15:36:03 +01004196 { // Separate scope so we don't pollute the rest of the scope with our temp variables
4197 // e.g. NHWC isnt compatible with NCHW as of now
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004198 DataLayout xLayout = m_Parameters.m_DataLayoutX;
4199 DataLayout yLayout = m_Parameters.m_DataLayoutY;
Samuel Yap6b478092022-07-06 15:36:03 +01004200
4201 if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
4202 {
4203 if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
4204 {
4205 throw InvalidArgumentException(descriptorName +
4206 ": Invalid input tensor data layout combination.");
4207 }
4208 }
4209 if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
4210 {
4211 if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
4212 {
4213 throw InvalidArgumentException(descriptorName +
4214 ": Invalid input tensor data layout combination.");
4215 }
4216 }
4217 }
4218
4219 // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004220 unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
4221 inputYInfoAfterParams.GetNumDimensions());
Samuel Yap6b478092022-07-06 15:36:03 +01004222 if(outputTensorDimSize-2 > 0)
4223 {
4224 TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4225 DataType::Float32);
4226 TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4227 DataType::Float32);
4228 TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4229 DataType::Float32);
4230
4231 auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
4232 {
4233 auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4234
4235 for(unsigned int i = 0; i < sizeDiff; i++)
4236 {
4237 axisIndices.insert(axisIndices.begin(), 1);
4238 }
4239
4240 for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4241 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004242 ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
Samuel Yap6b478092022-07-06 15:36:03 +01004243 }
4244 };
4245
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004246 auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX,
4247 inputXInfoAfterParams.GetShape());
4248 auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY,
4249 inputYInfoAfterParams.GetShape());
4250
4251 doAxisExtension(axesXNotMul, tiXNotMul);
4252 doAxisExtension(axesYNotMul, tiYNotMul);
Samuel Yap6b478092022-07-06 15:36:03 +01004253
4254 for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
4255 {
4256 tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
4257 tiYNotMul.GetShape()[i]);
4258 }
4259
4260 ValidateBroadcastTensorShapesMatch(tiXNotMul,
4261 tiYNotMul,
4262 tiOutNotMul,
4263 descriptorName,
4264 "input_X",
4265 "input_Y");
4266 }
Samuel Yap6b478092022-07-06 15:36:03 +01004267}
4268
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01004269
mathad01df9a3222021-04-28 11:42:57 +01004270} // namespace armnn