blob: 289f780fbad14844cdb893115013605ab7af1b48 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Colm Donelan0c479742021-12-10 12:43:54 +00006#include <armnn/backends/TensorHandle.hpp>
7#include <armnn/backends/WorkloadData.hpp>
8#include <armnn/backends/WorkloadInfo.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00009#include <armnnUtils/DataLayoutIndexed.hpp>
10#include <armnnUtils/TensorUtils.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
mathad01df9a3222021-04-28 11:42:57 +010012#include <armnn/Logging.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000013
telsoa014fcda012018-03-09 14:13:49 +000014#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000015#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000016#include <string>
17#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000018
James Ward47fce872020-09-10 11:57:28 +010019#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000020
Matteo Martincigh21350152018-11-28 16:22:22 +000021using namespace armnnUtils;
22
telsoa014fcda012018-03-09 14:13:49 +000023namespace armnn
24{
25
26//---------------------------------------------------------------
27DataType GetBiasDataType(DataType inputDataType)
28{
29 switch (inputDataType)
30 {
telsoa01c577f2c2018-08-31 09:22:23 +010031 case DataType::Float16:
32 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000033 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000034 case DataType::Float32:
35 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000036 case DataType::QAsymmS8:
37 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000038 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000039 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000040 case DataType::QSymmS8:
41 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000042 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010043 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000044 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010045 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000046 return DataType::Float32;
47 }
48}
49
50namespace
51{
52
53//---------------------------------------------------------------
54//android ndk does not support std::to_string function.
55template <typename T>
56std::string to_string(T value)
57{
58 std::ostringstream os;
59 os << value;
60 return os.str();
61}
62
63//---------------------------------------------------------------
64void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
65{
66 if (!ptr)
67 {
68 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
69 paramName + " parameter must be set.");
70 }
71}
72
73//---------------------------------------------------------------
74void ValidateTensorShapesMatch(const TensorInfo& first,
75 const TensorInfo& second,
76 std::string const& descName,
77 std::string const& firstName,
78 std::string const& secondName)
79{
80 if (first.GetShape() != second.GetShape())
81 {
82 throw InvalidArgumentException(descName + ": "
83 + firstName + " & " + secondName + " must have identical shapes");
84 }
85}
86
87//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010088void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000089{
Sadik Armaganeff363d2019-04-05 15:25:46 +010090 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000091 {
92 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010093 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000094 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
95 }
96}
97
98//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010099void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000100{
Sadik Armaganeff363d2019-04-05 15:25:46 +0100101 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000102 {
103 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100104 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000105 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
106 }
107}
108
109//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100110void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000111 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100112 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000113 std::string const& tensorName)
114{
115 if (tensor.GetNumDimensions() != numDimensions)
116 {
117 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
118 to_string(tensor.GetNumDimensions()) + " dimensions for " +
119 tensorName + " tensor.");
120 }
121}
122
123//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100124void ValidateTensorNumElements(const TensorInfo& tensor,
125 std::string const& descName,
126 unsigned int numElements,
127 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100128{
129 if (tensor.GetNumElements() != numElements)
130 {
131 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100132 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100133 tensorName + " tensor.");
134 }
135}
136
137//---------------------------------------------------------------
138void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100139 unsigned int numDimension,
140 unsigned int numElements,
141 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100142{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100143 const std::string functionName{"ValidateTensorNumDimNumElem"};
144 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
145 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100146}
147
148//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000149void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
150 const std::string& descName, std::string const& tensorName)
151{
152 if (tensor.GetDataType() != dataType)
153 {
154 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
155 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
156 }
157}
158
Derek Lambertid466a542020-01-22 15:37:29 +0000159void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
160{
Jan Eilers1b2654f2021-09-24 15:45:46 +0100161 if (tensor.GetDataType() != DataType::QSymmS8)
Derek Lambertid466a542020-01-22 15:37:29 +0000162 {
163 throw InvalidArgumentException(descName +
164 ": Expected data type which supports per-axis quantization scheme but got " +
165 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
166 }
Derek Lambertid466a542020-01-22 15:37:29 +0000167}
168
telsoa014fcda012018-03-09 14:13:49 +0000169//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100170void ValidateTensorQuantizationSpace(const TensorInfo& first,
171 const TensorInfo& second,
172 const std::string& descName,
173 std::string const& firstName,
174 std::string const& secondName)
175{
176 if (!first.IsQuantized() ||
177 !second.IsQuantized())
178 {
179 // Not a quantized type, ignore the validation
180 return;
181 }
182
183 DataType firstDataType = first.GetDataType();
184 DataType secondDataType = second.GetDataType();
185
186 if (firstDataType != secondDataType)
187 {
188 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
189 " must be of the same quantized type, " +
190 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
191 secondName + " is " + GetDataTypeName(secondDataType));
192 }
193
194 if (!first.IsTypeSpaceMatch(second))
195 {
196 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
197 " must have the same quantization space, " +
198 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
199 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
200 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
201 " and scale " + to_string(second.GetQuantizationScale()));
202 }
203}
204
205//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100206void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
207 const TensorInfo& inputTensorInfo,
208 const TensorInfo& weightsTensorInfo,
209 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000210{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000211 // Helper lambda function to validate a single bias quantization scale value
212 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
213 {
mathad01df9a3222021-04-28 11:42:57 +0100214 constexpr float tolerance = 0.0001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000215 if (std::abs(biasScale - expectedScale) > tolerance)
216 {
217 // Print the float values with extra precision to see very small differences
mathad01df9a3222021-04-28 11:42:57 +0100218 ARMNN_LOG(warning) << std::setprecision(6) << descName << ": Expected " << expectedScale <<
219 " for bias quantization scale (product of input and weight scales), but got " <<
220 biasScale << ". Using scale provided.";
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000221 }
222 };
223
telsoa014fcda012018-03-09 14:13:49 +0000224 if (biasTensor.GetQuantizationOffset() != 0)
225 {
226 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
227 to_string(biasTensor.GetQuantizationOffset()));
228 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000229
James Conroy8502ade2020-11-12 19:26:29 +0000230 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000231 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000232 // Validate per-axis quantization scales
233 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
234 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
235
236 if (weightScales.size() != biasScales.size())
237 {
238 std::stringstream msg;
James Conroy8502ade2020-11-12 19:26:29 +0000239 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
240 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
241 << ", biases=" << biasScales.size();
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000242 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
243 }
244
245 for (size_t i = 0ul; i < biasScales.size(); ++i)
246 {
247 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
248 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
249 }
250 }
251 else
252 {
253 // Validate per-tensor quantization scale
254 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
255 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000256 }
257}
258
259//---------------------------------------------------------------
260void ValidateTensors(const std::vector<ITensorHandle*>& vec,
261 unsigned int numExpected,
262 const std::string& descName,
263 const std::string& varName)
264{
265 if (vec.empty() && numExpected > 0)
266 {
267 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
268 }
269
270 for (unsigned int i = 0; i < numExpected; ++i)
271 {
272 if (!vec[i])
273 {
274 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
275 }
276 }
277}
278
279//---------------------------------------------------------------
280void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
281 const TensorInfo& second,
282 const TensorInfo& output,
283 std::string const& descName,
284 std::string const& firstName,
285 std::string const& secondName)
286{
287 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
288 // broadcasted.
289 if (first.GetNumDimensions() != second.GetNumDimensions())
290 {
291 throw InvalidArgumentException(descName + ": Tensors "
292 + firstName + " & " + secondName
293 + " must have the same number of dimensions in order to be broadcasted");
294 }
295 uint32_t numDims = first.GetNumDimensions();
296 std::vector<uint32_t> outputDims(numDims, 0u);
297 for (uint32_t i = 0; i < numDims; i++)
298 {
299 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
300 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
301 if (dimsNotEqual && dimsNotOne)
302 {
303 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
304 }
305 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
306 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100307 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000308 if (broadcastShape != output.GetShape())
309 {
310 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
311 + firstName + " & " + secondName
312 + " does not match the output shape");
313 }
314}
315
316//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100317void ValidateDataTypes(const TensorInfo& info,
318 const std::vector<armnn::DataType>& supportedTypes,
319 std::string const& descName)
320{
321 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
322 if (iterator == supportedTypes.end())
323 {
324 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
325 }
326}
327
James Conroy4d1ff582019-06-10 17:06:39 +0100328//---------------------------------------------------------------
329void ValidateTensorDataTypesMatch(const TensorInfo& first,
330 const TensorInfo& second,
331 std::string const& descName,
332 std::string const& firstName,
333 std::string const& secondName)
334{
335 if (first.GetDataType() != second.GetDataType())
336 {
337 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
338 " must have identical data types.");
339 }
340}
341
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100342//---------------------------------------------------------------
343void ValidateTensorNumElementsMatch(const TensorInfo& first,
344 const TensorInfo& second,
345 std::string const& descName,
346 std::string const& firstName,
347 std::string const& secondName)
348{
349 if (first.GetNumElements() != second.GetNumElements())
350 {
351 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
352 " must have the same number of elements.");
353 }
354}
355
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000356void ValidateWeightDataType(const TensorInfo& inputInfo,
357 const TensorInfo& weightInfo,
358 const std::string& descName)
359{
360 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000361 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000362 {
363 const std::vector<DataType> validTypes =
364 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000365 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100366 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +0100367 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000368 };
369
370 ValidateDataTypes(weightInfo, validTypes, descName);
371 }
372 else
373 {
374 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
375 }
376}
377
378void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
379 const std::string& descName,
380 const std::string& tensorName)
381{
382 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
383 if (!quantizationDim.has_value())
384 {
James Ward47fce872020-09-10 11:57:28 +0100385 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
386 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000387 }
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000388}
389
390void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
391 const std::string& descName,
392 const std::string& tensorName)
393{
394 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
395 if (quantizationOffset != 0)
396 {
James Ward47fce872020-09-10 11:57:28 +0100397 throw InvalidArgumentException(fmt::format(
398 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
399 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000400 }
401}
402
403void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
404 const TensorInfo& outputInfo,
405 const TensorInfo& weightInfo,
406 const Optional<TensorInfo>& optionalBiasInfo,
407 const std::string& descName)
408{
409 if (weightInfo.HasPerAxisQuantization())
410 {
411 const DataType inputDataType = inputInfo.GetDataType();
412 const DataType outputDataType = outputInfo.GetDataType();
413
Keith Davis0c2eeac2020-02-11 16:51:50 +0000414 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000415
416 if (!canHavePerAxisQuantization)
417 {
James Ward47fce872020-09-10 11:57:28 +0100418 throw InvalidArgumentException(fmt::format(
419 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
420 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000421 }
422
Derek Lambertid466a542020-01-22 15:37:29 +0000423
424 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000425 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
426 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
427
428 if (optionalBiasInfo.has_value())
429 {
430 const TensorInfo& biasInfo = optionalBiasInfo.value();
431 if (!biasInfo.HasPerAxisQuantization())
432 {
James Ward47fce872020-09-10 11:57:28 +0100433 throw InvalidArgumentException(fmt::format(
434 "{}: Per-axis quantization parameters not set on bias tensor, "
435 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000436 }
437
438 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
439 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
440 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
441 }
442 }
443}
444
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100445} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000446
447void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
448 unsigned int numExpectedIn, unsigned int numExpectedOut) const
449{
450 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
451 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
452}
453
454//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100455void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
456{
457 const std::string descriptorName{"MapQueueDescriptor"};
458
459 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100460 ValidateNumOutputs(workloadInfo, descriptorName, 0);
461
462 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
463 {
464 if (!m_Inputs[i])
465 {
466 throw InvalidArgumentException(
467 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
468 }
469 }
470}
471
472//---------------------------------------------------------------
473void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
474{
475 const std::string descriptorName{"UnmapQueueDescriptor"};
476
477 ValidateNumInputs(workloadInfo, descriptorName, 1);
478 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100479
480 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
481 {
482 if (!m_Inputs[i])
483 {
484 throw InvalidArgumentException(
485 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
486 }
487 }
488}
489
490//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000491void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
492{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100493 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000494
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100495 ValidateNumInputs(workloadInfo, descriptorName, 1);
496 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000497
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100498 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
499 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
500
501 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
502 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000503
504 if (m_Inputs.size() != m_Outputs.size())
505 {
James Ward47fce872020-09-10 11:57:28 +0100506 throw InvalidArgumentException(fmt::format(
507 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
508 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000509 }
510
511 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
512 {
513 if (!m_Inputs[i])
514 {
James Ward47fce872020-09-10 11:57:28 +0100515 throw InvalidArgumentException(fmt::format(
516 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000517 }
518
519 if (!m_Outputs[i])
520 {
James Ward47fce872020-09-10 11:57:28 +0100521 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000522 }
523 }
524}
525
Derek Lambertif674aa02019-08-01 15:56:25 +0100526//---------------------------------------------------------------
527void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
528{
529 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
530 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
531
532 if (workloadInfo.m_InputTensorInfos.size() != 1)
533 {
James Ward47fce872020-09-10 11:57:28 +0100534 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
535 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100536
537 }
538
539 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
540 {
James Ward47fce872020-09-10 11:57:28 +0100541 throw InvalidArgumentException(fmt::format(
542 "Number of input infos ({0}) does not match the number of output infos ({1})",
543 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100544 }
545
546 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
547 {
548 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
549 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
550 {
James Ward47fce872020-09-10 11:57:28 +0100551 throw InvalidArgumentException(fmt::format(
552 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100553 }
554 }
555
556 if (m_Inputs.size() != 1)
557 {
James Ward47fce872020-09-10 11:57:28 +0100558 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100559 }
560
561 if (m_Inputs.size() != m_Outputs.size())
562 {
James Ward47fce872020-09-10 11:57:28 +0100563 throw InvalidArgumentException(fmt::format(
564 "Number of inputs ({0}) does not match the number of outputs ({1})",
565 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100566 }
567
568 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
569 {
570 if (!m_Inputs[i])
571 {
James Ward47fce872020-09-10 11:57:28 +0100572 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100573 }
574
575 if (!m_Outputs[i])
576 {
James Ward47fce872020-09-10 11:57:28 +0100577 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100578 }
579 }
580}
581
582//---------------------------------------------------------------
583void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
584{
585 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
Derek Lambertif674aa02019-08-01 15:56:25 +0100586
Derek Lambertif674aa02019-08-01 15:56:25 +0100587 if (m_Inputs.size() != 1)
588 {
James Ward47fce872020-09-10 11:57:28 +0100589 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100590 }
591
592 if (m_Outputs.size() != 0)
593 {
James Ward47fce872020-09-10 11:57:28 +0100594 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100595 }
596
597 if (!m_Inputs[0])
598 {
James Ward47fce872020-09-10 11:57:28 +0100599 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100600 }
601}
602
603//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000604void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
605{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100606 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100607
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100608 ValidateNumInputs(workloadInfo, descriptorName, 1);
609 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100610
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100611 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
612 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100613
614 std::vector<DataType> supportedTypes =
615 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000616 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100617 DataType::Float16,
618 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000619 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000620 DataType::QAsymmU8,
621 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100622 };
623
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100624 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
625 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
626 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000627}
628
Nikhil Rajee391d52019-09-05 17:50:44 +0100629void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
630{
631 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
632
633 ValidateNumInputs(workloadInfo, descriptorName, 1);
634 ValidateNumOutputs(workloadInfo, descriptorName, 1);
635
636 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
637 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
638
Inki Daed4619e22020-09-10 15:33:54 +0900639 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
640 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100641 {
Inki Daed4619e22020-09-10 15:33:54 +0900642 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100643 }
644
James Conroyd47a0642019-09-17 14:22:06 +0100645 std::vector<DataType> supportedInputTypes =
646 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000647 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100648 DataType::Float16,
649 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100650 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000651 DataType::QAsymmU8,
652 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900653 DataType::Signed32,
654 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100655 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100656
James Conroyd47a0642019-09-17 14:22:06 +0100657 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100658
659 auto inputShape = inputTensorInfo.GetShape();
660 auto outputShape = outputTensorInfo.GetShape();
661
662 auto inputNumDimensions = inputShape.GetNumDimensions();
663 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
664
665 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
666
667 // 1D input shape results in scalar output shape
668 if (inputShape.GetNumDimensions() == 1)
669 {
670 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
671 {
672 throw InvalidArgumentException(descriptorName + outputShapeError);
673 }
674 }
675 else
676 {
677 for (unsigned int i = 0; i < unsignedAxis; ++i)
678 {
679 if (outputShape[i] != inputShape[i])
680 {
681 throw InvalidArgumentException(descriptorName + outputShapeError);
682 }
683 }
684
685 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
686 {
687 if (outputShape[i - 1] != inputShape[i])
688 {
689 throw InvalidArgumentException(descriptorName + outputShapeError);
690 }
691 }
692 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100693}
694
mathad01b392e982021-04-07 12:07:30 +0100695void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
696{
697 const std::string descriptorName{"CastQueueDescriptor"};
698
699 ValidateNumInputs(workloadInfo, descriptorName, 1);
700 ValidateNumOutputs(workloadInfo, descriptorName, 1);
701
702 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
703 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
704
705 std::vector<DataType> supportedTypes =
706 {
707 DataType::BFloat16,
708 DataType::Float16,
709 DataType::Float32,
710 DataType::QAsymmS8,
711 DataType::QAsymmU8,
712 DataType::QSymmS8,
713 DataType::QSymmS16,
714 DataType::Signed32,
715 DataType::Signed64
716 };
717
718 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
719 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
720}
721
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100722void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
723{
724 const std::string descriptorName{"SoftmaxQueueDescriptor"};
725
726 ValidateNumInputs(workloadInfo, descriptorName, 1);
727 ValidateNumOutputs(workloadInfo, descriptorName, 1);
728
729 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
730 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
731
732 std::vector<DataType> supportedTypes =
733 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000734 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100735 DataType::Float16,
736 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000737 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000738 DataType::QAsymmU8,
739 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100740 };
741
742 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
743 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
744 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
745}
746
telsoa014fcda012018-03-09 14:13:49 +0000747void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
748{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100749 const std::string descriptorName{"SplitterQueueDescriptor"};
750
751 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000752
Ruomei Yan25339c32019-05-28 16:48:20 +0100753 // Check the supported data types
754 std::vector<DataType> supportedTypes =
755 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000756 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100757 DataType::Float32,
758 DataType::Float16,
759 DataType::Boolean,
760 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100761 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000762 DataType::QAsymmU8,
763 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100764 };
765
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100766 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
767 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100768 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100769 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
770 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
771
772 const std::string outputName = "output_" + std::to_string(i);
773 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100774 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100775
telsoa014fcda012018-03-09 14:13:49 +0000776 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
777 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100778 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000779 }
780
781 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
782 {
783 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100784 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000785 "has to match number of workloadInfo.m_OutputTensorInfos. "
786 "Number of windows: " +
787 to_string(m_ViewOrigins.size()) +
788 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
789 }
790
telsoa01c577f2c2018-08-31 09:22:23 +0100791 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000792 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
793 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
794 {
telsoa01c577f2c2018-08-31 09:22:23 +0100795 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000796 ViewOrigin const& e = m_ViewOrigins[w];
797 if (e.m_Origin.size() != inputDims)
798 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100799 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000800 "have the same dimensionality as the input tensor. "
801 "Window origin (index: " +
802 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
803 " dimensions, the input "
804 "tensor has " +
805 to_string(inputDims) + " dimensions.");
806 }
807 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
808 {
809 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
810 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
811 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100812 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000813 "be smaller or equal than the size of the input in that coord.");
814 }
815 }
816 }
817}
818
Jim Flynne242f2d2019-05-22 14:24:13 +0100819void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000820{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100821 const std::string descriptorName{"ConcatQueueDescriptor"};
822
823 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000824
825 if (m_Inputs.size() <= 0)
826 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100827 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000828 }
829 if (m_Outputs.size() <= 0)
830 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100831 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000832 }
833
834 if (workloadInfo.m_InputTensorInfos.size() <= 0)
835 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100836 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000837 }
838 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
839 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100840 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000841 }
842
Nikhil Raj8599a412018-11-19 14:51:07 +0000843 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
844 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100845 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000846 }
847
848 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
849 {
850 return;
851 }
852
telsoa014fcda012018-03-09 14:13:49 +0000853 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
854 {
855 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100856 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000857 "has to match number of workloadInfo.m_InputTensorInfos. "
858 "Number of windows: " +
859 to_string(m_ViewOrigins.size()) +
860 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
861 }
862
telsoa01c577f2c2018-08-31 09:22:23 +0100863 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000864 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
865 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
866 {
telsoa01c577f2c2018-08-31 09:22:23 +0100867 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000868 ViewOrigin const& e = m_ViewOrigins[w];
869 if (e.m_Origin.size() != outputDims)
870 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100871 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000872 "have the same dimensionality as the output tensor. "
873 "Window origin (index: " +
874 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
875 " dimensions, the output "
876 "tensor has " +
877 to_string(outputDims) + " dimensions.");
878 }
telsoa01c577f2c2018-08-31 09:22:23 +0100879 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000880 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
881 {
882 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
883 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
884 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100885 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000886 "be smaller or equal than the size of the output in that coord.");
887 }
888 }
889 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100890
891 // Check the supported data types
892 std::vector<DataType> supportedTypes =
893 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000894 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100895 DataType::Float32,
896 DataType::Float16,
897 DataType::Boolean,
898 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100899 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000900 DataType::QAsymmU8,
901 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100902 };
903
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100904 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
905 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100906 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100907 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
908 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
909
910 const std::string inputName = "input_" + std::to_string(i);
911 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100912 }
telsoa014fcda012018-03-09 14:13:49 +0000913}
914
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100915void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
916{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100917 const std::string descriptorName{"StackQueueDescriptor"};
918
919 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100920
921 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
922 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100923 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100924 }
925
926 // All inputs must have the same shape, which is defined in parameters
927 const TensorShape& inputShape = m_Parameters.m_InputShape;
928 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
929 {
930 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
931 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100932 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100933 }
934 }
935
Matthew Jacksondba634f2019-08-15 15:14:18 +0100936 if (inputShape.GetNumDimensions() > 4)
937 {
938 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
939 }
940
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100941 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
942 // since the output tensor has an additional dimension.
943 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
944 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100945 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100946 "than the number of input dimensions.");
947 }
948
949 // Output shape must be as inferred from the input shape
950 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
951 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
952 {
953 if (outputShape[i] != inputShape[i])
954 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100955 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100956 "match shape inferred from input tensor.");
957 }
958 }
959
960 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
961 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100962 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100963 "match shape inferred from input tensor.");
964 }
965
966 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
967 {
968 if (outputShape[i] != inputShape[i-1])
969 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100970 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100971 "match shape inferred from input tensor.");
972 }
973 }
974
Matthew Jacksondba634f2019-08-15 15:14:18 +0100975 if (outputShape.GetNumDimensions() > 5)
976 {
977 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
978 }
979
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100980 // Check the supported data types
981 std::vector<DataType> supportedTypes =
982 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000983 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100984 DataType::Float32,
985 DataType::Float16,
986 DataType::Boolean,
987 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100988 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000989 DataType::QAsymmU8,
990 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100991 };
992
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100993 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100994
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100995 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100996 {
997 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
998 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100999 descriptorName,
1000 "input_0",
1001 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001002 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001003
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001004 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1005 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001006 descriptorName,
1007 "input_0",
1008 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001009}
1010
Ryan OSheaec6c6802020-06-05 17:17:06 +01001011void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1012{
1013 const std::string descriptorName{"FillQueueDescriptor"};
1014
1015 ValidateNumInputs(workloadInfo, descriptorName, 1);
1016 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1017
1018 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1019 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1020
1021 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1022
1023 std::vector<DataType> supportedTypes =
1024 {
1025 DataType::BFloat16,
1026 DataType::Float32,
1027 DataType::Float16,
1028 DataType::Signed32
1029 };
1030
1031 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1032}
1033
telsoa014fcda012018-03-09 14:13:49 +00001034void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1035{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001036 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001037
Matthew Sloyan81beae32021-07-13 19:46:11 +01001038 uint32_t numInputs = 2;
1039 if (m_Parameters.m_BiasEnabled)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001040 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001041 numInputs = 3;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001042 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001043
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001044 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001045 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1046
1047 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1048 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1049
1050 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1051
1052 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001053 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001054 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001055 }
1056
Matthew Sloyan81beae32021-07-13 19:46:11 +01001057 TensorInfo weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001058 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001059
1060 if (m_Parameters.m_BiasEnabled)
1061 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001062 TensorInfo biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
telsoa01c577f2c2018-08-31 09:22:23 +01001063 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001064 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001065 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1066 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001067 }
1068
Francis Murtagh46c09d02019-05-28 08:15:28 +01001069 // Check the supported data types
1070 std::vector<DataType> supportedTypes =
1071 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001072 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001073 DataType::Float32,
1074 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001075 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001076 DataType::QAsymmU8,
1077 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001078 };
1079
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001080 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001081
1082 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1083 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1084 {
1085 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1086 {
1087 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1088 "for BFloat16 input.");
1089 }
1090 }
1091 else
1092 {
1093 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1094 }
telsoa014fcda012018-03-09 14:13:49 +00001095}
1096
telsoa014fcda012018-03-09 14:13:49 +00001097void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1098{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001099 const std::string descriptorName{"NormalizationQueueDescriptor"};
1100
1101 ValidateNumInputs(workloadInfo, descriptorName, 1);
1102 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1103
1104 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1105 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001106
1107 // Check the supported data types
1108 std::vector<DataType> supportedTypes =
1109 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001110 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001111 DataType::Float16,
1112 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001113 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001114 DataType::QAsymmU8,
1115 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001116 };
1117
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001118 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001119
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001120 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001121
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001122 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001123}
1124
1125void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1126{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001127 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001128
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001129 ValidateNumInputs(workloadInfo, descriptorName, 2);
1130 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1131
1132 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1133 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1134 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1135
1136 std::vector<DataType> supportedTypes =
1137 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001138 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001139 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001140 DataType::Float16,
1141 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001142 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001143 DataType::QSymmS16,
1144 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001145 };
1146
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001147 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1148 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1149 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001150
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001151 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1152 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001153
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001154 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1155 inputTensorInfo1,
1156 outputTensorInfo,
1157 descriptorName,
1158 "input_0",
1159 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001160}
1161
telsoa014fcda012018-03-09 14:13:49 +00001162void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1163{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001164 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001165
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001166 ValidateNumInputs(workloadInfo, descriptorName, 2);
1167 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1168
1169 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1170 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1171 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1172
1173 std::vector<DataType> supportedTypes =
1174 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001175 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001176 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001177 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001178 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001179 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001180 DataType::QSymmS16,
1181 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001182 };
1183
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001184 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1185 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1186 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001187
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001188 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1189 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001190
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001191 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1192 inputTensorInfo1,
1193 outputTensorInfo,
1194 descriptorName,
1195 "input_0",
1196 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001197}
1198
1199void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1200{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001201 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001202
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001203 ValidateNumInputs(workloadInfo, descriptorName, 1);
1204 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1205
1206 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1207 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001208
1209 std::vector<DataType> supportedTypes =
1210 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001211 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001212 DataType::Float16,
1213 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001214 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001215 DataType::QAsymmU8,
1216 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001217 };
1218
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001219 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1220 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001221
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001222 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001223 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001224
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001225 ValidatePointer(m_Mean, descriptorName, "mean");
1226 ValidatePointer(m_Variance, descriptorName, "variance");
1227 ValidatePointer(m_Beta, descriptorName, "beta");
1228 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001229
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001230 const TensorInfo& mean = m_Mean->GetTensorInfo();
1231 const TensorInfo& variance = m_Variance->GetTensorInfo();
1232 const TensorInfo& beta = m_Beta->GetTensorInfo();
1233 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001234
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001235 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1236 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1237 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1238 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001239
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001240 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1241 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1242 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001243}
1244
1245void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1246{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001247 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001248
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001249 ValidateNumInputs(workloadInfo, descriptorName, 1);
1250 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001251
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001252 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1253 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001254
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001255 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1256 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001257
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001258 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001259
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001260 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1261 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001262
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001263 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001264
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001265 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001266 if (m_Parameters.m_BiasEnabled)
1267 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001268 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001269
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001270 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1271 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001272
1273 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1274 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001275 }
1276
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001277 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1278 {
1279 throw InvalidArgumentException(
1280 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1281 "cannot be either negative or 0.",
1282 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1283 }
1284
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001285 ValidatePerAxisQuantization(inputTensorInfo,
1286 outputTensorInfo,
1287 weightTensorInfo,
1288 optionalBiasTensorInfo,
1289 descriptorName);
1290
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001291 std::vector<DataType> supportedTypes =
1292 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001293 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001294 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001295 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001296 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001297 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001298 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001299 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001300 };
1301
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001302 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001303
1304 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1305 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1306 {
1307 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1308 {
1309 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1310 "for BFloat16 input.");
1311 }
1312 }
1313 else
1314 {
1315 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1316 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001317}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001318
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001319void Convolution3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1320{
1321 const std::string descriptorName{"Convolution3dQueueDescriptor"};
1322
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001323 uint32_t numInputs = 2;
1324 if (m_Parameters.m_BiasEnabled)
1325 {
1326 numInputs = 3;
1327 }
1328 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001329 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1330
1331 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1332 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1333
1334 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1335 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1336
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001337 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001338 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 5, "weight");
1339
1340 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1341
1342 Optional<TensorInfo> optionalBiasTensorInfo;
1343 if (m_Parameters.m_BiasEnabled)
1344 {
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001345 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001346 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1347
1348 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1349 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1350 }
1351
1352 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 || m_Parameters.m_StrideZ <= 0 )
1353 {
1354 throw InvalidArgumentException(
1355 fmt::format("{}: strideX (provided {}), strideY (provided {}) or strideZ (provided {})"
1356 "cannot be either negative or 0.",
1357 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY, m_Parameters.m_StrideZ));
1358 }
1359
1360 ValidatePerAxisQuantization(inputTensorInfo,
1361 outputTensorInfo,
1362 weightTensorInfo,
1363 optionalBiasTensorInfo,
1364 descriptorName);
1365
1366 std::vector<DataType> supportedTypes =
1367 {
1368 DataType::BFloat16,
1369 DataType::Float16,
1370 DataType::Float32,
1371 DataType::QAsymmS8,
1372 DataType::QAsymmU8,
1373 DataType::QSymmS16,
1374 DataType::QSymmS8
1375 };
1376
1377 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1378 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1379}
1380
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001381void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1382{
1383 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1384
Cathal Corbett06902652022-04-14 17:55:11 +01001385 uint32_t numInputs = 2;
1386 if (m_Parameters.m_BiasEnabled)
1387 {
1388 numInputs = 3;
1389 }
1390
1391 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001392 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1393
1394 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1395 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1396
1397 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1398 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1399
Cathal Corbett06902652022-04-14 17:55:11 +01001400 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001401 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1402
1403 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1404 {
1405 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001406 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1407 "cannot be smaller than 1.",
1408 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001409 }
1410
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001411 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1412 {
1413 throw InvalidArgumentException(
1414 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1415 "cannot be either negative or 0.",
1416 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1417 }
1418
Jan Eilers53ef7952021-06-02 12:01:25 +01001419 if (weightTensorInfo.GetShape()[0] != 1)
1420 {
1421 throw InvalidArgumentException(fmt::format(
1422 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1423 "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
1424 descriptorName,
1425 weightTensorInfo.GetShape()[0],
1426 weightTensorInfo.GetShape()[1],
1427 weightTensorInfo.GetShape()[2],
1428 weightTensorInfo.GetShape()[3]));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001429 }
1430
Cathal Corbett4b19d222022-05-11 20:12:17 +01001431 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1432 const unsigned int numWeightOutputChannelsRefFormat = weightTensorInfo.GetShape()[3];
1433 const unsigned int numWeightOutputChannelsAclFormat = weightTensorInfo.GetShape()[1];
1434 const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1435
1436 // Weights format has two valid options: [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] (CpuAcc/GpuAcc).
1437 bool validRefFormat = (numWeightOutputChannelsRefFormat == numOutputChannels);
1438 bool validAclFormat = (numWeightOutputChannelsAclFormat == numOutputChannels);
1439
1440 if (!(validRefFormat || validAclFormat))
1441 {
1442 throw InvalidArgumentException(fmt::format(
1443 "{0}: The weight format in armnn is expected to be [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] "
1444 "(CpuAcc/GpuAcc). But neither the 4th (CpuRef) or 2nd (CpuAcc/GpuAcc) dimension is equal to Cout."
1445 "Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
1446 descriptorName,
1447 numOutputChannels,
1448 weightTensorInfo.GetShape()[0],
1449 weightTensorInfo.GetShape()[1],
1450 weightTensorInfo.GetShape()[2],
1451 weightTensorInfo.GetShape()[3]));
1452 }
1453
Teresa Charlind8df0262019-11-11 12:28:15 +00001454 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001455
Teresa Charlind8df0262019-11-11 12:28:15 +00001456 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001457 if (m_Parameters.m_BiasEnabled)
1458 {
Cathal Corbett06902652022-04-14 17:55:11 +01001459 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Teresa Charlind8df0262019-11-11 12:28:15 +00001460 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001461
1462 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1463 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1464 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001465 ValidatePerAxisQuantization(inputTensorInfo,
1466 outputTensorInfo,
1467 weightTensorInfo,
1468 optionalBiasTensorInfo,
1469 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001470
1471 std::vector<DataType> supportedTypes =
1472 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001473 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001474 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001475 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001476 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001477 DataType::QAsymmU8,
1478 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001479 };
1480
1481 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1482 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001483}
1484
1485void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1486{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001487 const std::string descriptorName{"PermuteQueueDescriptor"};
1488
1489 ValidateNumInputs(workloadInfo, descriptorName, 1);
1490 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001491
1492 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1493
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001494 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1495 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001496
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001497 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1498 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001499
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001500 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001501 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001502 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001503 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001504 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1505 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1506 "must match dst dimension " + to_string(mapping[i]) +
1507 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001508 }
1509 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001510
1511 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001512}
1513
1514void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1515{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001516 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001517
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001518 ValidateNumInputs(workloadInfo, descriptorName, 1);
1519 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1520
1521 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1522 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1523
1524 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1525 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001526
1527 std::vector<DataType> supportedTypes =
1528 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001529 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001530 DataType::Float32,
1531 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001532 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001533 DataType::QAsymmU8,
1534 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001535 };
1536
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001537 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1538 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001539}
1540
Tamás Nyíri7b885b32021-10-26 14:47:57 +01001541void Pooling3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1542{
1543 const std::string descriptorName{"Pooling3dQueueDescriptor"};
1544
1545 ValidateNumInputs(workloadInfo, descriptorName, 1);
1546 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1547
1548 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1549 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1550
1551 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1552 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1553
1554 std::vector<DataType> supportedTypes =
1555 {
1556 DataType::BFloat16,
1557 DataType::Float32,
1558 DataType::Float16,
1559 DataType::QAsymmS8,
1560 DataType::QAsymmU8,
1561 DataType::QSymmS16
1562 };
1563
1564 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1565 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1566}
1567
1568
telsoa014fcda012018-03-09 14:13:49 +00001569void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1570{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001571 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001572
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001573 ValidateNumInputs(workloadInfo, descriptorName, 1);
1574 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1575
1576 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1577 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1578
1579 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1580 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001581
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001582 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001583 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001584 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001585 DataType::Float16,
1586 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001587 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001588 DataType::QAsymmU8,
1589 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001590 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001591
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001592 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1593 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001594
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001595 // ResizeBilinear only changes width and height: batch and channel count must match.
1596 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1597 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001598 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001599 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001600 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001601 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1602 descriptorName, inputBatchSize, outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001603 }
1604
Teresa Charlin970f43b2019-07-01 13:51:07 +01001605 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001606 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1607 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001608 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001609 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001610 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001611 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1612 descriptorName, inputChannelCount, outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001613 }
1614}
1615
1616void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1617{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001618 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001619
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001620 ValidateNumInputs(workloadInfo, descriptorName, 1);
1621 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1622
1623 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1624 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1625
1626 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1627 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001628
1629 std::vector<DataType> supportedTypes =
1630 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001631 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001632 DataType::Float16,
1633 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001634 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001635 DataType::QAsymmU8,
1636 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001637 };
1638
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001639 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1640 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001641
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001642 // Resize only changes width and height: batch and channel count must match.
1643 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1644 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001645 if (inputBatchSize != outputBatchSize)
1646 {
1647 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001648 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1649 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001650 }
1651
1652 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001653 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1654 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001655 if (inputChannelCount != outputChannelCount)
1656 {
1657 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001658 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1659 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001660 }
1661}
1662
1663void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1664{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001665 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001666
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001667 ValidateNumInputs(workloadInfo, descriptorName, 1);
1668 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1669
1670 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1671 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1672
1673 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1674 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1675
1676 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1677
telsoa014fcda012018-03-09 14:13:49 +00001678 if (m_Parameters.m_Min > m_Parameters.m_Max)
1679 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001680 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001681 }
telsoa014fcda012018-03-09 14:13:49 +00001682}
1683
Kevin Mayce5045a2019-10-02 14:07:47 +01001684void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1685{
1686 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1687
1688 ValidateNumInputs(workloadInfo, descriptorName, 1);
1689 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1690
1691 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1692 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1693
1694 if (inputTensorInfo.GetNumDimensions() > 4)
1695 {
1696 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1697 }
1698
1699 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1700
1701 // Check the supported data types
1702 std::vector<DataType> supportedTypes =
1703 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001704 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001705 DataType::Float32,
1706 DataType::Float16
1707 };
1708
1709 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001710 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001711}
1712
telsoa014fcda012018-03-09 14:13:49 +00001713void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1714{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001715 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001716
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001717 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001718 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1719
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001720 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1721 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1722
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001723 if (inputTensorInfo.GetNumDimensions() > 4)
1724 {
1725 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1726 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001727
1728 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001729
1730 // Check the supported data types
1731 std::vector<DataType> supportedTypes =
1732 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001733 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001734 DataType::Float32,
1735 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001736 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001737 DataType::QAsymmU8,
1738 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001739 };
1740
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001741 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001742 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1743}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001744
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001745void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1746{
1747 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1748
1749 ValidateNumInputs(workloadInfo, descriptorName, 1);
1750 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1751
1752 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1753 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1754
1755 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1756
1757 std::vector<DataType> supportedTypes =
1758 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001759 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001760 DataType::Float32,
1761 DataType::Float16,
1762 };
1763
1764 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001765 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001766}
1767
1768void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1769{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001770 const std::string descriptorName{"ConstantQueueDescriptor"};
1771
1772 ValidateNumInputs(workloadInfo, descriptorName, 0);
1773 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001774
1775 if (!m_LayerOutput)
1776 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001777 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001778 }
1779
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001780 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1781 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001782
1783 // Check the supported data types
1784 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001785 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001786 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001787 DataType::Float32,
1788 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001789 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001790 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001791 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001792 DataType::QSymmS16,
1793 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001794 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001795
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001796 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001797}
1798
1799void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1800{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001801 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001802
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001803 ValidateNumInputs(workloadInfo, descriptorName, 1);
1804 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1805
1806 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1807 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1808
1809 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001810
1811 // Check the supported data types
1812 std::vector<DataType> supportedTypes =
1813 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001814 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001815 DataType::Float32,
1816 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001817 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001818 DataType::QAsymmU8,
1819 DataType::QSymmS16,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001820 DataType::Signed32,
1821 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001822 };
1823
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001824 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1825 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001826}
1827
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001828void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1829{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001830 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001831
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001832 ValidateNumInputs(workloadInfo, descriptorName, 1);
1833 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1834
1835 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1836 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1837
1838 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1839 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001840
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001841 if (m_Parameters.m_BlockShape.size() != 2)
1842 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001843 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001844 }
1845
1846 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1847 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001848 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1849 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001850 }
1851
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001852 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001853
1854 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001855 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001856
Matthew Bentham8800c002018-11-19 13:19:28 +00001857 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001858
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001859 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1860 widthPad.first + widthPad.second;
1861 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1862 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001863
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001864 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1865 inputShape[dimensionIndices.GetChannelsIndex()];
1866 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001867
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001868 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001869 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001870 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001871 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001872 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001873 }
1874
1875 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001876 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001877 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1878 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001879 }
nikraj01120522a2019-05-31 11:33:07 +01001880
1881 std::vector<DataType> supportedTypes =
1882 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001883 DataType::BFloat16,
1884 DataType::Float16,
1885 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001886 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001887 DataType::QAsymmU8,
1888 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001889 };
1890
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001891 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1892 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001893}
1894
Keith Davisa57eccb2019-06-14 17:33:22 +01001895void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1896{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001897 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001898
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001899 ValidateNumInputs(workloadInfo, descriptorName, 1);
1900 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001901
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001902 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1903 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1904
1905 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1906 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001907
1908 std::vector<DataType> supportedTypes =
1909 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001910 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001911 DataType::Float32,
1912 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001913 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001914 DataType::QAsymmU8,
1915 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001916 };
1917
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001918 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1919 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001920
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001921 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1922
1923 if (m_Parameters.m_BlockSize == 0)
1924 {
1925 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1926 }
1927
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001928 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1929 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1930 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1931 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001932
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001933 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001934 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001935 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001936 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1937 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001938 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001939
1940 const TensorShape& outputShape = outputTensorInfo.GetShape();
1941 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1942 {
1943 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1944 "must be divisible by the square of block size." );
1945 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001946}
1947
telsoa014fcda012018-03-09 14:13:49 +00001948void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1949{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001950 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001951
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001952 ValidateNumInputs(workloadInfo, descriptorName, 1);
1953 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1954
1955 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1956 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001957
1958 std::vector<DataType> supportedTypes =
1959 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001960 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001961 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001962 DataType::Float16,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001963 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001964 };
1965
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001966 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matthew Sloyan81beae32021-07-13 19:46:11 +01001967 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1968 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1969 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001970}
1971
telsoa01c577f2c2018-08-31 09:22:23 +01001972void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1973{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001974 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1975
1976 const std::string descriptorName{"LstmQueueDescriptor"};
1977
1978 // check dimensions of all inputs and outputs
1979 if (workloadInfo.m_InputTensorInfos.size() != 3)
1980 {
1981 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1982 }
1983 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1984 {
1985 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1986 }
1987
1988 std::vector<DataType> supportedTypes =
1989 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001990 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001991 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001992 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001993 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001994 };
1995
Jan Eilers38e05bd2019-06-26 13:10:09 +01001996 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001997 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1998
Jan Eilers38e05bd2019-06-26 13:10:09 +01001999 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002000 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01002001 {
2002 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2003 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002004 descriptorName,
2005 "input_0",
2006 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01002007 }
2008 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002009 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01002010 {
2011 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2012 workloadInfo.m_OutputTensorInfos[i],
2013 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002014 "input_0",
2015 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01002016 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002017
janeil0117d8d852019-11-15 15:00:16 +00002018 // Making sure clipping parameters have valid values.
2019 // == 0 means no clipping
2020 // > 0 means clipping
2021 if (m_Parameters.m_ClippingThresCell < 0.0f)
2022 {
2023 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
2024 }
2025 if (m_Parameters.m_ClippingThresProj < 0.0f)
2026 {
2027 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
2028 }
2029
Jan Eilers38e05bd2019-06-26 13:10:09 +01002030 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01002031 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
2032 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
2033 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
2034 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
2035 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
2036 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
2037
Jan Eilers38e05bd2019-06-26 13:10:09 +01002038 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002039 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
2040 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002041 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002042 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
2043 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002044 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002045 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
2046 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002047 // scratchBufferTensor
2048 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002049 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
2050 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002051 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002052 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
2053 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002054 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002055 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
2056 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002057 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002058 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
2059 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002060
Jan Eilers38e05bd2019-06-26 13:10:09 +01002061 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
2062 if ( m_InputToInputWeights )
2063 {
2064 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
2065 (n_cell * n_input), "InputLayerNormWeights");
2066 }
2067
2068 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2069 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
2070 (n_cell * n_input), "InputToForgetWeights");
2071
2072 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2073 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2074 (n_cell * n_input), "InputToCellWeights");
2075
2076 if ( m_RecurrentToInputWeights )
2077 {
2078 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2079 (n_cell * n_output), "RecurrentToInputWeights");
2080 }
2081
2082 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2083 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2084 (n_cell * n_output), "RecurrentToForgetWeights");
2085
2086 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2087 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2088 (n_cell * n_output), "RecurrentToCellWeights");
2089
2090 // Make sure the input-gate's parameters are either both present (regular
2091 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2092 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2093 !m_Parameters.m_CifgEnabled) ||
2094 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2095 m_Parameters.m_CifgEnabled));
2096 if (!cifg_weights_all_or_none)
2097 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002098 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2099 "RecurrentToInputWeights must either both be present (regular LSTM) "
2100 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2101 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002102 }
2103
2104 if ( m_CellToInputWeights )
2105 {
2106 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2107 n_cell, "CellToInputWeights");
2108 }
2109 if ( m_CellToForgetWeights )
2110 {
2111 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2112 n_cell, "CellToForgetWeights");
2113 }
2114 if ( m_CellToOutputWeights )
2115 {
2116 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2117 n_cell, "CellToOutputWeights");
2118 }
2119
2120 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2121 bool peephole_weights_all_or_none =
2122 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2123 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2124 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2125 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2126 if (!peephole_weights_all_or_none)
2127 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002128 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002129 }
2130
2131 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2132 if (m_Parameters.m_CifgEnabled)
2133 {
2134 if (m_InputGateBias)
2135 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002136 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002137 }
2138 }
2139 else
2140 {
2141 if (!m_InputGateBias)
2142 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002143 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2144 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002145 }
2146 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2147 n_cell, "InputGateBias");
2148 }
2149
2150 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2151 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2152
2153 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2154 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2155
2156 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2157 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2158
2159 if (m_ProjectionWeights)
2160 {
2161 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2162 (n_cell * n_output), "ProjectionWeights");
2163 }
2164 if (m_ProjectionBias)
2165 {
2166 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2167 }
2168
2169 // Making sure the projection tensors are consistent:
2170 // 1) If projection weight is not present, then projection bias should not be
2171 // present.
2172 // 2) If projection weight is present, then projection bias is optional.
2173 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2174 !m_Parameters.m_ProjectionEnabled)
2175 || (m_ProjectionWeights && !m_ProjectionBias &&
2176 m_Parameters.m_ProjectionEnabled)
2177 || (m_ProjectionWeights && m_ProjectionBias &&
2178 m_Parameters.m_ProjectionEnabled));
2179 if (!projecton_tensors_consistent)
2180 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002181 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002182 }
2183
2184 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2185 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2186 // either all have values or none of them have values. Layer normalization is used when the values of all the
2187 // layer normalization weights are present
2188 if (m_InputLayerNormWeights)
2189 {
2190 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2191 }
2192 if (m_ForgetLayerNormWeights)
2193 {
2194 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2195 }
2196 if (m_CellLayerNormWeights)
2197 {
2198 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2199 }
2200 if (m_OutputLayerNormWeights)
2201 {
2202 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2203 }
2204
Jan Eilers38e05bd2019-06-26 13:10:09 +01002205 if (m_Parameters.m_LayerNormEnabled)
2206 {
2207 if (!m_Parameters.m_CifgEnabled)
2208 {
2209 if (!m_InputLayerNormWeights)
2210 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002211 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2212 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002213 }
2214 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2215 1, n_cell, "InputLayerNormWeights");
2216 }
2217 else if (m_InputLayerNormWeights)
2218 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002219 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2220 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002221 }
2222
2223 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2224 "ForgetLayerNormWeights");
2225 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2226
2227 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2228 "OutputLayerNormWeights");
2229 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2230
2231 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2232 "CellLayerNormWeights");
2233 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2234 }
2235 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2236 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002237 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2238 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002239 }
telsoa01c577f2c2018-08-31 09:22:23 +01002240}
2241
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002242void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2243{
2244 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2245
2246 ValidateNumInputs(workloadInfo, descriptorName, 1);
2247 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2248
2249 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2250 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2251
2252 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2253 {
2254 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2255 }
2256
2257 if (outputTensorInfo.GetDataType() != DataType::Float32)
2258 {
2259 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2260 }
2261
2262 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2263}
2264
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002265void ConvertFp32ToBf16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2266{
2267 const std::string descriptorName{"ConvertFp32ToBf16QueueDescriptor"};
2268
2269 ValidateNumInputs(workloadInfo, descriptorName, 1);
2270 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2271
2272 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2273 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2274
2275 if (inputTensorInfo.GetDataType() != DataType::Float32)
2276 {
2277 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2278 }
2279
2280 if (outputTensorInfo.GetDataType() != DataType::BFloat16)
2281 {
2282 throw InvalidArgumentException(descriptorName + ": Output tensor type must be BFloat16.");
2283 }
2284
2285 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2286}
2287
telsoa01c577f2c2018-08-31 09:22:23 +01002288void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2289{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002290 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002291
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002292 ValidateNumInputs(workloadInfo, descriptorName, 1);
2293 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2294
2295 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2296 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2297
2298 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002299 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002300 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002301 }
2302
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002303 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002304 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002305 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002306 }
2307
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002308 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002309}
2310
2311void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2312{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002313 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002314
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002315 ValidateNumInputs(workloadInfo, descriptorName, 1);
2316 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2317
2318 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2319 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2320
2321 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002322 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002323 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002324 }
2325
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002326 if (outputTensorInfo.GetDataType() != DataType::Float32)
2327 {
2328 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2329 }
2330
2331 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002332}
2333
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002334void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2335{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002336 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002337
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002338 ValidateNumInputs(workloadInfo, descriptorName, 2);
2339 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2340
2341 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2342 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2343 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2344
2345 std::vector<DataType> supportedTypes =
2346 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002347 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002348 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002349 DataType::Float32,
2350 DataType::QAsymmS8,
2351 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002352 DataType::QSymmS16,
2353 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002354 };
2355
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002356 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2357 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2358 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002359
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002360 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2361 inputTensorInfo1,
2362 outputTensorInfo,
2363 descriptorName,
2364 "input_0",
2365 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002366}
2367
David Beckc2044fe2018-09-05 15:00:38 +01002368void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2369{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002370 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002371
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002372 ValidateNumInputs(workloadInfo, descriptorName, 2);
2373 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2374
2375 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2376 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2377 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2378
2379 std::vector<DataType> supportedTypes =
2380 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002381 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002382 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002383 DataType::Float32,
2384 DataType::QAsymmS8,
2385 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002386 DataType::QSymmS16,
2387 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002388 };
2389
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002390 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2391 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2392 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002393
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002394 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2395 inputTensorInfo1,
2396 outputTensorInfo,
2397 descriptorName,
2398 "input_0",
2399 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002400}
2401
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002402void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2403{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002404 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002405
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002406 ValidateNumInputs(workloadInfo, descriptorName, 2);
2407 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2408
2409 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2410 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2411 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2412
2413 std::vector<DataType> supportedTypes =
2414 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002415 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002416 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002417 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002418 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002419 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002420 DataType::QSymmS16,
2421 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002422 };
2423
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002424 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2425 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2426 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002427
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002428 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2429 inputTensorInfo1,
2430 outputTensorInfo,
2431 descriptorName,
2432 "input_0",
2433 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002434}
2435
narpra01a6bf9122018-09-10 09:50:09 +01002436void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2437{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002438 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002439
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002440 ValidateNumInputs(workloadInfo, descriptorName, 1);
2441 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2442
2443 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2444 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002445
2446 std::vector<DataType> supportedTypes =
2447 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002448 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002449 DataType::Float32,
2450 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002451 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002452 DataType::QAsymmU8,
2453 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002454 };
narpra01eb061912018-09-10 17:35:27 +01002455
James Conroy4d1ff582019-06-10 17:06:39 +01002456 // First check if input tensor data type is supported, then
2457 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002458 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2459 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002460
narpra0132b90462018-09-13 11:07:48 +01002461 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002462 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002463 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002464 }
narpra0132b90462018-09-13 11:07:48 +01002465 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002466 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002467 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002468 }
2469 else
2470 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002471 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002472 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002473 ValidateTensorNumDimensions(outputTensorInfo,
2474 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002475 outputDim > 0 ? outputDim : 1,
2476 "output");
2477 }
narpra01a6bf9122018-09-10 09:50:09 +01002478}
2479
jimfly012c9322a2018-09-19 10:59:49 +01002480void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2481{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002482 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002483
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002484 ValidateNumInputs(workloadInfo, descriptorName, 1);
2485 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2486
2487 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2488 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002489
jimfly012c9322a2018-09-19 10:59:49 +01002490 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002491 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2492
jimfly012c9322a2018-09-19 10:59:49 +01002493 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002494 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2495 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2496 "as there are dimensions in the input tensor that is " +
2497 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2498 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002499 }
2500}
2501
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002502void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2503{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002504 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002505
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002506 ValidateNumInputs(workloadInfo, descriptorName, 1);
2507 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002508
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002509 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2510 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2511
Sadik Armagan2208b602019-07-31 16:36:27 +01002512 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002513 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002514 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002515 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002516 DataType::Float16,
2517 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002518 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002519 DataType::QAsymmU8,
2520 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002521 };
2522
2523 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002524
Keith Davis0c2eeac2020-02-11 16:51:50 +00002525 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002526 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002527 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002528 }
2529}
2530
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002531void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2532{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002533 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002534
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002535 ValidateNumInputs(workloadInfo, descriptorName, 1);
2536 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002537
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002538 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2539 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002540
2541 std::vector<DataType> supportedTypes =
2542 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002543 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002544 DataType::Float32,
2545 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002546 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002547 DataType::QAsymmU8,
2548 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002549 };
2550
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002551 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2552 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002553}
2554
Conor Kennedy430b5d82018-11-14 15:28:28 +00002555void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2556{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002557 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002558
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002559 ValidateNumInputs(workloadInfo, descriptorName, 1);
2560 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2561
2562 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2563 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002564
2565 std::vector<DataType> supportedTypes =
2566 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002567 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002568 DataType::Float16,
2569 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002570 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002571 DataType::QAsymmU8,
2572 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002573 };
2574
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002575 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2576 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002577
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002578 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002579
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002580 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002581 if (rank > 4)
2582 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002583 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002584 }
2585
Conor Kennedy430b5d82018-11-14 15:28:28 +00002586 // Begin, End & Stride length must be of rank(input0)
2587 if (m_Parameters.m_Begin.size() != rank)
2588 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002589 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002590 }
2591
2592 if (m_Parameters.m_End.size() != rank)
2593 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002594 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002595 }
2596
2597 if (m_Parameters.m_Stride.size() != rank)
2598 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002599 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002600 }
2601
2602 // Stride entries must be non-zero
2603 for (auto& stride : m_Parameters.m_Stride)
2604 {
2605 if (stride == 0)
2606 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002607 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002608 }
2609 }
2610}
2611
kevmay0190539692018-11-29 08:40:19 +00002612void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2613{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002614 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002615
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002616 ValidateNumInputs(workloadInfo, descriptorName, 2);
2617 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2618
2619 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2620 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2621 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2622
2623 std::vector<DataType> supportedTypes =
2624 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002625 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002626 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002627 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002628 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002629 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002630 DataType::QSymmS16,
2631 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002632 };
2633
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002634 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2635 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2636 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002637
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002638 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2639 inputTensorInfo1,
2640 outputTensorInfo,
2641 descriptorName,
2642 "input_0",
2643 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002644}
2645
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002646void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2647{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002648 const std::string descriptorName{"DebugQueueDescriptor"};
2649
2650 ValidateNumInputs(workloadInfo, descriptorName, 1);
2651 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002652}
2653
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002654void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2655{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002656 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002657
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002658 ValidateNumInputs(workloadInfo, descriptorName, 2);
2659 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002660
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002661 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2662 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2663 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2664
2665 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2666 inputTensorInfo1,
2667 outputTensorInfo,
2668 descriptorName,
2669 "input_0",
2670 "input_1");
2671
2672 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002673 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002674 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002675 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002676}
2677
FrancisMurtagh878f0232018-12-19 10:56:15 +00002678void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2679{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002680 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002681
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002682 ValidateNumInputs(workloadInfo, descriptorName, 2);
2683 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002684
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002685 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2686 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2687 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2688
2689 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2690 inputTensorInfo1,
2691 outputTensorInfo,
2692 descriptorName,
2693 "input_0",
2694 "input_1");
2695
2696 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002697 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002698 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002699 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002700}
2701
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002702void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2703{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002704 const std::string descriptorName{"RsqrtQueueDescriptor"};
2705
2706 ValidateNumInputs(workloadInfo, descriptorName, 1);
2707 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2708
2709 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2710 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2711
2712 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002713
2714 std::vector<DataType> supportedTypes =
2715 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002716 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002717 DataType::Float16,
2718 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002719 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002720 DataType::QAsymmU8,
2721 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002722 };
2723
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002724 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2725 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002726}
2727
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01002728void GatherNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2729{
2730 const std::string descriptorName{"GatherNdQueueDescriptor"};
2731
2732 ValidateNumInputs(workloadInfo, descriptorName, 2);
2733 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2734
2735 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2736 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
2737 {
2738 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
2739 }
2740
2741 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2742 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2743
2744 std::vector<DataType> supportedTypes =
2745 {
2746 DataType::BFloat16,
2747 DataType::Float16,
2748 DataType::Float32,
2749 DataType::QAsymmS8,
2750 DataType::QAsymmU8,
2751 DataType::QSymmS16,
2752 DataType::Signed32,
2753 };
2754
2755 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2756
2757 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2758
2759 unsigned int outputDim = outputTensorInfo.GetNumDimensions();
2760 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
2761}
2762
narpra01b89b05f2019-01-16 09:53:09 +00002763void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2764{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002765 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002766
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002767 ValidateNumInputs(workloadInfo, descriptorName, 2);
2768 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002769
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002770 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2771 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002772 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002773 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002774 }
2775
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002776 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2777 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2778
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002779 std::vector<DataType> supportedTypes =
2780 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002781 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002782 DataType::Float16,
2783 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002784 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002785 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002786 DataType::QSymmS16,
2787 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002788 };
2789
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002790 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002791
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002792 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002793
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002794 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2795 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002796}
2797
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002798void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2799{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002800 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2801
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002802 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002803
2804 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2805 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002806 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002807 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2808 }
2809
2810 if (m_Anchors == nullptr)
2811 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002812 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002813 }
2814
2815 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002816 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2817 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2818
2819 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002820 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002821 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2822 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002823
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002824 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2825 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2826 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002827
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002828 const std::vector<DataType> supportedInputTypes =
2829 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002830 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002831 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002832 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002833 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002834 DataType::QAsymmU8,
2835 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002836 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002837
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002838 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2839 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2840 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2841
2842 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2843 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2844 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2845 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2846
2847 // NOTE: Output is always Float32 regardless of input type
2848 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2849 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2850 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2851 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002852
2853 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2854 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002855 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002856 "must be positive and less than or equal to 1.");
2857 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002858
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002859 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2860 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002861 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002862 "should be equal to number of classes + 1.");
2863 }
2864}
2865
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002866void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2867{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002868 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002869
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002870 ValidateNumInputs(workloadInfo, descriptorName, 1);
2871 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2872
2873 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2874 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2875
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002876 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002877 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002878 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002879 }
2880
Sadik Armagan2208b602019-07-31 16:36:27 +01002881 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002882 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002883 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002884 DataType::Float32,
2885 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002886 };
2887
2888 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002889}
2890
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002891void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2892{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002893 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002894
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002895 ValidateNumInputs(workloadInfo, descriptorName, 2);
2896 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002897
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002898 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2899 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2900 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002901
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002902 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2903 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2904
2905 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2906 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002907}
2908
Keith Davis3ae3f972021-05-21 16:33:48 +01002909void ShapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2910{
2911 const std::string& descriptorName{"ShapeQueueDescriptor"};
2912
2913 ValidateNumInputs(workloadInfo, descriptorName, 1);
2914 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2915
2916 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2917 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2918
2919 std::vector<DataType> supportedTypes =
2920 {
2921 DataType::BFloat16,
2922 DataType::Float16,
2923 DataType::Float32,
2924 DataType::QAsymmS8,
2925 DataType::QAsymmU8,
2926 DataType::QAsymmS8,
2927 DataType::QSymmS8,
2928 DataType::QSymmS16,
2929 DataType::Signed32
2930 };
2931
2932 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2933 ValidateDataTypes(outputTensorInfo, {DataType::Signed32}, descriptorName);
2934}
2935
Sadik Armaganeff363d2019-04-05 15:25:46 +01002936void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2937{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002938 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002939
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002940 ValidateNumInputs(workloadInfo, descriptorName, 2);
2941 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2942
2943 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2944 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2945
2946 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2947 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2948
2949 std::vector<DataType> supportedTypes =
2950 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002951 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002952 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002953 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002954 DataType::QAsymmU8,
2955 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002956 };
2957
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002958 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2959 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002960
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002961 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2962 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002963
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002964 ValidateTensorShapesMatch(inputTensorInfo0,
2965 outputTensorInfo0,
2966 descriptorName,
2967 "input_0",
2968 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002969
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002970 ValidateTensorShapesMatch(inputTensorInfo0,
2971 outputTensorInfo1,
2972 descriptorName,
2973 "input_0",
2974 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002975}
2976
Derek Lamberti901ea112019-12-10 22:07:09 +00002977void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002978{
2979 // This is internally generated so it should not need validation.
2980}
2981
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002982void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2983{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002984 const std::string& descriptorName{"PreluQueueDescriptor"};
2985
2986 ValidateNumInputs(workloadInfo, descriptorName, 2);
2987 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2988
2989 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2990 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2991 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002992
2993 std::vector<DataType> supportedTypes
2994 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002995 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002996 DataType::Float16,
2997 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002998 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002999 DataType::QAsymmU8,
3000 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003001 };
3002
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003003 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3004 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003005
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003006 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003007
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003008 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
3009 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003010
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003011 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
3012 alphaTensorInfo,
3013 outputTensorInfo,
3014 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003015 "input",
3016 "alpha");
3017}
3018
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003019void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3020{
3021 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
3022
3023 ValidateNumInputs(workloadInfo, descriptorName, 1);
3024 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3025
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003026 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3027 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3028
3029 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
3030 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003031
3032 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003033
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003034 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
3035 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003036
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003037 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
3038
3039 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003040 if (m_Parameters.m_BiasEnabled)
3041 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003042 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003043
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003044 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
3045 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003046
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003047 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003048 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003049 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003050
3051 ValidatePerAxisQuantization(inputTensorInfo,
3052 outputTensorInfo,
3053 weightTensorInfo,
3054 optionalBiasTensorInfo,
3055 descriptorName);
3056
3057 std::vector<DataType> supportedTypes =
3058 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003059 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003060 DataType::Float32,
3061 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003062 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003063 DataType::QAsymmU8,
3064 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003065 };
3066
3067 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3068 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003069}
3070
Mike Kellyc9ea45a2020-02-28 18:11:58 +00003071void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3072{
3073 const std::string descriptorName{"TransposeQueueDescriptor"};
3074
3075 ValidateNumInputs(workloadInfo, descriptorName, 1);
3076 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3077
3078 const PermutationVector& mapping = m_Parameters.m_DimMappings;
3079
3080 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3081 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3082
3083 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
3084 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
3085
3086 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
3087 {
3088 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
3089 {
3090 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
3091 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
3092 "must match dst dimension " + to_string(i) +
3093 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
3094 }
3095 }
3096
3097 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3098}
3099
Simon Obute51f67772021-09-03 15:50:13 +01003100void ChannelShuffleQueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
3101{
3102 const std::string descriptorName{"TransposeQueueDescriptor"};
3103
3104 ValidateNumInputs(workloadInfo, descriptorName, 1);
3105 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3106
3107 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3108 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3109
3110 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3111}
3112
James Conroy4f1f8992020-04-29 20:01:10 +01003113void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3114{
3115 const std::string descriptorName{"QLstmQueueDescriptor"};
3116
3117 // Validate number of inputs/outputs
3118 ValidateNumInputs(workloadInfo, descriptorName, 3);
3119 ValidateNumOutputs(workloadInfo, descriptorName, 3);
3120
3121 // Input/output tensor info
3122 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3123 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
3124 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
3125
3126 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3127 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3128 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
3129
3130 // Supported types for various tensors in QLSTM
3131 std::vector<DataType> inputOutputSupportedTypes =
3132 {
3133 DataType::QAsymmS8
3134 };
3135
3136 std::vector<DataType> cellStateSupportedTypes =
3137 {
3138 DataType::QSymmS16
3139 };
3140
3141 std::vector<DataType> weightsSupportedTypes =
3142 {
3143 DataType::QSymmS8
3144 };
3145
3146 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3147 {
3148 DataType::QSymmS16
3149 };
3150
3151 std::vector<DataType> biasSupportedTypes =
3152 {
3153 DataType::Signed32
3154 };
3155
3156 // Validate types of input/output tensors
3157 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3158 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3159 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3160
3161 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3162 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3163 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3164
3165 // Validate matching types of input/output tensors
3166 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3167 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3168 "outputStateIn", "outputStateOut");
3169 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3170
3171 // Infer number of batches, number of units, input size and output size from tensor dimensions
3172 const uint32_t numBatches = inputInfo.GetShape()[0];
3173 const uint32_t inputSize = inputInfo.GetShape()[1];
3174 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3175 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3176
3177 // Validate number of dimensions and number of elements for input/output tensors
3178 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3179 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3180 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3181
3182 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3183 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3184 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3185
3186 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3187 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3188 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3189 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3190
3191 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3192 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3193 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3194
3195 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3196 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3197 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3198
3199 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3200 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3201 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3202 " RecurrentToForgetWeights");
3203
3204 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3205 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3206 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3207
3208 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3209 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3210 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3211
3212 // Validate data types for MANDATORY weights tensors (all should match each other)
3213 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3214
3215 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3216 "inputToForgetWeights", "inputToCellWeights");
3217 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3218 "inputToForgetWeights", "inputToOutputWeights");
3219
3220 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3221 "inputToForgetWeights", "recurrentToForgeteights");
3222 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3223 "inputToForgetWeights", "recurrentToCellWeights");
3224 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3225 "inputToForgetWeights", "recurrentToOutputWeights");
3226
3227 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3228 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3229 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3230 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3231
3232 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3233 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3234 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3235
3236 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3237 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3238 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3239
3240 // Validate data types for MANDATORY bias tensors
3241 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3242
3243 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3244 "forgetGateBias", "cellBias");
3245 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3246 "forgetGateBias", "outputGateBias");
3247
3248 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3249 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3250 !m_Parameters.m_CifgEnabled) ||
3251 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3252 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3253
3254 if (!allCifgParamsPresentOrNot)
3255 {
3256 throw InvalidArgumentException(descriptorName +
3257 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3258 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3259 "set appropriately.");
3260 }
3261
3262 if (!m_Parameters.m_CifgEnabled)
3263 {
3264 // Validate number of dimensions and number of elements
3265 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3266 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3267
3268 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3269 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3270 " RecurrentToInputWeights");
3271
3272 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3273 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3274
3275 // Validate data types
3276 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3277 "inputToForgetWeights", "inputToInputWeights");
3278 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3279 "inputToForgetWeights", "recurrentToInputWeights");
3280 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3281 "forgetGateBias", "inputGateBias");
3282 }
3283
3284 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3285 bool allPeepholeWeightsPresentOrNot =
3286 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3287 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3288 || (!m_CellToInputWeights && !m_CellToForgetWeights
3289 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3290
3291 if (!allPeepholeWeightsPresentOrNot)
3292 {
3293 throw InvalidArgumentException(descriptorName +
3294 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3295 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3296 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3297 "appropriately.");
3298 }
3299
3300 if (m_Parameters.m_PeepholeEnabled)
3301 {
3302 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3303 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3304 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3305
3306 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3307 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3308 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3309 "cellToForgetWeight", "cellToOutputWeights");
3310
3311 if (!m_Parameters.m_CifgEnabled)
3312 {
3313 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3314 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3315 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3316 "cellToForgetWeights", "cellToInputWeights");
3317 }
3318 }
3319
3320 // Validate OPTIONAL params: Layer Norm Weights
3321 bool allLayerNormWeightsPresentOrNot =
3322 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3323 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3324 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3325 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3326
3327 if (!allLayerNormWeightsPresentOrNot)
3328 {
3329 throw InvalidArgumentException(descriptorName +
3330 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3331 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3332 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3333 "only be present when Layer Norm is enabled and CIFG is disabled. "
3334 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3335 }
3336
3337 if (m_Parameters.m_LayerNormEnabled)
3338 {
3339 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3340 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3341 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3342
3343 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3344 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3345 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3346 "forgetLayerNormWeights", "cellLayerNormWeights");
3347
3348 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3349 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3350 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3351 "forgetLayerNormWeights", "outputLayerNormWeights");
3352
3353 if (!m_Parameters.m_CifgEnabled)
3354 {
3355 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3356 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3357 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3358 "forgetLayerNormWeights", "inputLayerNormWeights");
3359 }
3360 }
3361
3362 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3363 bool correctProjectionTensorsPresent =
3364 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3365 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3366 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3367
3368 if (!correctProjectionTensorsPresent)
3369 {
3370 throw InvalidArgumentException(descriptorName +
3371 ": If projection is enabled, ProjectionWeights should be present and "
3372 "ProjectionBias is optional. If projection is disabled, neither "
3373 "ProjectionWeights nor ProjectionBias should be present.");
3374 }
3375
3376 if (m_Parameters.m_ProjectionEnabled)
3377 {
3378 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3379 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3380 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3381
3382 if (m_ProjectionBias)
3383 {
3384 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003385 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003386 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3387 }
3388
3389 }
3390 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3391 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3392 throw InvalidArgumentException(descriptorName +
3393 ": If projection is disabled, output quantization info (scale, offset) "
3394 "should match HiddenStateScale and HiddenStateZeroPoint.");
3395 }
3396
3397}
3398
James Conroy9c3cae82019-08-01 16:01:48 +01003399void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3400{
3401 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3402
3403 // Validate number of inputs/outputs
3404 ValidateNumInputs(workloadInfo, descriptorName, 3);
3405 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3406
3407 // Input/output tensor infos
3408 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3409 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3410 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3411
3412 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3413 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3414
3415 std::vector<DataType> inputOutputSupportedTypes =
3416 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003417 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003418 };
3419
3420 std::vector<DataType> cellStateSupportedTypes =
3421 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003422 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003423 };
3424
3425 std::vector<DataType> weightsSupportedTypes =
3426 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003427 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003428 };
3429
3430 std::vector<DataType> biasSupportedTypes =
3431 {
3432 DataType::Signed32
3433 };
3434
3435 // Validate types of input/output tensors
3436 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3437 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3438 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3439
3440 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3441 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3442
3443 // Validate matching types of input/output tensors
3444 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3445 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3446 "outputStateIn", "outputStateOut");
3447 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3448
3449 // Validate matching quantization info for input/output tensors
3450 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3451 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3452 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003453
James Conroy9c3cae82019-08-01 16:01:48 +01003454 // Infer number of batches, input size and output size from tensor dimensions
3455 const uint32_t numBatches = inputInfo.GetShape()[0];
3456 const uint32_t inputSize = inputInfo.GetShape()[1];
3457 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3458
3459 // Validate number of dimensions and number of elements for input/output tensors
3460 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3461 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3462 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3463 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3464 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3465
3466 // Validate number of dimensions and number of elements for weights tensors
3467 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3468 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3469 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3470
3471 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3472 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3473 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3474
3475 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3476 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3477 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3478
3479 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3480 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3481 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3482
3483 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3484 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3485 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3486
3487 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3488 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3489 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3490 " RecurrentToForgetWeights");
3491
3492 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3493 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3494 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3495
3496 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3497 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3498 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3499
3500 // Validate data types for weights tensors (all should match each other)
3501 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3502
3503 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3504 "inputToInputWeights", "inputToForgetWeights");
3505 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3506 "inputToInputWeights", "inputToCellWeights");
3507 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3508 "inputToInputWeights", "inputToOutputWeights");
3509
3510 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3511 "inputToInputWeights", "recurrentToInputWeights");
3512 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3513 "inputToInputWeights", "recurrentToForgeteights");
3514 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3515 "inputToInputWeights", "recurrentToCellWeights");
3516 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3517 "inputToInputWeights", "recurrentToOutputWeights");
3518
3519 // Validate matching quantization info for weight tensors (all should match each other)
3520 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3521 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3522 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3523 descriptorName, "inputToInputWeights", "inputToCellWeights");
3524 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3525 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3526
3527 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3528 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3529 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3530 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3531 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3532 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3533 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3534 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3535
3536 // Validate number of dimensions and number of elements in bias tensors
3537 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3538 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3539 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3540
3541 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3542 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3543 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3544
3545 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3546 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3547 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3548
3549 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3550 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3551 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3552
3553 // Validate data types for bias tensors (all should match each other)
3554 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3555
3556 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3557 "inputGateBias", "forgetGateBias");
3558 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3559 "inputGateBias", "cellBias");
3560 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3561 "inputGateBias", "outputGateBias");
3562
3563 // Validate bias tensor quantization info
3564 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3565 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3566 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3567 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3568}
3569
Kevin May868eb142019-09-04 17:29:31 +01003570void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3571{
3572 const std::string descriptorName{"AbsQueueDescriptor"};
3573
3574 ValidateNumInputs(workloadInfo, descriptorName, 1);
3575 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3576
3577 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3578 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3579
3580 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3581
3582 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003583 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003584 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003585 DataType::Float16,
3586 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003587 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003588 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003589 DataType::QSymmS16,
3590 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003591 };
Kevin May868eb142019-09-04 17:29:31 +01003592
3593 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3594 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3595}
3596
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003597void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3598{
3599 const std::string descriptorName{"SliceQueueDescriptor"};
3600
3601 ValidateNumInputs(workloadInfo, descriptorName, 1);
3602 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3603
3604 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3605 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3606
3607 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3608
3609 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3610 if (rank > 4)
3611 {
3612 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3613 }
3614
3615 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3616
3617 // Check if m_Begin and m_Size have the expected length
3618 if (m_Parameters.m_Begin.size() != rank)
3619 {
3620 throw InvalidArgumentException(descriptorName +
3621 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3622 }
3623 if (m_Parameters.m_Size.size() != rank)
3624 {
3625 throw InvalidArgumentException(descriptorName +
3626 ": Length of size descriptor must equal rank " + std::to_string(rank));
3627 }
3628
3629 // Check if the shape of the output tensor matches m_Size
3630 const TensorShape& outputShape = outputTensorInfo.GetShape();
3631 for (unsigned int i = 0u; i < rank; ++i)
3632 {
3633 if (m_Parameters.m_Size[i] != outputShape[i])
3634 {
3635 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3636 }
3637 }
3638
3639 // Check if the sum of begin offset and size in a given dimension
3640 // does not exceed the size of corresponding input
3641 const TensorShape& inputShape = inputTensorInfo.GetShape();
3642 for(unsigned int i = 0u; i < rank; ++i)
3643 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003644 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003645 {
3646 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3647 std::to_string(i) + " exceeds input size.");
3648 }
3649 }
3650}
3651
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003652void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3653{
3654 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3655
3656 ValidateNumInputs(workloadInfo, descriptorName, 1);
3657 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3658
3659 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3660 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3661
3662 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3663 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3664
3665 std::vector<DataType> supportedTypes =
3666 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003667 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003668 DataType::Float32,
3669 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003670 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003671 DataType::QAsymmU8,
3672 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003673 };
3674
3675 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3676 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3677
3678 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3679
3680 if (m_Parameters.m_BlockSize == 0)
3681 {
3682 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3683 }
3684
3685 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3686 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3687 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3688 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3689
3690 const TensorShape& outputShape = outputInfo.GetShape();
3691 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3692 {
3693 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3694 "must be divisible by block size.");
3695 }
3696
3697 const TensorShape& inputShape = inputInfo.GetShape();
3698 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3699 {
3700 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3701 "must be divisible by the square of block size." );
3702 }
3703}
3704
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003705void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3706{
3707 const std::string descriptorName{"ComparisonQueueDescriptor"};
3708
3709 ValidateNumInputs(workloadInfo, descriptorName, 2);
3710 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3711
3712 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3713 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3714 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3715
3716 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3717 inputTensorInfo1,
3718 outputTensorInfo,
3719 descriptorName,
3720 "input_0",
3721 "input_1");
3722
3723 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3724 {
3725 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3726 }
3727}
3728
josh minor4a3c6102020-01-06 16:40:46 -06003729void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3730{
3731 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3732
3733 ValidateNumInputs(workloadInfo, descriptorName, 1);
3734 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3735
3736 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3737 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3738
3739 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3740
3741 std::vector<DataType> supportedTypes =
3742 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003743 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003744 DataType::Float16,
3745 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003746 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003747 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003748 DataType::QSymmS16,
3749 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003750 };
3751
James Conroyaba90cd2020-11-06 16:28:18 +00003752 std::vector<DataType> logicalSupportedTypes =
3753 {
3754 DataType::Boolean
3755 };
3756
3757 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3758 {
3759 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3760 }
3761 else
3762 {
3763 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3764 }
3765
3766
josh minor4a3c6102020-01-06 16:40:46 -06003767 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3768}
3769
Finn Williams2605b232020-06-10 15:53:46 +01003770void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3771{
3772 const std::string descriptorName{"RankQueueDescriptor"};
3773
3774 ValidateNumInputs(workloadInfo, descriptorName, 1);
3775 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3776
3777 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3778 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3779
3780 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3781 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3782
3783 std::vector<DataType> supportedTypes =
3784 {
3785 DataType::BFloat16,
3786 DataType::Float16,
3787 DataType::Float32,
3788 DataType::QAsymmS8,
3789 DataType::QAsymmU8,
3790 DataType::QSymmS8,
3791 DataType::QSymmS16,
3792 DataType::Signed32
3793 };
3794
3795 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3796 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3797}
3798
James Conroyaba90cd2020-11-06 16:28:18 +00003799void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3800{
3801 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3802
3803 ValidateNumInputs(workloadInfo, descriptorName, 2);
3804 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3805
3806 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3807 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3808 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3809
3810 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3811 inputTensorInfo1,
3812 outputTensorInfo,
3813 descriptorName,
3814 "input_0",
3815 "input_1");
3816
3817 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3818 {
3819 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3820 }
3821
3822 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3823 {
3824 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3825 }
3826
3827 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3828 {
3829 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3830 }
3831}
3832
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003833void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3834{
3835 const std::string descriptorName{"ReduceQueueDescriptor"};
3836
3837 ValidateNumInputs(workloadInfo, descriptorName, 1);
3838 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3839
3840 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3841 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3842
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003843 std::vector<DataType> supportedTypes =
3844 {
3845 DataType::BFloat16,
3846 DataType::Float16,
3847 DataType::Float32,
3848 DataType::QAsymmS8,
3849 DataType::QAsymmU8,
3850 DataType::QSymmS16,
3851 DataType::Signed32
3852 };
3853
3854 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3855 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3856}
3857
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003858void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3859{
3860 // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm
3861
3862 const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"};
3863
3864 // check dimensions of all inputs and outputs
3865 if (workloadInfo.m_InputTensorInfos.size() != 3)
3866 {
3867 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
3868 }
Mike Kelly12994962022-04-21 11:57:09 +01003869 if (workloadInfo.m_OutputTensorInfos.size() != 3)
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003870 {
3871 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
3872 }
3873
3874 std::vector<DataType> supportedTypes =
3875 {
Mike Kelly12994962022-04-21 11:57:09 +01003876 DataType::Float32,
3877 DataType::QAsymmS8
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003878 };
3879
3880 // check for supported type of one input and match them with all the other input and output
3881 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
3882
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003883 // Making sure clipping parameters have valid values.
3884 // == 0 means no clipping
3885 // > 0 means clipping
3886 if (m_Parameters.m_ClippingThresCell < 0.0f)
3887 {
3888 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
3889 }
3890 if (m_Parameters.m_ClippingThresProj < 0.0f)
3891 {
3892 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
3893 }
3894
3895 unsigned int batchIndx = 0;
3896 unsigned int inputIndx = 1;
3897 uint32_t timeStep = 1;
3898 unsigned int timeIndx = 1;
3899 inputIndx = 2;
3900 if (m_Parameters.m_TimeMajor)
3901 {
3902 batchIndx = 1;
3903 timeIndx = 0;
3904
3905 }
3906 timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx];
3907
3908 // Inferring batch size, number of outputs and number of cells from the inputs.
3909 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx];
3910 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx];
3911 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
3912 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
3913 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
3914 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
3915
3916 // input tensor
3917 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input),
3918 descriptorName + " input_0");
3919 // outputStateInTensor
3920 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
3921 descriptorName + " input_1");
3922 // outputStateInTensor
3923 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
3924 descriptorName + " input_2");
3925
3926 // outputTensor
Mike Kelly12994962022-04-21 11:57:09 +01003927 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 3, (timeStep * n_batch * n_output),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003928 descriptorName + " output_0");
3929
3930 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
3931 if ( m_InputToInputWeights )
3932 {
3933 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
3934 (n_cell * n_input), "InputLayerNormWeights");
3935 }
3936
3937 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
3938 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
3939 (n_cell * n_input), "InputToForgetWeights");
3940
3941 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
3942 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
3943 (n_cell * n_input), "InputToCellWeights");
3944
3945 if ( m_RecurrentToInputWeights )
3946 {
3947 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
3948 (n_cell * n_output), "RecurrentToInputWeights");
3949 }
3950
3951 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
3952 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
3953 (n_cell * n_output), "RecurrentToForgetWeights");
3954
3955 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
3956 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
3957 (n_cell * n_output), "RecurrentToCellWeights");
3958
3959 // Make sure the input-gate's parameters are either both present (regular
3960 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
3961 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
3962 !m_Parameters.m_CifgEnabled) ||
3963 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3964 m_Parameters.m_CifgEnabled));
3965 if (!cifg_weights_all_or_none)
3966 {
3967 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
3968 "RecurrentToInputWeights must either both be present (regular LSTM) "
3969 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
3970 "accordingly.");
3971 }
3972
3973 if ( m_CellToInputWeights )
3974 {
3975 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
3976 n_cell, "CellToInputWeights");
3977 }
3978 if ( m_CellToForgetWeights )
3979 {
3980 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
3981 n_cell, "CellToForgetWeights");
3982 }
3983 if ( m_CellToOutputWeights )
3984 {
3985 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
3986 n_cell, "CellToOutputWeights");
3987 }
3988
3989 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
3990 bool peephole_weights_all_or_none =
3991 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3992 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3993 || ( !m_CellToInputWeights && !m_CellToForgetWeights
3994 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3995 if (!peephole_weights_all_or_none)
3996 {
3997 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
3998 }
3999
4000 // Make sure the input gate bias is present only when not a CIFG-LSTM.
4001 if (m_Parameters.m_CifgEnabled)
4002 {
4003 if (m_InputGateBias)
4004 {
4005 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
4006 }
4007 }
4008 else
4009 {
4010 if (!m_InputGateBias)
4011 {
4012 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
4013 "must be present.");
4014 }
4015 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
4016 n_cell, "InputGateBias");
4017 }
4018
4019 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
4020 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
4021
4022 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
4023 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
4024
4025 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
4026 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
4027
4028 if (m_ProjectionWeights)
4029 {
4030 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
4031 (n_cell * n_output), "ProjectionWeights");
4032 }
4033 if (m_ProjectionBias)
4034 {
4035 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
4036 }
4037
4038 // Making sure the projection tensors are consistent:
4039 // 1) If projection weight is not present, then projection bias should not be
4040 // present.
4041 // 2) If projection weight is present, then projection bias is optional.
4042 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
4043 !m_Parameters.m_ProjectionEnabled)
4044 || (m_ProjectionWeights && !m_ProjectionBias &&
4045 m_Parameters.m_ProjectionEnabled)
4046 || (m_ProjectionWeights && m_ProjectionBias &&
4047 m_Parameters.m_ProjectionEnabled));
4048 if (!projecton_tensors_consistent)
4049 {
4050 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
4051 }
4052
4053 // The four layer normalization weights either all have values or none of them have values. Additionally, if
4054 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
4055 // either all have values or none of them have values. Layer normalization is used when the values of all the
4056 // layer normalization weights are present
4057 if (m_InputLayerNormWeights)
4058 {
4059 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
4060 }
4061 if (m_ForgetLayerNormWeights)
4062 {
4063 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4064 }
4065 if (m_CellLayerNormWeights)
4066 {
4067 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4068 }
4069 if (m_OutputLayerNormWeights)
4070 {
4071 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4072 }
4073
4074 if (m_Parameters.m_LayerNormEnabled)
4075 {
4076 if (!m_Parameters.m_CifgEnabled)
4077 {
4078 if (!m_InputLayerNormWeights)
4079 {
4080 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
4081 "disabled but InputLayerNormWeights are not present");
4082 }
4083 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
4084 1, n_cell, "InputLayerNormWeights");
4085 }
4086 else if (m_InputLayerNormWeights)
4087 {
4088 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
4089 "enabled");
4090 }
4091
4092 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
4093 "ForgetLayerNormWeights");
4094 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4095
4096 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
4097 "OutputLayerNormWeights");
4098 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4099
4100 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
4101 "CellLayerNormWeights");
4102 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4103 }
4104 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
4105 {
4106 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
4107 "normalisation weights are present.");
4108 }
4109}
4110
4111
mathad01df9a3222021-04-28 11:42:57 +01004112} // namespace armnn