blob: 7055092be2d1f432eb9c5506cc74bb309410aa08 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Colm Donelanb4ef1632024-02-01 15:00:43 +00002// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Colm Donelan0c479742021-12-10 12:43:54 +00006#include <armnn/backends/TensorHandle.hpp>
7#include <armnn/backends/WorkloadData.hpp>
8#include <armnn/backends/WorkloadInfo.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00009#include <armnnUtils/DataLayoutIndexed.hpp>
10#include <armnnUtils/TensorUtils.hpp>
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010011#include <armnnUtils/Permute.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010012#include <armnn/utility/NumericCast.hpp>
mathad01df9a3222021-04-28 11:42:57 +010013#include <armnn/Logging.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000014
telsoa014fcda012018-03-09 14:13:49 +000015#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000017#include <string>
18#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000019
James Ward47fce872020-09-10 11:57:28 +010020#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000021
Matteo Martincigh21350152018-11-28 16:22:22 +000022using namespace armnnUtils;
23
telsoa014fcda012018-03-09 14:13:49 +000024namespace armnn
25{
26
27//---------------------------------------------------------------
28DataType GetBiasDataType(DataType inputDataType)
29{
30 switch (inputDataType)
31 {
telsoa01c577f2c2018-08-31 09:22:23 +010032 case DataType::Float16:
33 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000034 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000035 case DataType::Float32:
36 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000037 case DataType::QAsymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +000038 case DataType::QAsymmU8:
Keith Davis5204aa82020-01-27 15:24:59 +000039 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +000040 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010041 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000042 default:
Colm Donelanb4ef1632024-02-01 15:00:43 +000043 throw InvalidArgumentException("GetBiasDataType(): Unsupported data type.");
telsoa014fcda012018-03-09 14:13:49 +000044 }
45}
46
47namespace
48{
49
50//---------------------------------------------------------------
51//android ndk does not support std::to_string function.
52template <typename T>
53std::string to_string(T value)
54{
55 std::ostringstream os;
56 os << value;
57 return os.str();
58}
59
60//---------------------------------------------------------------
61void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
62{
63 if (!ptr)
64 {
65 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
66 paramName + " parameter must be set.");
67 }
68}
69
70//---------------------------------------------------------------
71void ValidateTensorShapesMatch(const TensorInfo& first,
72 const TensorInfo& second,
73 std::string const& descName,
74 std::string const& firstName,
75 std::string const& secondName)
76{
77 if (first.GetShape() != second.GetShape())
78 {
79 throw InvalidArgumentException(descName + ": "
80 + firstName + " & " + secondName + " must have identical shapes");
81 }
82}
83
84//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010085void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000086{
Sadik Armaganeff363d2019-04-05 15:25:46 +010087 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000088 {
89 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010090 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000091 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
92 }
93}
94
95//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010096void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000097{
Sadik Armaganeff363d2019-04-05 15:25:46 +010098 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000099 {
100 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100101 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000102 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
103 }
104}
105
106//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000107
108//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100109void ValidateTensorNumElements(const TensorInfo& tensor,
110 std::string const& descName,
111 unsigned int numElements,
112 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100113{
114 if (tensor.GetNumElements() != numElements)
115 {
116 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100117 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100118 tensorName + " tensor.");
119 }
120}
121
122//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000123void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
124 const std::string& descName, std::string const& tensorName)
125{
126 if (tensor.GetDataType() != dataType)
127 {
128 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
129 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
130 }
131}
132
Derek Lambertid466a542020-01-22 15:37:29 +0000133void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
134{
Jan Eilers1b2654f2021-09-24 15:45:46 +0100135 if (tensor.GetDataType() != DataType::QSymmS8)
Derek Lambertid466a542020-01-22 15:37:29 +0000136 {
137 throw InvalidArgumentException(descName +
138 ": Expected data type which supports per-axis quantization scheme but got " +
139 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
140 }
Derek Lambertid466a542020-01-22 15:37:29 +0000141}
142
telsoa014fcda012018-03-09 14:13:49 +0000143//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100144void ValidateTensorQuantizationSpace(const TensorInfo& first,
145 const TensorInfo& second,
146 const std::string& descName,
147 std::string const& firstName,
148 std::string const& secondName)
149{
150 if (!first.IsQuantized() ||
151 !second.IsQuantized())
152 {
153 // Not a quantized type, ignore the validation
154 return;
155 }
156
157 DataType firstDataType = first.GetDataType();
158 DataType secondDataType = second.GetDataType();
159
160 if (firstDataType != secondDataType)
161 {
162 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
163 " must be of the same quantized type, " +
164 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
165 secondName + " is " + GetDataTypeName(secondDataType));
166 }
167
168 if (!first.IsTypeSpaceMatch(second))
169 {
170 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
171 " must have the same quantization space, " +
172 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
173 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
174 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
175 " and scale " + to_string(second.GetQuantizationScale()));
176 }
177}
178
179//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100180void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100181 const TensorInfo& weightsTensorInfo,
182 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000183{
184 if (biasTensor.GetQuantizationOffset() != 0)
185 {
186 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
187 to_string(biasTensor.GetQuantizationOffset()));
188 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000189
James Conroy8502ade2020-11-12 19:26:29 +0000190 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000191 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000192 // Validate per-axis quantization scales
193 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
194 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
195
196 if (weightScales.size() != biasScales.size())
197 {
198 std::stringstream msg;
James Conroy8502ade2020-11-12 19:26:29 +0000199 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
200 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
201 << ", biases=" << biasScales.size();
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000202 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
203 }
telsoa014fcda012018-03-09 14:13:49 +0000204 }
205}
206
207//---------------------------------------------------------------
208void ValidateTensors(const std::vector<ITensorHandle*>& vec,
Teresa Charlin79a06a52023-07-13 17:16:45 +0100209 unsigned int numExpected,
210 const std::string& descName,
211 const std::string& varName)
telsoa014fcda012018-03-09 14:13:49 +0000212{
213 if (vec.empty() && numExpected > 0)
214 {
215 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
216 }
217
218 for (unsigned int i = 0; i < numExpected; ++i)
219 {
220 if (!vec[i])
221 {
222 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
223 }
224 }
225}
226
227//---------------------------------------------------------------
228void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
229 const TensorInfo& second,
230 const TensorInfo& output,
231 std::string const& descName,
232 std::string const& firstName,
233 std::string const& secondName)
234{
235 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
236 // broadcasted.
Colm Donelan02300aa2024-04-04 11:20:29 +0100237 // NOTE: This check is dependent on the AddBroadcastReshapeLayerImpl optimization having been applied to the layer.
telsoa014fcda012018-03-09 14:13:49 +0000238 if (first.GetNumDimensions() != second.GetNumDimensions())
239 {
240 throw InvalidArgumentException(descName + ": Tensors "
241 + firstName + " & " + secondName
242 + " must have the same number of dimensions in order to be broadcasted");
243 }
244 uint32_t numDims = first.GetNumDimensions();
245 std::vector<uint32_t> outputDims(numDims, 0u);
246 for (uint32_t i = 0; i < numDims; i++)
247 {
248 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
249 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
250 if (dimsNotEqual && dimsNotOne)
251 {
252 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
253 }
254 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
255 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100256 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000257 if (broadcastShape != output.GetShape())
258 {
259 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
260 + firstName + " & " + secondName
261 + " does not match the output shape");
262 }
263}
264
265//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100266void ValidateDataTypes(const TensorInfo& info,
267 const std::vector<armnn::DataType>& supportedTypes,
268 std::string const& descName)
269{
270 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
271 if (iterator == supportedTypes.end())
272 {
Colm Donelan02300aa2024-04-04 11:20:29 +0100273 throw InvalidArgumentException(descName + ": " + " Tensor type " + GetDataTypeName(info.GetDataType()) +
274 " is not supported.");
Sadik Armaganeff363d2019-04-05 15:25:46 +0100275 }
276}
277
James Conroy4d1ff582019-06-10 17:06:39 +0100278//---------------------------------------------------------------
279void ValidateTensorDataTypesMatch(const TensorInfo& first,
280 const TensorInfo& second,
281 std::string const& descName,
282 std::string const& firstName,
283 std::string const& secondName)
284{
285 if (first.GetDataType() != second.GetDataType())
286 {
287 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
288 " must have identical data types.");
289 }
290}
291
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100292//---------------------------------------------------------------
293void ValidateTensorNumElementsMatch(const TensorInfo& first,
294 const TensorInfo& second,
295 std::string const& descName,
296 std::string const& firstName,
297 std::string const& secondName)
298{
299 if (first.GetNumElements() != second.GetNumElements())
300 {
301 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
302 " must have the same number of elements.");
303 }
304}
305
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000306void ValidateWeightDataType(const TensorInfo& inputInfo,
307 const TensorInfo& weightInfo,
308 const std::string& descName)
309{
310 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000311 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000312 {
313 const std::vector<DataType> validTypes =
314 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000315 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100316 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +0100317 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000318 };
319
320 ValidateDataTypes(weightInfo, validTypes, descName);
321 }
322 else
323 {
324 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
325 }
326}
327
328void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
329 const std::string& descName,
330 const std::string& tensorName)
331{
332 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
333 if (!quantizationDim.has_value())
334 {
James Ward47fce872020-09-10 11:57:28 +0100335 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
336 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000337 }
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000338}
339
340void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
341 const std::string& descName,
342 const std::string& tensorName)
343{
344 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
345 if (quantizationOffset != 0)
346 {
James Ward47fce872020-09-10 11:57:28 +0100347 throw InvalidArgumentException(fmt::format(
348 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
349 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000350 }
351}
352
353void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
354 const TensorInfo& outputInfo,
355 const TensorInfo& weightInfo,
356 const Optional<TensorInfo>& optionalBiasInfo,
357 const std::string& descName)
358{
359 if (weightInfo.HasPerAxisQuantization())
360 {
361 const DataType inputDataType = inputInfo.GetDataType();
362 const DataType outputDataType = outputInfo.GetDataType();
363
Keith Davis0c2eeac2020-02-11 16:51:50 +0000364 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000365
366 if (!canHavePerAxisQuantization)
367 {
James Ward47fce872020-09-10 11:57:28 +0100368 throw InvalidArgumentException(fmt::format(
369 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
370 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000371 }
372
Derek Lambertid466a542020-01-22 15:37:29 +0000373
374 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000375 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
376 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
377
378 if (optionalBiasInfo.has_value())
379 {
380 const TensorInfo& biasInfo = optionalBiasInfo.value();
381 if (!biasInfo.HasPerAxisQuantization())
382 {
James Ward47fce872020-09-10 11:57:28 +0100383 throw InvalidArgumentException(fmt::format(
384 "{}: Per-axis quantization parameters not set on bias tensor, "
385 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000386 }
387
388 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
389 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
390 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
391 }
392 }
393}
394
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100395} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000396
Mike Kelly80512b02022-05-16 23:10:42 +0100397//---------------------------------------------------------------
398void QueueDescriptor::ValidateTensorNumDimensions(const TensorInfo& tensor,
399 std::string const& descName,
400 unsigned int numDimensions,
401 std::string const& tensorName) const
402{
403 // If we're allowing expanded dimensions then numDimensions becomes the minimum number of Dimensions we can allow.
404 // Throw an Exception if the tensors has fewer than numDimensions or if the squeezed dimensions are greater than
405 // numDimensions.
406 if (m_AllowExpandedDims)
407 {
408 unsigned int squeezedDims = 0;
409
410 for (unsigned int i = 0; i < tensor.GetNumDimensions(); ++i)
411 {
412 if (tensor.GetShape()[i] != 1)
413 {
414 ++squeezedDims;
415 }
416 }
417 if (tensor.GetNumDimensions() < numDimensions || squeezedDims > numDimensions)
418 {
419 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " or less but got " +
420 to_string(tensor.GetNumDimensions()) + " dimensions for " +
421 tensorName + " tensor.");
422 }
423 }
424 else
425 {
426 if (tensor.GetNumDimensions() != numDimensions)
427 {
428 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
429 to_string(tensor.GetNumDimensions()) + " dimensions for " +
430 tensorName + " tensor.");
431 }
432 }
433}
434
435//---------------------------------------------------------------
436void QueueDescriptor::ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Teresa Charlin79a06a52023-07-13 17:16:45 +0100437 unsigned int numDimension,
438 unsigned int numElements,
439 std::string const& tensorName) const
Mike Kelly80512b02022-05-16 23:10:42 +0100440{
441 const std::string functionName{"ValidateTensorNumDimNumElem"};
442 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
443 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
444}
445
446//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000447void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
448 unsigned int numExpectedIn, unsigned int numExpectedOut) const
449{
450 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
451 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
452}
453
454//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100455void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
456{
457 const std::string descriptorName{"MapQueueDescriptor"};
458
459 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100460 ValidateNumOutputs(workloadInfo, descriptorName, 0);
461
462 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
463 {
464 if (!m_Inputs[i])
465 {
466 throw InvalidArgumentException(
467 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
468 }
469 }
470}
471
472//---------------------------------------------------------------
473void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
474{
475 const std::string descriptorName{"UnmapQueueDescriptor"};
476
477 ValidateNumInputs(workloadInfo, descriptorName, 1);
478 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100479
480 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
481 {
482 if (!m_Inputs[i])
483 {
484 throw InvalidArgumentException(
485 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
486 }
487 }
488}
489
490//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000491void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
492{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100493 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000494
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100495 ValidateNumInputs(workloadInfo, descriptorName, 1);
496 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000497
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100498 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
499 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
500
501 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
502 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000503
504 if (m_Inputs.size() != m_Outputs.size())
505 {
James Ward47fce872020-09-10 11:57:28 +0100506 throw InvalidArgumentException(fmt::format(
507 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
508 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000509 }
510
511 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
512 {
513 if (!m_Inputs[i])
514 {
James Ward47fce872020-09-10 11:57:28 +0100515 throw InvalidArgumentException(fmt::format(
516 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000517 }
518
519 if (!m_Outputs[i])
520 {
James Ward47fce872020-09-10 11:57:28 +0100521 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000522 }
523 }
524}
525
Derek Lambertif674aa02019-08-01 15:56:25 +0100526//---------------------------------------------------------------
527void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
528{
529 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
530 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
531
532 if (workloadInfo.m_InputTensorInfos.size() != 1)
533 {
James Ward47fce872020-09-10 11:57:28 +0100534 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
535 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100536
537 }
538
539 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
540 {
James Ward47fce872020-09-10 11:57:28 +0100541 throw InvalidArgumentException(fmt::format(
542 "Number of input infos ({0}) does not match the number of output infos ({1})",
543 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100544 }
545
546 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
547 {
548 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
549 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
550 {
James Ward47fce872020-09-10 11:57:28 +0100551 throw InvalidArgumentException(fmt::format(
552 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100553 }
554 }
555
556 if (m_Inputs.size() != 1)
557 {
James Ward47fce872020-09-10 11:57:28 +0100558 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100559 }
560
561 if (m_Inputs.size() != m_Outputs.size())
562 {
James Ward47fce872020-09-10 11:57:28 +0100563 throw InvalidArgumentException(fmt::format(
564 "Number of inputs ({0}) does not match the number of outputs ({1})",
565 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100566 }
567
568 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
569 {
570 if (!m_Inputs[i])
571 {
James Ward47fce872020-09-10 11:57:28 +0100572 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100573 }
574
575 if (!m_Outputs[i])
576 {
James Ward47fce872020-09-10 11:57:28 +0100577 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100578 }
579 }
580}
581
582//---------------------------------------------------------------
583void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
584{
585 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
Derek Lambertif674aa02019-08-01 15:56:25 +0100586
Derek Lambertif674aa02019-08-01 15:56:25 +0100587 if (m_Inputs.size() != 1)
588 {
James Ward47fce872020-09-10 11:57:28 +0100589 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100590 }
591
592 if (m_Outputs.size() != 0)
593 {
James Ward47fce872020-09-10 11:57:28 +0100594 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100595 }
596
597 if (!m_Inputs[0])
598 {
James Ward47fce872020-09-10 11:57:28 +0100599 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100600 }
601}
602
603//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000604void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
605{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100606 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100607
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100608 ValidateNumInputs(workloadInfo, descriptorName, 1);
609 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100610
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100611 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
612 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100613
614 std::vector<DataType> supportedTypes =
615 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000616 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100617 DataType::Float16,
618 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000619 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000620 DataType::QAsymmU8,
621 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100622 };
623
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100624 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
625 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
626 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000627}
628
Nikhil Rajee391d52019-09-05 17:50:44 +0100629void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
630{
631 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
632
633 ValidateNumInputs(workloadInfo, descriptorName, 1);
634 ValidateNumOutputs(workloadInfo, descriptorName, 1);
635
636 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
637 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
638
Inki Daed4619e22020-09-10 15:33:54 +0900639 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
640 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100641 {
Inki Daed4619e22020-09-10 15:33:54 +0900642 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100643 }
644
James Conroyd47a0642019-09-17 14:22:06 +0100645 std::vector<DataType> supportedInputTypes =
646 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000647 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100648 DataType::Float16,
649 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100650 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000651 DataType::QAsymmU8,
652 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900653 DataType::Signed32,
654 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100655 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100656
James Conroyd47a0642019-09-17 14:22:06 +0100657 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100658
659 auto inputShape = inputTensorInfo.GetShape();
660 auto outputShape = outputTensorInfo.GetShape();
661
662 auto inputNumDimensions = inputShape.GetNumDimensions();
663 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
664
665 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
666
667 // 1D input shape results in scalar output shape
668 if (inputShape.GetNumDimensions() == 1)
669 {
670 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
671 {
672 throw InvalidArgumentException(descriptorName + outputShapeError);
673 }
674 }
675 else
676 {
677 for (unsigned int i = 0; i < unsignedAxis; ++i)
678 {
679 if (outputShape[i] != inputShape[i])
680 {
681 throw InvalidArgumentException(descriptorName + outputShapeError);
682 }
683 }
684
685 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
686 {
687 if (outputShape[i - 1] != inputShape[i])
688 {
689 throw InvalidArgumentException(descriptorName + outputShapeError);
690 }
691 }
692 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100693}
694
mathad01b392e982021-04-07 12:07:30 +0100695void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
696{
697 const std::string descriptorName{"CastQueueDescriptor"};
698
699 ValidateNumInputs(workloadInfo, descriptorName, 1);
700 ValidateNumOutputs(workloadInfo, descriptorName, 1);
701
702 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
703 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
704
705 std::vector<DataType> supportedTypes =
706 {
707 DataType::BFloat16,
708 DataType::Float16,
709 DataType::Float32,
710 DataType::QAsymmS8,
711 DataType::QAsymmU8,
712 DataType::QSymmS8,
713 DataType::QSymmS16,
714 DataType::Signed32,
Colm Donelan02300aa2024-04-04 11:20:29 +0100715 DataType::Signed64,
716 DataType::Boolean
mathad01b392e982021-04-07 12:07:30 +0100717 };
718
719 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
720 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
721}
722
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100723void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
724{
725 const std::string descriptorName{"SoftmaxQueueDescriptor"};
726
727 ValidateNumInputs(workloadInfo, descriptorName, 1);
728 ValidateNumOutputs(workloadInfo, descriptorName, 1);
729
730 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
731 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
732
733 std::vector<DataType> supportedTypes =
734 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000735 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100736 DataType::Float16,
737 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000738 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000739 DataType::QAsymmU8,
740 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100741 };
742
743 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
744 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
745 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
746}
747
telsoa014fcda012018-03-09 14:13:49 +0000748void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
749{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100750 const std::string descriptorName{"SplitterQueueDescriptor"};
751
752 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000753
Ruomei Yan25339c32019-05-28 16:48:20 +0100754 // Check the supported data types
755 std::vector<DataType> supportedTypes =
756 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000757 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100758 DataType::Float32,
759 DataType::Float16,
760 DataType::Boolean,
761 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100762 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000763 DataType::QAsymmU8,
764 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100765 };
766
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100767 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
768 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100769 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100770 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
771 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
772
773 const std::string outputName = "output_" + std::to_string(i);
774 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100775 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100776
telsoa014fcda012018-03-09 14:13:49 +0000777 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
778 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100779 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000780 }
781
782 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
783 {
784 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100785 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000786 "has to match number of workloadInfo.m_OutputTensorInfos. "
787 "Number of windows: " +
788 to_string(m_ViewOrigins.size()) +
789 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
790 }
791
telsoa01c577f2c2018-08-31 09:22:23 +0100792 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000793 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
794 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
795 {
telsoa01c577f2c2018-08-31 09:22:23 +0100796 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000797 ViewOrigin const& e = m_ViewOrigins[w];
798 if (e.m_Origin.size() != inputDims)
799 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100800 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000801 "have the same dimensionality as the input tensor. "
802 "Window origin (index: " +
803 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
804 " dimensions, the input "
805 "tensor has " +
806 to_string(inputDims) + " dimensions.");
807 }
808 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
809 {
810 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
811 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
812 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100813 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000814 "be smaller or equal than the size of the input in that coord.");
815 }
816 }
817 }
818}
819
Jim Flynne242f2d2019-05-22 14:24:13 +0100820void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000821{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100822 const std::string descriptorName{"ConcatQueueDescriptor"};
823
824 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000825
826 if (m_Inputs.size() <= 0)
827 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100828 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000829 }
830 if (m_Outputs.size() <= 0)
831 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100832 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000833 }
834
835 if (workloadInfo.m_InputTensorInfos.size() <= 0)
836 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100837 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000838 }
839 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
840 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100841 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000842 }
843
Nikhil Raj8599a412018-11-19 14:51:07 +0000844 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
845 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100846 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000847 }
848
849 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
850 {
851 return;
852 }
853
telsoa014fcda012018-03-09 14:13:49 +0000854 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
855 {
856 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100857 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000858 "has to match number of workloadInfo.m_InputTensorInfos. "
859 "Number of windows: " +
860 to_string(m_ViewOrigins.size()) +
861 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
862 }
863
telsoa01c577f2c2018-08-31 09:22:23 +0100864 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000865 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
866 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
867 {
telsoa01c577f2c2018-08-31 09:22:23 +0100868 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000869 ViewOrigin const& e = m_ViewOrigins[w];
870 if (e.m_Origin.size() != outputDims)
871 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100872 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000873 "have the same dimensionality as the output tensor. "
874 "Window origin (index: " +
875 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
876 " dimensions, the output "
877 "tensor has " +
878 to_string(outputDims) + " dimensions.");
879 }
telsoa01c577f2c2018-08-31 09:22:23 +0100880 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000881 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
882 {
883 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
884 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
885 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100886 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000887 "be smaller or equal than the size of the output in that coord.");
888 }
889 }
890 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100891
892 // Check the supported data types
893 std::vector<DataType> supportedTypes =
894 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000895 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100896 DataType::Float32,
897 DataType::Float16,
898 DataType::Boolean,
899 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100900 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000901 DataType::QAsymmU8,
902 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100903 };
904
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100905 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
906 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100907 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100908 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
909 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
910
911 const std::string inputName = "input_" + std::to_string(i);
912 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100913 }
telsoa014fcda012018-03-09 14:13:49 +0000914}
915
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100916void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
917{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100918 const std::string descriptorName{"StackQueueDescriptor"};
919
920 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100921
922 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
923 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100924 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100925 }
926
927 // All inputs must have the same shape, which is defined in parameters
928 const TensorShape& inputShape = m_Parameters.m_InputShape;
929 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
930 {
931 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
932 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100933 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100934 }
935 }
936
Matthew Jacksondba634f2019-08-15 15:14:18 +0100937 if (inputShape.GetNumDimensions() > 4)
938 {
939 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
940 }
941
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100942 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
943 // since the output tensor has an additional dimension.
944 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
945 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100946 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100947 "than the number of input dimensions.");
948 }
949
950 // Output shape must be as inferred from the input shape
951 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
952 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
953 {
954 if (outputShape[i] != inputShape[i])
955 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100956 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100957 "match shape inferred from input tensor.");
958 }
959 }
960
961 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
962 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100963 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100964 "match shape inferred from input tensor.");
965 }
966
967 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
968 {
969 if (outputShape[i] != inputShape[i-1])
970 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100971 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100972 "match shape inferred from input tensor.");
973 }
974 }
975
Matthew Jacksondba634f2019-08-15 15:14:18 +0100976 if (outputShape.GetNumDimensions() > 5)
977 {
978 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
979 }
980
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100981 // Check the supported data types
982 std::vector<DataType> supportedTypes =
983 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000984 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100985 DataType::Float32,
986 DataType::Float16,
987 DataType::Boolean,
988 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100989 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000990 DataType::QAsymmU8,
991 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100992 };
993
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100994 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100995
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100996 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100997 {
998 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
999 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001000 descriptorName,
1001 "input_0",
1002 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001003 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001004
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001005 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1006 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001007 descriptorName,
1008 "input_0",
1009 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001010}
1011
Ryan OSheaec6c6802020-06-05 17:17:06 +01001012void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1013{
1014 const std::string descriptorName{"FillQueueDescriptor"};
1015
1016 ValidateNumInputs(workloadInfo, descriptorName, 1);
1017 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1018
1019 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1020 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1021
1022 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1023
1024 std::vector<DataType> supportedTypes =
1025 {
1026 DataType::BFloat16,
1027 DataType::Float32,
1028 DataType::Float16,
1029 DataType::Signed32
1030 };
1031
1032 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1033}
1034
telsoa014fcda012018-03-09 14:13:49 +00001035void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1036{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001037 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001038
Matthew Sloyan81beae32021-07-13 19:46:11 +01001039 uint32_t numInputs = 2;
1040 if (m_Parameters.m_BiasEnabled)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001041 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001042 numInputs = 3;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001043 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001044
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001045 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001046 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1047
1048 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1049 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1050
1051 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1052
1053 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001054 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001055 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001056 }
1057
Matthew Sloyan81beae32021-07-13 19:46:11 +01001058 TensorInfo weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001059 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001060
1061 if (m_Parameters.m_BiasEnabled)
1062 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001063 TensorInfo biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
telsoa01c577f2c2018-08-31 09:22:23 +01001064 // Validates type and quantization values.
Ryan OSheaf183acd2023-07-06 11:41:25 +01001065 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001066 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1067 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001068 }
1069
Francis Murtagh46c09d02019-05-28 08:15:28 +01001070 // Check the supported data types
1071 std::vector<DataType> supportedTypes =
1072 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001073 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001074 DataType::Float32,
1075 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001076 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001077 DataType::QAsymmU8,
1078 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001079 };
1080
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001081 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001082
1083 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1084 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1085 {
1086 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1087 {
1088 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1089 "for BFloat16 input.");
1090 }
1091 }
1092 else
1093 {
1094 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1095 }
telsoa014fcda012018-03-09 14:13:49 +00001096}
1097
Teresa Charlin9145e382023-08-17 18:44:58 +01001098void FusedQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
1099{
1100 // This is internally generated, so it should not need validation.
1101}
1102
telsoa014fcda012018-03-09 14:13:49 +00001103void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1104{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001105 const std::string descriptorName{"NormalizationQueueDescriptor"};
1106
1107 ValidateNumInputs(workloadInfo, descriptorName, 1);
1108 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1109
1110 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1111 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001112
1113 // Check the supported data types
1114 std::vector<DataType> supportedTypes =
1115 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001116 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001117 DataType::Float16,
1118 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001119 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001120 DataType::QAsymmU8,
1121 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001122 };
1123
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001124 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001125
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001126 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001127
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001128 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001129}
1130
1131void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1132{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001133 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001134
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001135 ValidateNumInputs(workloadInfo, descriptorName, 2);
1136 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1137
1138 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1139 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1140 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1141
1142 std::vector<DataType> supportedTypes =
1143 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001144 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001145 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001146 DataType::Float16,
1147 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001148 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001149 DataType::QSymmS16,
1150 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001151 };
1152
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001153 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1154 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1155 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001156
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001157 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1158 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001159
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001160 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1161 inputTensorInfo1,
1162 outputTensorInfo,
1163 descriptorName,
1164 "input_0",
1165 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001166}
1167
telsoa014fcda012018-03-09 14:13:49 +00001168void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1169{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001170 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001171
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001172 ValidateNumInputs(workloadInfo, descriptorName, 2);
1173 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1174
1175 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1176 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1177 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1178
1179 std::vector<DataType> supportedTypes =
1180 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001181 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001182 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001183 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001184 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001185 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001186 DataType::QSymmS16,
1187 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001188 };
1189
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001190 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1191 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1192 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001193
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001194 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1195 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001196
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001197 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1198 inputTensorInfo1,
1199 outputTensorInfo,
1200 descriptorName,
1201 "input_0",
1202 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001203}
1204
1205void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1206{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001207 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001208
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001209 ValidateNumInputs(workloadInfo, descriptorName, 1);
1210 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1211
1212 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1213 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001214
1215 std::vector<DataType> supportedTypes =
1216 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001217 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001218 DataType::Float16,
1219 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001220 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001221 DataType::QAsymmU8,
1222 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001223 };
1224
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001225 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1226 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001227
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001228 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001229 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001230
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001231 ValidatePointer(m_Mean, descriptorName, "mean");
1232 ValidatePointer(m_Variance, descriptorName, "variance");
1233 ValidatePointer(m_Beta, descriptorName, "beta");
1234 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001235
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001236 const TensorInfo& mean = m_Mean->GetTensorInfo();
1237 const TensorInfo& variance = m_Variance->GetTensorInfo();
1238 const TensorInfo& beta = m_Beta->GetTensorInfo();
1239 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001240
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001241 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1242 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1243 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1244 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001245
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001246 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1247 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1248 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001249}
1250
1251void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1252{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001253 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001254
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001255 uint32_t numInputs = 2;
1256 if (m_Parameters.m_BiasEnabled)
1257 {
1258 numInputs = 3;
1259 }
1260
1261 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001262 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001263
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001264 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1265 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001266
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001267 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1268 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001269
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001270 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
telsoa014fcda012018-03-09 14:13:49 +00001271
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001272 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001273
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001274 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001275
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001276 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001277 if (m_Parameters.m_BiasEnabled)
1278 {
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001279 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001280 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001281
1282 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Ryan OSheaf183acd2023-07-06 11:41:25 +01001283 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001284 }
1285
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001286 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1287 {
1288 throw InvalidArgumentException(
1289 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1290 "cannot be either negative or 0.",
1291 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1292 }
1293
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001294 ValidatePerAxisQuantization(inputTensorInfo,
1295 outputTensorInfo,
1296 weightTensorInfo,
1297 optionalBiasTensorInfo,
1298 descriptorName);
1299
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001300 std::vector<DataType> supportedTypes =
1301 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001302 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001303 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001304 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001305 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001306 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001307 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001308 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001309 };
1310
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001311 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001312
1313 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1314 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1315 {
1316 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1317 {
1318 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1319 "for BFloat16 input.");
1320 }
1321 }
1322 else
1323 {
1324 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1325 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001326}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001327
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001328void Convolution3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1329{
1330 const std::string descriptorName{"Convolution3dQueueDescriptor"};
1331
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001332 uint32_t numInputs = 2;
1333 if (m_Parameters.m_BiasEnabled)
1334 {
1335 numInputs = 3;
1336 }
1337 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001338 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1339
1340 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1341 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1342
1343 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1344 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1345
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001346 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001347 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 5, "weight");
1348
1349 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1350
1351 Optional<TensorInfo> optionalBiasTensorInfo;
1352 if (m_Parameters.m_BiasEnabled)
1353 {
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001354 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001355 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1356
1357 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Ryan OSheaf183acd2023-07-06 11:41:25 +01001358 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001359 }
1360
1361 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 || m_Parameters.m_StrideZ <= 0 )
1362 {
1363 throw InvalidArgumentException(
1364 fmt::format("{}: strideX (provided {}), strideY (provided {}) or strideZ (provided {})"
1365 "cannot be either negative or 0.",
1366 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY, m_Parameters.m_StrideZ));
1367 }
1368
1369 ValidatePerAxisQuantization(inputTensorInfo,
1370 outputTensorInfo,
1371 weightTensorInfo,
1372 optionalBiasTensorInfo,
1373 descriptorName);
1374
1375 std::vector<DataType> supportedTypes =
1376 {
1377 DataType::BFloat16,
1378 DataType::Float16,
1379 DataType::Float32,
1380 DataType::QAsymmS8,
1381 DataType::QAsymmU8,
1382 DataType::QSymmS16,
1383 DataType::QSymmS8
1384 };
1385
1386 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1387 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1388}
1389
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001390void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1391{
1392 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1393
Cathal Corbett06902652022-04-14 17:55:11 +01001394 uint32_t numInputs = 2;
1395 if (m_Parameters.m_BiasEnabled)
1396 {
1397 numInputs = 3;
1398 }
1399
1400 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001401 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1402
1403 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1404 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1405
1406 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1407 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1408
Cathal Corbett06902652022-04-14 17:55:11 +01001409 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001410 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1411
1412 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1413 {
1414 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001415 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1416 "cannot be smaller than 1.",
1417 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001418 }
1419
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001420 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1421 {
1422 throw InvalidArgumentException(
1423 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1424 "cannot be either negative or 0.",
1425 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1426 }
1427
Jan Eilers53ef7952021-06-02 12:01:25 +01001428 if (weightTensorInfo.GetShape()[0] != 1)
1429 {
1430 throw InvalidArgumentException(fmt::format(
1431 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1432 "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
1433 descriptorName,
1434 weightTensorInfo.GetShape()[0],
1435 weightTensorInfo.GetShape()[1],
1436 weightTensorInfo.GetShape()[2],
1437 weightTensorInfo.GetShape()[3]));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001438 }
1439
Cathal Corbett4b19d222022-05-11 20:12:17 +01001440 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1441 const unsigned int numWeightOutputChannelsRefFormat = weightTensorInfo.GetShape()[3];
1442 const unsigned int numWeightOutputChannelsAclFormat = weightTensorInfo.GetShape()[1];
1443 const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1444
1445 // Weights format has two valid options: [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] (CpuAcc/GpuAcc).
1446 bool validRefFormat = (numWeightOutputChannelsRefFormat == numOutputChannels);
1447 bool validAclFormat = (numWeightOutputChannelsAclFormat == numOutputChannels);
1448
1449 if (!(validRefFormat || validAclFormat))
1450 {
1451 throw InvalidArgumentException(fmt::format(
1452 "{0}: The weight format in armnn is expected to be [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] "
1453 "(CpuAcc/GpuAcc). But neither the 4th (CpuRef) or 2nd (CpuAcc/GpuAcc) dimension is equal to Cout."
1454 "Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
1455 descriptorName,
1456 numOutputChannels,
1457 weightTensorInfo.GetShape()[0],
1458 weightTensorInfo.GetShape()[1],
1459 weightTensorInfo.GetShape()[2],
1460 weightTensorInfo.GetShape()[3]));
1461 }
1462
Teresa Charlind8df0262019-11-11 12:28:15 +00001463 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001464
Teresa Charlind8df0262019-11-11 12:28:15 +00001465 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001466 if (m_Parameters.m_BiasEnabled)
1467 {
Cathal Corbett06902652022-04-14 17:55:11 +01001468 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Teresa Charlind8df0262019-11-11 12:28:15 +00001469 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001470
Ryan OSheaf183acd2023-07-06 11:41:25 +01001471 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001472 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1473 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001474 ValidatePerAxisQuantization(inputTensorInfo,
1475 outputTensorInfo,
1476 weightTensorInfo,
1477 optionalBiasTensorInfo,
1478 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001479
1480 std::vector<DataType> supportedTypes =
1481 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001482 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001483 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001484 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001485 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001486 DataType::QAsymmU8,
1487 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001488 };
1489
1490 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1491 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001492}
1493
1494void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1495{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001496 const std::string descriptorName{"PermuteQueueDescriptor"};
1497
1498 ValidateNumInputs(workloadInfo, descriptorName, 1);
1499 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001500
1501 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1502
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001503 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1504 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001505
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001506 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1507 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001508
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001509 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001510 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001511 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001512 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001513 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1514 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1515 "must match dst dimension " + to_string(mapping[i]) +
1516 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001517 }
1518 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001519
1520 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001521}
1522
1523void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1524{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001525 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001526
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001527 ValidateNumInputs(workloadInfo, descriptorName, 1);
1528 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1529
1530 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1531 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1532
1533 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1534 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001535
1536 std::vector<DataType> supportedTypes =
1537 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001538 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001539 DataType::Float32,
1540 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001541 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001542 DataType::QAsymmU8,
1543 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001544 };
1545
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001546 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1547 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001548}
1549
Tamás Nyíri7b885b32021-10-26 14:47:57 +01001550void Pooling3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1551{
1552 const std::string descriptorName{"Pooling3dQueueDescriptor"};
1553
1554 ValidateNumInputs(workloadInfo, descriptorName, 1);
1555 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1556
1557 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1558 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1559
1560 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1561 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1562
1563 std::vector<DataType> supportedTypes =
1564 {
1565 DataType::BFloat16,
1566 DataType::Float32,
1567 DataType::Float16,
1568 DataType::QAsymmS8,
1569 DataType::QAsymmU8,
1570 DataType::QSymmS16
1571 };
1572
1573 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1574 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1575}
1576
Teresa Charlin970f43b2019-07-01 13:51:07 +01001577void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1578{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001579 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001580
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001581 ValidateNumInputs(workloadInfo, descriptorName, 1);
1582 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1583
1584 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1585 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1586
1587 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1588 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001589
1590 std::vector<DataType> supportedTypes =
1591 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001592 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001593 DataType::Float16,
1594 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001595 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001596 DataType::QAsymmU8,
Teresa Charlince655882023-11-21 15:44:13 +00001597 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001598 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001599 };
1600
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001601 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1602 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001603
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001604 // Resize only changes width and height: batch and channel count must match.
1605 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1606 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001607 if (inputBatchSize != outputBatchSize)
1608 {
1609 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001610 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1611 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001612 }
1613
1614 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001615 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1616 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001617 if (inputChannelCount != outputChannelCount)
1618 {
1619 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001620 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1621 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001622 }
1623}
1624
Teresa Charlin79a06a52023-07-13 17:16:45 +01001625void ReverseV2QueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
1626{
Tianle Cheng988354d2023-06-28 13:20:47 +01001627 const std::string descriptorName{"ReverseV2QueueDescriptor"};
1628
Tracy Narinebb8d7592023-07-13 16:50:54 +01001629 // Backend restriction
1630 const unsigned int maxDimensions = 4;
1631
1632 ValidateNumInputs(workloadInfo, descriptorName, 2);
Tianle Cheng988354d2023-06-28 13:20:47 +01001633 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1634
1635 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
Tracy Narinebb8d7592023-07-13 16:50:54 +01001636 const TensorInfo& axisTensorInfo = workloadInfo.m_InputTensorInfos[1];
Tianle Cheng988354d2023-06-28 13:20:47 +01001637 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1638
Tracy Narinebb8d7592023-07-13 16:50:54 +01001639 const auto inputTensorNumDimensions = inputTensorInfo.GetNumDimensions();
1640 if (inputTensorNumDimensions > maxDimensions)
Tianle Cheng988354d2023-06-28 13:20:47 +01001641 {
1642 throw InvalidArgumentException(descriptorName +
1643 ": Input tensors with rank greater than " +
Tracy Narinebb8d7592023-07-13 16:50:54 +01001644 std::to_string(maxDimensions) + " are not supported.");
1645 }
1646
1647 const auto axisTensorNumDimensions = axisTensorInfo.GetNumDimensions();
1648 if (axisTensorNumDimensions > maxDimensions)
1649 {
1650 throw InvalidArgumentException(descriptorName +
1651 ": More than " + std::to_string(maxDimensions) + " axes cannot be specified.");
1652 }
1653
1654 if (axisTensorNumDimensions > inputTensorNumDimensions)
1655 {
1656 throw InvalidArgumentException(descriptorName +
1657 ": More axes specified than the number of axes on the input tensor.");
Tianle Cheng988354d2023-06-28 13:20:47 +01001658 }
1659
1660 std::vector<DataType> supportedTypes =
1661 {
1662 DataType::BFloat16,
1663 DataType::Float16,
1664 DataType::Float32,
1665 DataType::QAsymmS8,
1666 DataType::QAsymmU8,
Declan-ARM1bf56cd2023-07-20 17:32:57 +01001667 DataType::QSymmS8,
1668 DataType::QSymmS16,
1669 DataType::Signed32
Tianle Cheng988354d2023-06-28 13:20:47 +01001670 };
1671
1672 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Tracy Narinebb8d7592023-07-13 16:50:54 +01001673
1674 std::vector<DataType> axisSupportedTypes =
1675 {
1676 DataType::Signed32,
1677 };
1678
1679 ValidateDataTypes(axisTensorInfo, axisSupportedTypes, descriptorName);
1680
Tianle Cheng988354d2023-06-28 13:20:47 +01001681 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1682 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Tianle Cheng988354d2023-06-28 13:20:47 +01001683}
1684
telsoa014fcda012018-03-09 14:13:49 +00001685void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1686{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001687 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001688
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001689 ValidateNumInputs(workloadInfo, descriptorName, 1);
1690 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1691
1692 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1693 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1694
1695 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1696 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1697
1698 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1699
telsoa014fcda012018-03-09 14:13:49 +00001700 if (m_Parameters.m_Min > m_Parameters.m_Max)
1701 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001702 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001703 }
telsoa014fcda012018-03-09 14:13:49 +00001704}
1705
Kevin Mayce5045a2019-10-02 14:07:47 +01001706void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1707{
1708 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1709
1710 ValidateNumInputs(workloadInfo, descriptorName, 1);
1711 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1712
1713 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1714 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1715
1716 if (inputTensorInfo.GetNumDimensions() > 4)
1717 {
1718 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1719 }
1720
1721 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1722
1723 // Check the supported data types
1724 std::vector<DataType> supportedTypes =
1725 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001726 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001727 DataType::Float32,
1728 DataType::Float16
1729 };
1730
1731 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001732 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001733}
1734
telsoa014fcda012018-03-09 14:13:49 +00001735void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1736{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001737 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001738
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001739 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001740 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1741
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001742 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1743 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1744
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001745 if (inputTensorInfo.GetNumDimensions() > 4)
1746 {
1747 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1748 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001749
1750 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001751
1752 // Check the supported data types
1753 std::vector<DataType> supportedTypes =
1754 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001755 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001756 DataType::Float32,
1757 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001758 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001759 DataType::QAsymmU8,
1760 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001761 };
1762
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001763 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001764 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1765}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001766
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001767void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1768{
1769 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1770
1771 ValidateNumInputs(workloadInfo, descriptorName, 1);
1772 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1773
1774 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1775 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1776
1777 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1778
1779 std::vector<DataType> supportedTypes =
1780 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001781 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001782 DataType::Float32,
1783 DataType::Float16,
1784 };
1785
1786 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001787 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001788}
1789
1790void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1791{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001792 const std::string descriptorName{"ConstantQueueDescriptor"};
1793
1794 ValidateNumInputs(workloadInfo, descriptorName, 0);
1795 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001796
1797 if (!m_LayerOutput)
1798 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001799 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001800 }
1801
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001802 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1803 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001804
1805 // Check the supported data types
1806 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001807 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001808 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001809 DataType::Float32,
1810 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001811 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001812 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001813 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001814 DataType::QSymmS16,
1815 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001816 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001817
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001818 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001819}
1820
1821void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1822{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001823 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001824
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001825 ValidateNumInputs(workloadInfo, descriptorName, 1);
1826 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1827
1828 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1829 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1830
1831 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001832
1833 // Check the supported data types
1834 std::vector<DataType> supportedTypes =
1835 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001836 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001837 DataType::Float32,
1838 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001839 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001840 DataType::QAsymmU8,
1841 DataType::QSymmS16,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001842 DataType::Signed32,
1843 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001844 };
1845
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001846 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1847 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001848}
1849
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001850void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1851{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001852 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001853
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001854 ValidateNumInputs(workloadInfo, descriptorName, 1);
1855 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1856
1857 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1858 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1859
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001860 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1861 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001862 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1863 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001864 }
1865
Teresa Charlinf77cab52023-06-01 16:15:13 +01001866 if (m_Parameters.m_BlockShape.size() == 2)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001867 {
Teresa Charlinf77cab52023-06-01 16:15:13 +01001868 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1869 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1870 }
1871 else if (m_Parameters.m_BlockShape.size() == 1)
1872 {
1873 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 3, "input");
1874 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 3, "output");
1875 }
1876 else
1877 {
1878 throw InvalidArgumentException(descriptorName + ": Invalid Block and Crops size.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001879 }
1880
Teresa Charlinf77cab52023-06-01 16:15:13 +01001881 // Check input + padding and output have the same number of elements
1882 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1883 const unsigned int inputHeight = inputTensorInfo.GetShape()[dimensionIndices.GetHeightIndex()] +
1884 m_Parameters.m_PadList[0].first + m_Parameters.m_PadList[0].second;
1885 const unsigned int inputWidth = (inputTensorInfo.GetNumDimensions() == 3) ? 1 :
1886 inputTensorInfo.GetShape()[dimensionIndices.GetWidthIndex()] +
1887 m_Parameters.m_PadList[1].first + m_Parameters.m_PadList[1].second;
1888
1889 const int channelsIndex_int = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : -1;
1890 const unsigned int channelsIndex = channelsIndex_int < 0 ?
1891 static_cast<unsigned int>(channelsIndex_int) + inputTensorInfo.GetNumDimensions()
1892 : static_cast<unsigned int>(channelsIndex_int);
1893
1894 const unsigned int numInputElements = inputTensorInfo.GetShape()[0] *
1895 inputHeight *
1896 inputWidth *
1897 inputTensorInfo.GetShape()[channelsIndex];
1898
1899 if (outputTensorInfo.GetNumElements() != numInputElements)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001900 {
Teresa Charlinf77cab52023-06-01 16:15:13 +01001901 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
1902 to_string(numInputElements) + " after padding but output tensor has " +
1903 to_string(outputTensorInfo.GetNumElements()) + " elements.");
1904 }
1905
1906 // In a 4D tensor, there will be 2 spatialDimensions (H and W), and the for loop will run twice.
1907 // In a 3D tensor, there will be 1 spatialDimensions, and the for loop will run once.
1908 unsigned int firstSpatialDimension = m_Parameters.m_DataLayout == DataLayout::NCHW ? 2 : 1;
1909 for (unsigned int i = 0; i < m_Parameters.m_BlockShape.size(); ++i)
1910 {
1911 unsigned int spatialDimension = firstSpatialDimension + i;
1912 auto inputSize = inputTensorInfo.GetShape()[spatialDimension] +
1913 m_Parameters.m_PadList[i].first +
1914 m_Parameters.m_PadList[i].second;
1915 if (inputSize % m_Parameters.m_BlockShape[i] != 0)
1916 {
1917 throw InvalidArgumentException(descriptorName + ": Input dimension size after padding must be "
1918 "divisible by Block Shape in dimension: " + to_string(spatialDimension) + ".");
1919 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001920 }
nikraj01120522a2019-05-31 11:33:07 +01001921
1922 std::vector<DataType> supportedTypes =
1923 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001924 DataType::BFloat16,
1925 DataType::Float16,
1926 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001927 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001928 DataType::QAsymmU8,
1929 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001930 };
1931
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001932 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1933 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001934}
1935
Keith Davisa57eccb2019-06-14 17:33:22 +01001936void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1937{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001938 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001939
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001940 ValidateNumInputs(workloadInfo, descriptorName, 1);
1941 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001942
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001943 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1944 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1945
1946 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1947 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001948
1949 std::vector<DataType> supportedTypes =
1950 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001951 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001952 DataType::Float32,
1953 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001954 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001955 DataType::QAsymmU8,
1956 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001957 };
1958
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001959 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1960 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001961
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001962 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1963
1964 if (m_Parameters.m_BlockSize == 0)
1965 {
1966 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1967 }
1968
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001969 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1970 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1971 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1972 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001973
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001974 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001975 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001976 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001977 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1978 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001979 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001980
1981 const TensorShape& outputShape = outputTensorInfo.GetShape();
1982 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1983 {
1984 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1985 "must be divisible by the square of block size." );
1986 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001987}
1988
telsoa014fcda012018-03-09 14:13:49 +00001989void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1990{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001991 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001992
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001993 ValidateNumInputs(workloadInfo, descriptorName, 1);
1994 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1995
1996 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1997 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001998
1999 std::vector<DataType> supportedTypes =
2000 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002001 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002002 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002003 DataType::Float16,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01002004 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01002005 };
2006
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002007 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matthew Sloyan81beae32021-07-13 19:46:11 +01002008 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2009 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2010 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00002011}
2012
telsoa01c577f2c2018-08-31 09:22:23 +01002013void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2014{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002015 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
2016
2017 const std::string descriptorName{"LstmQueueDescriptor"};
2018
2019 // check dimensions of all inputs and outputs
2020 if (workloadInfo.m_InputTensorInfos.size() != 3)
2021 {
2022 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
2023 }
2024 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2025 {
2026 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
2027 }
2028
2029 std::vector<DataType> supportedTypes =
2030 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002031 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01002032 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002033 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002034 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002035 };
2036
Jan Eilers38e05bd2019-06-26 13:10:09 +01002037 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002038 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
2039
Jan Eilers38e05bd2019-06-26 13:10:09 +01002040 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002041 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01002042 {
2043 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2044 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002045 descriptorName,
2046 "input_0",
2047 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01002048 }
2049 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002050 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01002051 {
2052 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2053 workloadInfo.m_OutputTensorInfos[i],
2054 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002055 "input_0",
2056 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01002057 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002058
janeil0117d8d852019-11-15 15:00:16 +00002059 // Making sure clipping parameters have valid values.
2060 // == 0 means no clipping
2061 // > 0 means clipping
2062 if (m_Parameters.m_ClippingThresCell < 0.0f)
2063 {
2064 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
2065 }
2066 if (m_Parameters.m_ClippingThresProj < 0.0f)
2067 {
2068 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
2069 }
2070
Jan Eilers38e05bd2019-06-26 13:10:09 +01002071 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01002072 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
2073 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
2074 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
2075 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
2076 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
2077 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
2078
Jan Eilers38e05bd2019-06-26 13:10:09 +01002079 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002080 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
2081 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002082 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002083 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
2084 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002085 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002086 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
2087 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002088 // scratchBufferTensor
2089 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002090 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
2091 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002092 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002093 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
2094 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002095 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002096 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
2097 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002098 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002099 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
2100 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002101
Jan Eilers38e05bd2019-06-26 13:10:09 +01002102 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
2103 if ( m_InputToInputWeights )
2104 {
2105 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
2106 (n_cell * n_input), "InputLayerNormWeights");
2107 }
2108
2109 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2110 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
2111 (n_cell * n_input), "InputToForgetWeights");
2112
2113 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2114 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2115 (n_cell * n_input), "InputToCellWeights");
2116
2117 if ( m_RecurrentToInputWeights )
2118 {
2119 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2120 (n_cell * n_output), "RecurrentToInputWeights");
2121 }
2122
2123 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2124 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2125 (n_cell * n_output), "RecurrentToForgetWeights");
2126
2127 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2128 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2129 (n_cell * n_output), "RecurrentToCellWeights");
2130
2131 // Make sure the input-gate's parameters are either both present (regular
2132 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2133 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2134 !m_Parameters.m_CifgEnabled) ||
2135 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2136 m_Parameters.m_CifgEnabled));
2137 if (!cifg_weights_all_or_none)
2138 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002139 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2140 "RecurrentToInputWeights must either both be present (regular LSTM) "
2141 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2142 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002143 }
2144
2145 if ( m_CellToInputWeights )
2146 {
2147 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2148 n_cell, "CellToInputWeights");
2149 }
2150 if ( m_CellToForgetWeights )
2151 {
2152 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2153 n_cell, "CellToForgetWeights");
2154 }
2155 if ( m_CellToOutputWeights )
2156 {
2157 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2158 n_cell, "CellToOutputWeights");
2159 }
2160
2161 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2162 bool peephole_weights_all_or_none =
2163 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2164 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2165 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2166 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2167 if (!peephole_weights_all_or_none)
2168 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002169 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002170 }
2171
2172 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2173 if (m_Parameters.m_CifgEnabled)
2174 {
2175 if (m_InputGateBias)
2176 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002177 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002178 }
2179 }
2180 else
2181 {
2182 if (!m_InputGateBias)
2183 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002184 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2185 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002186 }
2187 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2188 n_cell, "InputGateBias");
2189 }
2190
2191 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2192 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2193
2194 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2195 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2196
2197 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2198 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2199
2200 if (m_ProjectionWeights)
2201 {
2202 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2203 (n_cell * n_output), "ProjectionWeights");
2204 }
2205 if (m_ProjectionBias)
2206 {
2207 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2208 }
2209
2210 // Making sure the projection tensors are consistent:
2211 // 1) If projection weight is not present, then projection bias should not be
2212 // present.
2213 // 2) If projection weight is present, then projection bias is optional.
2214 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2215 !m_Parameters.m_ProjectionEnabled)
2216 || (m_ProjectionWeights && !m_ProjectionBias &&
2217 m_Parameters.m_ProjectionEnabled)
2218 || (m_ProjectionWeights && m_ProjectionBias &&
2219 m_Parameters.m_ProjectionEnabled));
2220 if (!projecton_tensors_consistent)
2221 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002222 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002223 }
2224
2225 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2226 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2227 // either all have values or none of them have values. Layer normalization is used when the values of all the
2228 // layer normalization weights are present
2229 if (m_InputLayerNormWeights)
2230 {
2231 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2232 }
2233 if (m_ForgetLayerNormWeights)
2234 {
2235 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2236 }
2237 if (m_CellLayerNormWeights)
2238 {
2239 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2240 }
2241 if (m_OutputLayerNormWeights)
2242 {
2243 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2244 }
2245
Jan Eilers38e05bd2019-06-26 13:10:09 +01002246 if (m_Parameters.m_LayerNormEnabled)
2247 {
2248 if (!m_Parameters.m_CifgEnabled)
2249 {
2250 if (!m_InputLayerNormWeights)
2251 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002252 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2253 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002254 }
2255 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2256 1, n_cell, "InputLayerNormWeights");
2257 }
2258 else if (m_InputLayerNormWeights)
2259 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002260 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2261 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002262 }
2263
2264 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2265 "ForgetLayerNormWeights");
2266 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2267
2268 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2269 "OutputLayerNormWeights");
2270 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2271
2272 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2273 "CellLayerNormWeights");
2274 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2275 }
2276 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2277 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002278 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2279 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002280 }
telsoa01c577f2c2018-08-31 09:22:23 +01002281}
2282
2283void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2284{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002285 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002286
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002287 ValidateNumInputs(workloadInfo, descriptorName, 1);
2288 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2289
2290 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2291 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2292
2293 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002294 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002295 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002296 }
2297
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002298 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002299 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002300 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002301 }
2302
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002303 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002304}
2305
2306void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2307{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002308 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002309
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002310 ValidateNumInputs(workloadInfo, descriptorName, 1);
2311 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2312
2313 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2314 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2315
2316 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002317 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002318 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002319 }
2320
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002321 if (outputTensorInfo.GetDataType() != DataType::Float32)
2322 {
2323 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2324 }
2325
2326 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002327}
2328
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002329void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2330{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002331 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002332
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002333 ValidateNumInputs(workloadInfo, descriptorName, 2);
2334 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2335
2336 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2337 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2338 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2339
2340 std::vector<DataType> supportedTypes =
2341 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002342 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002343 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002344 DataType::Float32,
2345 DataType::QAsymmS8,
2346 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002347 DataType::QSymmS16,
2348 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002349 };
2350
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002351 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2352 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2353 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002354
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002355 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2356 inputTensorInfo1,
2357 outputTensorInfo,
2358 descriptorName,
2359 "input_0",
2360 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002361}
2362
David Beckc2044fe2018-09-05 15:00:38 +01002363void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2364{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002365 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002366
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002367 ValidateNumInputs(workloadInfo, descriptorName, 2);
2368 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2369
2370 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2371 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2372 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2373
2374 std::vector<DataType> supportedTypes =
2375 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002376 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002377 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002378 DataType::Float32,
2379 DataType::QAsymmS8,
2380 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002381 DataType::QSymmS16,
2382 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002383 };
2384
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002385 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2386 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2387 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002388
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002389 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2390 inputTensorInfo1,
2391 outputTensorInfo,
2392 descriptorName,
2393 "input_0",
2394 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002395}
2396
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002397void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2398{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002399 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002400
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002401 ValidateNumInputs(workloadInfo, descriptorName, 2);
2402 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2403
2404 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2405 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2406 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2407
2408 std::vector<DataType> supportedTypes =
2409 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002410 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002411 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002412 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002413 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002414 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002415 DataType::QSymmS16,
2416 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002417 };
2418
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002419 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2420 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2421 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002422
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002423 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2424 inputTensorInfo1,
2425 outputTensorInfo,
2426 descriptorName,
2427 "input_0",
2428 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002429}
2430
narpra01a6bf9122018-09-10 09:50:09 +01002431void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2432{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002433 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002434
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002435 ValidateNumInputs(workloadInfo, descriptorName, 1);
2436 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2437
2438 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2439 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002440
2441 std::vector<DataType> supportedTypes =
2442 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002443 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002444 DataType::Float32,
2445 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002446 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002447 DataType::QAsymmU8,
2448 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002449 };
narpra01eb061912018-09-10 17:35:27 +01002450
James Conroy4d1ff582019-06-10 17:06:39 +01002451 // First check if input tensor data type is supported, then
2452 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002453 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2454 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002455
narpra0132b90462018-09-13 11:07:48 +01002456 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002457 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002458 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002459 }
narpra0132b90462018-09-13 11:07:48 +01002460 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002461 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002462 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002463 }
2464 else
2465 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002466 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002467 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002468 ValidateTensorNumDimensions(outputTensorInfo,
2469 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002470 outputDim > 0 ? outputDim : 1,
2471 "output");
2472 }
narpra01a6bf9122018-09-10 09:50:09 +01002473}
2474
jimfly012c9322a2018-09-19 10:59:49 +01002475void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2476{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002477 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002478
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002479 ValidateNumInputs(workloadInfo, descriptorName, 1);
2480 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2481
2482 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2483 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002484
jimfly012c9322a2018-09-19 10:59:49 +01002485 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002486 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2487
jimfly012c9322a2018-09-19 10:59:49 +01002488 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002489 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2490 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2491 "as there are dimensions in the input tensor that is " +
2492 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2493 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002494 }
2495}
2496
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002497void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2498{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002499 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002500
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002501 ValidateNumInputs(workloadInfo, descriptorName, 1);
2502 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002503
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002504 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2505 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2506
Sadik Armagan2208b602019-07-31 16:36:27 +01002507 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002508 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002509 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002510 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002511 DataType::Float16,
2512 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002513 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002514 DataType::QAsymmU8,
2515 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002516 };
2517
2518 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002519
Keith Davis0c2eeac2020-02-11 16:51:50 +00002520 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002521 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002522 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002523 }
2524}
2525
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002526void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2527{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002528 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002529
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002530 ValidateNumInputs(workloadInfo, descriptorName, 1);
2531 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002532
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002533 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2534 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002535
Teresa Charlinf77cab52023-06-01 16:15:13 +01002536 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_Crops.size())
2537 {
2538 throw InvalidArgumentException(descriptorName + ": Crops must contain the same number of "
2539 "dimensions as Block Shape.");
2540 }
2541
2542 if (m_Parameters.m_BlockShape.size() == 2)
2543 {
2544 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2545 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
2546 }
2547 else if (m_Parameters.m_BlockShape.size() == 1)
2548 {
2549 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 3, "input");
2550 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 3, "output");
2551 }
2552 else
2553 {
2554 throw InvalidArgumentException(descriptorName + ": Invalid Block and Crops size.");
2555 }
2556
2557 // In a 4D tensor, there will be 2 spatialDimensions (H and W), and the for loop will run twice.
2558 // In a 3D tensor, there will be 1 spatialDimensions, and the for loop will run once.
2559 unsigned int firstSpatialDimension = m_Parameters.m_DataLayout == DataLayout::NCHW ? 2 : 1;
2560 for (unsigned int i = 0; i < m_Parameters.m_BlockShape.size(); ++i)
2561 {
2562 unsigned int spatialDimension = firstSpatialDimension + i;
2563 unsigned int cropSize = m_Parameters.m_Crops[i].first + m_Parameters.m_Crops[i].second;
2564 unsigned int outputSize = inputTensorInfo.GetShape()[spatialDimension] * m_Parameters.m_BlockShape[i];
2565 if (cropSize > outputSize)
2566 {
2567 throw InvalidArgumentException(descriptorName + ": CropSize must be less than or equal to the uncropped"
2568 "outputSize in dimension: " + to_string(spatialDimension) + ".");
2569 }
2570 }
2571
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002572 std::vector<DataType> supportedTypes =
2573 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002574 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002575 DataType::Float32,
2576 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002577 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002578 DataType::QAsymmU8,
2579 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002580 };
2581
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002582 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2583 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002584}
2585
Conor Kennedy430b5d82018-11-14 15:28:28 +00002586void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2587{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002588 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002589
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002590 ValidateNumInputs(workloadInfo, descriptorName, 1);
2591 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2592
2593 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2594 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002595
2596 std::vector<DataType> supportedTypes =
2597 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002598 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002599 DataType::Float16,
2600 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002601 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002602 DataType::QAsymmU8,
2603 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002604 };
2605
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002606 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2607 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002608
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002609 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002610
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002611 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002612 if (rank > 4)
2613 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002614 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002615 }
2616
Conor Kennedy430b5d82018-11-14 15:28:28 +00002617 // Begin, End & Stride length must be of rank(input0)
2618 if (m_Parameters.m_Begin.size() != rank)
2619 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002620 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002621 }
2622
2623 if (m_Parameters.m_End.size() != rank)
2624 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002625 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002626 }
2627
2628 if (m_Parameters.m_Stride.size() != rank)
2629 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002630 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002631 }
2632
2633 // Stride entries must be non-zero
2634 for (auto& stride : m_Parameters.m_Stride)
2635 {
2636 if (stride == 0)
2637 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002638 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002639 }
2640 }
2641}
2642
kevmay0190539692018-11-29 08:40:19 +00002643void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2644{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002645 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002646
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002647 ValidateNumInputs(workloadInfo, descriptorName, 2);
2648 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2649
2650 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2651 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2652 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2653
2654 std::vector<DataType> supportedTypes =
2655 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002656 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002657 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002658 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002659 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002660 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002661 DataType::QSymmS16,
2662 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002663 };
2664
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002665 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2666 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2667 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002668
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002669 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2670 inputTensorInfo1,
2671 outputTensorInfo,
2672 descriptorName,
2673 "input_0",
2674 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002675}
2676
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002677void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2678{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002679 const std::string descriptorName{"DebugQueueDescriptor"};
2680
2681 ValidateNumInputs(workloadInfo, descriptorName, 1);
2682 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002683}
2684
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002685void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2686{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002687 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002688
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002689 ValidateNumInputs(workloadInfo, descriptorName, 2);
2690 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002691
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002692 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2693 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2694 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2695
2696 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2697 inputTensorInfo1,
2698 outputTensorInfo,
2699 descriptorName,
2700 "input_0",
2701 "input_1");
2702
2703 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002704 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002705 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002706 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002707}
2708
FrancisMurtagh878f0232018-12-19 10:56:15 +00002709void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2710{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002711 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002712
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002713 ValidateNumInputs(workloadInfo, descriptorName, 2);
2714 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002715
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002716 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2717 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2718 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2719
2720 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2721 inputTensorInfo1,
2722 outputTensorInfo,
2723 descriptorName,
2724 "input_0",
2725 "input_1");
2726
2727 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002728 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002729 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002730 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002731}
2732
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002733void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2734{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002735 const std::string descriptorName{"RsqrtQueueDescriptor"};
2736
2737 ValidateNumInputs(workloadInfo, descriptorName, 1);
2738 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2739
2740 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2741 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2742
2743 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002744
2745 std::vector<DataType> supportedTypes =
2746 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002747 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002748 DataType::Float16,
2749 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002750 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002751 DataType::QAsymmU8,
2752 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002753 };
2754
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002755 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2756 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002757}
2758
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01002759void GatherNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2760{
2761 const std::string descriptorName{"GatherNdQueueDescriptor"};
2762
2763 ValidateNumInputs(workloadInfo, descriptorName, 2);
2764 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2765
2766 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2767 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
2768 {
2769 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
2770 }
2771
2772 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2773 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2774
2775 std::vector<DataType> supportedTypes =
2776 {
2777 DataType::BFloat16,
2778 DataType::Float16,
2779 DataType::Float32,
2780 DataType::QAsymmS8,
2781 DataType::QAsymmU8,
2782 DataType::QSymmS16,
2783 DataType::Signed32,
2784 };
2785
2786 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2787
2788 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2789
2790 unsigned int outputDim = outputTensorInfo.GetNumDimensions();
2791 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
2792}
2793
narpra01b89b05f2019-01-16 09:53:09 +00002794void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2795{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002796 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002797
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002798 ValidateNumInputs(workloadInfo, descriptorName, 2);
2799 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002800
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002801 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2802 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002803 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002804 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002805 }
2806
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002807 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2808 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2809
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002810 std::vector<DataType> supportedTypes =
2811 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002812 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002813 DataType::Float16,
2814 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002815 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002816 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002817 DataType::QSymmS16,
2818 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002819 };
2820
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002821 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002822
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002823 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002824
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002825 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2826 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002827}
2828
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002829void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2830{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002831 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2832
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002833 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002834
2835 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2836 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002837 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002838 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2839 }
2840
2841 if (m_Anchors == nullptr)
2842 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002843 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002844 }
2845
2846 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002847 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2848 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2849
2850 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002851 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002852 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2853 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002854
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002855 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2856 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2857 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002858
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002859 const std::vector<DataType> supportedInputTypes =
2860 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002861 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002862 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002863 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002864 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002865 DataType::QAsymmU8,
2866 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002867 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002868
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002869 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2870 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2871 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2872
2873 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2874 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2875 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2876 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2877
2878 // NOTE: Output is always Float32 regardless of input type
2879 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2880 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2881 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2882 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002883
2884 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2885 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002886 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002887 "must be positive and less than or equal to 1.");
2888 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002889
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002890 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2891 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002892 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002893 "should be equal to number of classes + 1.");
2894 }
2895}
2896
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002897void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2898{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002899 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002900
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002901 ValidateNumInputs(workloadInfo, descriptorName, 1);
2902 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2903
2904 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2905 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2906
Teresa Charlin07307f32022-05-15 14:07:05 +01002907 std::vector<DataType> inputSupportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002908 {
Teresa Charlin07307f32022-05-15 14:07:05 +01002909 DataType::QAsymmS8,
2910 DataType::QAsymmU8,
2911 DataType::QSymmS8,
2912 DataType::QSymmS16,
2913 DataType::Float16
2914 };
2915 ValidateDataTypes(inputTensorInfo, inputSupportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002916
Teresa Charlin07307f32022-05-15 14:07:05 +01002917 std::vector<DataType> outputSupportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002918 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002919 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002920 DataType::Float32,
2921 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002922 };
2923
Teresa Charlin07307f32022-05-15 14:07:05 +01002924 ValidateDataTypes(outputTensorInfo, outputSupportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002925}
2926
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002927void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2928{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002929 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002930
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002931 ValidateNumInputs(workloadInfo, descriptorName, 2);
2932 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002933
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002934 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2935 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2936 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002937
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002938 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2939 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2940
2941 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2942 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002943}
2944
Keith Davis3ae3f972021-05-21 16:33:48 +01002945void ShapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2946{
2947 const std::string& descriptorName{"ShapeQueueDescriptor"};
2948
2949 ValidateNumInputs(workloadInfo, descriptorName, 1);
2950 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2951
2952 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2953 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2954
2955 std::vector<DataType> supportedTypes =
2956 {
2957 DataType::BFloat16,
2958 DataType::Float16,
2959 DataType::Float32,
2960 DataType::QAsymmS8,
2961 DataType::QAsymmU8,
Keith Davis3ae3f972021-05-21 16:33:48 +01002962 DataType::QSymmS8,
2963 DataType::QSymmS16,
2964 DataType::Signed32
2965 };
2966
2967 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2968 ValidateDataTypes(outputTensorInfo, {DataType::Signed32}, descriptorName);
2969}
2970
Sadik Armaganeff363d2019-04-05 15:25:46 +01002971void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2972{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002973 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002974
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002975 ValidateNumInputs(workloadInfo, descriptorName, 2);
2976 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2977
2978 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2979 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2980
2981 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2982 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2983
2984 std::vector<DataType> supportedTypes =
2985 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002986 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002987 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002988 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002989 DataType::QAsymmU8,
2990 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002991 };
2992
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002993 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2994 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002995
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002996 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2997 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002998
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002999 ValidateTensorShapesMatch(inputTensorInfo0,
3000 outputTensorInfo0,
3001 descriptorName,
3002 "input_0",
3003 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01003004
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003005 ValidateTensorShapesMatch(inputTensorInfo0,
3006 outputTensorInfo1,
3007 descriptorName,
3008 "input_0",
3009 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01003010}
3011
Derek Lamberti901ea112019-12-10 22:07:09 +00003012void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00003013{
Teresa Charlin9145e382023-08-17 18:44:58 +01003014 // This is internally generated, so it should not need validation.
Matteo Martincigh49124022019-01-11 13:25:59 +00003015}
3016
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003017void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3018{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003019 const std::string& descriptorName{"PreluQueueDescriptor"};
3020
3021 ValidateNumInputs(workloadInfo, descriptorName, 2);
3022 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3023
3024 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3025 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
3026 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003027
3028 std::vector<DataType> supportedTypes
3029 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003030 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003031 DataType::Float16,
3032 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003033 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003034 DataType::QAsymmU8,
3035 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003036 };
3037
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003038 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3039 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003040
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003041 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003042
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003043 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
3044 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003045
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003046 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
3047 alphaTensorInfo,
3048 outputTensorInfo,
3049 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003050 "input",
3051 "alpha");
3052}
3053
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003054void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3055{
3056 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
3057
3058 ValidateNumInputs(workloadInfo, descriptorName, 1);
3059 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3060
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003061 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3062 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3063
3064 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
3065 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003066
3067 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003068
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003069 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
3070 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003071
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003072 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
3073
3074 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003075 if (m_Parameters.m_BiasEnabled)
3076 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003077 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003078
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003079 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
3080 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003081
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003082 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Ryan OSheaf183acd2023-07-06 11:41:25 +01003083 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003084 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003085
3086 ValidatePerAxisQuantization(inputTensorInfo,
3087 outputTensorInfo,
3088 weightTensorInfo,
3089 optionalBiasTensorInfo,
3090 descriptorName);
3091
3092 std::vector<DataType> supportedTypes =
3093 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003094 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003095 DataType::Float32,
3096 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003097 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003098 DataType::QAsymmU8,
3099 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003100 };
3101
3102 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3103 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003104}
3105
Mike Kellyc9ea45a2020-02-28 18:11:58 +00003106void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3107{
3108 const std::string descriptorName{"TransposeQueueDescriptor"};
3109
3110 ValidateNumInputs(workloadInfo, descriptorName, 1);
3111 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3112
3113 const PermutationVector& mapping = m_Parameters.m_DimMappings;
3114
3115 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3116 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3117
3118 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
3119 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
3120
3121 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
3122 {
3123 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
3124 {
3125 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
3126 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
3127 "must match dst dimension " + to_string(i) +
3128 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
3129 }
3130 }
3131
3132 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3133}
3134
Simon Obute51f67772021-09-03 15:50:13 +01003135void ChannelShuffleQueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
3136{
3137 const std::string descriptorName{"TransposeQueueDescriptor"};
3138
3139 ValidateNumInputs(workloadInfo, descriptorName, 1);
3140 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3141
3142 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3143 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3144
3145 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3146}
3147
James Conroy4f1f8992020-04-29 20:01:10 +01003148void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3149{
3150 const std::string descriptorName{"QLstmQueueDescriptor"};
3151
3152 // Validate number of inputs/outputs
3153 ValidateNumInputs(workloadInfo, descriptorName, 3);
3154 ValidateNumOutputs(workloadInfo, descriptorName, 3);
3155
3156 // Input/output tensor info
3157 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3158 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
3159 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
3160
3161 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3162 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3163 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
3164
3165 // Supported types for various tensors in QLSTM
3166 std::vector<DataType> inputOutputSupportedTypes =
3167 {
3168 DataType::QAsymmS8
3169 };
3170
3171 std::vector<DataType> cellStateSupportedTypes =
3172 {
3173 DataType::QSymmS16
3174 };
3175
3176 std::vector<DataType> weightsSupportedTypes =
3177 {
3178 DataType::QSymmS8
3179 };
3180
3181 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3182 {
3183 DataType::QSymmS16
3184 };
3185
3186 std::vector<DataType> biasSupportedTypes =
3187 {
3188 DataType::Signed32
3189 };
3190
3191 // Validate types of input/output tensors
3192 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3193 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3194 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3195
3196 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3197 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3198 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3199
3200 // Validate matching types of input/output tensors
3201 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3202 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3203 "outputStateIn", "outputStateOut");
3204 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3205
3206 // Infer number of batches, number of units, input size and output size from tensor dimensions
3207 const uint32_t numBatches = inputInfo.GetShape()[0];
3208 const uint32_t inputSize = inputInfo.GetShape()[1];
3209 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3210 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3211
3212 // Validate number of dimensions and number of elements for input/output tensors
3213 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3214 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3215 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3216
3217 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3218 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3219 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3220
3221 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3222 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3223 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3224 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3225
3226 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3227 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3228 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3229
3230 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3231 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3232 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3233
3234 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3235 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3236 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3237 " RecurrentToForgetWeights");
3238
3239 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3240 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3241 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3242
3243 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3244 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3245 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3246
3247 // Validate data types for MANDATORY weights tensors (all should match each other)
3248 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3249
3250 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3251 "inputToForgetWeights", "inputToCellWeights");
3252 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3253 "inputToForgetWeights", "inputToOutputWeights");
3254
3255 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3256 "inputToForgetWeights", "recurrentToForgeteights");
3257 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3258 "inputToForgetWeights", "recurrentToCellWeights");
3259 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3260 "inputToForgetWeights", "recurrentToOutputWeights");
3261
3262 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3263 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3264 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3265 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3266
3267 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3268 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3269 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3270
3271 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3272 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3273 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3274
3275 // Validate data types for MANDATORY bias tensors
3276 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3277
3278 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3279 "forgetGateBias", "cellBias");
3280 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3281 "forgetGateBias", "outputGateBias");
3282
3283 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3284 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3285 !m_Parameters.m_CifgEnabled) ||
3286 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3287 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3288
3289 if (!allCifgParamsPresentOrNot)
3290 {
3291 throw InvalidArgumentException(descriptorName +
3292 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3293 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3294 "set appropriately.");
3295 }
3296
3297 if (!m_Parameters.m_CifgEnabled)
3298 {
3299 // Validate number of dimensions and number of elements
3300 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3301 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3302
3303 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3304 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3305 " RecurrentToInputWeights");
3306
3307 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3308 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3309
3310 // Validate data types
3311 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3312 "inputToForgetWeights", "inputToInputWeights");
3313 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3314 "inputToForgetWeights", "recurrentToInputWeights");
3315 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3316 "forgetGateBias", "inputGateBias");
3317 }
3318
3319 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3320 bool allPeepholeWeightsPresentOrNot =
3321 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3322 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3323 || (!m_CellToInputWeights && !m_CellToForgetWeights
3324 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3325
3326 if (!allPeepholeWeightsPresentOrNot)
3327 {
3328 throw InvalidArgumentException(descriptorName +
3329 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3330 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3331 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3332 "appropriately.");
3333 }
3334
3335 if (m_Parameters.m_PeepholeEnabled)
3336 {
3337 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3338 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3339 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3340
3341 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3342 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3343 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3344 "cellToForgetWeight", "cellToOutputWeights");
3345
3346 if (!m_Parameters.m_CifgEnabled)
3347 {
3348 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3349 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3350 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3351 "cellToForgetWeights", "cellToInputWeights");
3352 }
3353 }
3354
3355 // Validate OPTIONAL params: Layer Norm Weights
3356 bool allLayerNormWeightsPresentOrNot =
3357 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3358 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3359 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3360 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3361
3362 if (!allLayerNormWeightsPresentOrNot)
3363 {
3364 throw InvalidArgumentException(descriptorName +
3365 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3366 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3367 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3368 "only be present when Layer Norm is enabled and CIFG is disabled. "
3369 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3370 }
3371
3372 if (m_Parameters.m_LayerNormEnabled)
3373 {
3374 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3375 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3376 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3377
3378 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3379 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3380 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3381 "forgetLayerNormWeights", "cellLayerNormWeights");
3382
3383 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3384 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3385 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3386 "forgetLayerNormWeights", "outputLayerNormWeights");
3387
3388 if (!m_Parameters.m_CifgEnabled)
3389 {
3390 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3391 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3392 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3393 "forgetLayerNormWeights", "inputLayerNormWeights");
3394 }
3395 }
3396
3397 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3398 bool correctProjectionTensorsPresent =
3399 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3400 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3401 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3402
3403 if (!correctProjectionTensorsPresent)
3404 {
3405 throw InvalidArgumentException(descriptorName +
3406 ": If projection is enabled, ProjectionWeights should be present and "
3407 "ProjectionBias is optional. If projection is disabled, neither "
3408 "ProjectionWeights nor ProjectionBias should be present.");
3409 }
3410
3411 if (m_Parameters.m_ProjectionEnabled)
3412 {
3413 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3414 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3415 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3416
3417 if (m_ProjectionBias)
3418 {
3419 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003420 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003421 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3422 }
3423
3424 }
3425 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3426 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3427 throw InvalidArgumentException(descriptorName +
3428 ": If projection is disabled, output quantization info (scale, offset) "
3429 "should match HiddenStateScale and HiddenStateZeroPoint.");
3430 }
3431
3432}
3433
James Conroy9c3cae82019-08-01 16:01:48 +01003434void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3435{
3436 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3437
3438 // Validate number of inputs/outputs
3439 ValidateNumInputs(workloadInfo, descriptorName, 3);
3440 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3441
3442 // Input/output tensor infos
3443 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3444 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3445 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3446
3447 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3448 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3449
3450 std::vector<DataType> inputOutputSupportedTypes =
3451 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003452 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003453 };
3454
3455 std::vector<DataType> cellStateSupportedTypes =
3456 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003457 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003458 };
3459
3460 std::vector<DataType> weightsSupportedTypes =
3461 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003462 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003463 };
3464
3465 std::vector<DataType> biasSupportedTypes =
3466 {
3467 DataType::Signed32
3468 };
3469
3470 // Validate types of input/output tensors
3471 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3472 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3473 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3474
3475 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3476 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3477
3478 // Validate matching types of input/output tensors
3479 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3480 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3481 "outputStateIn", "outputStateOut");
3482 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3483
3484 // Validate matching quantization info for input/output tensors
3485 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3486 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3487 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003488
James Conroy9c3cae82019-08-01 16:01:48 +01003489 // Infer number of batches, input size and output size from tensor dimensions
3490 const uint32_t numBatches = inputInfo.GetShape()[0];
3491 const uint32_t inputSize = inputInfo.GetShape()[1];
3492 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3493
3494 // Validate number of dimensions and number of elements for input/output tensors
3495 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3496 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3497 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3498 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3499 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3500
3501 // Validate number of dimensions and number of elements for weights tensors
3502 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3503 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3504 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3505
3506 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3507 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3508 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3509
3510 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3511 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3512 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3513
3514 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3515 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3516 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3517
3518 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3519 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3520 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3521
3522 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3523 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3524 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3525 " RecurrentToForgetWeights");
3526
3527 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3528 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3529 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3530
3531 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3532 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3533 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3534
3535 // Validate data types for weights tensors (all should match each other)
3536 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3537
3538 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3539 "inputToInputWeights", "inputToForgetWeights");
3540 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3541 "inputToInputWeights", "inputToCellWeights");
3542 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3543 "inputToInputWeights", "inputToOutputWeights");
3544
3545 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3546 "inputToInputWeights", "recurrentToInputWeights");
3547 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3548 "inputToInputWeights", "recurrentToForgeteights");
3549 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3550 "inputToInputWeights", "recurrentToCellWeights");
3551 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3552 "inputToInputWeights", "recurrentToOutputWeights");
3553
3554 // Validate matching quantization info for weight tensors (all should match each other)
3555 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3556 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3557 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3558 descriptorName, "inputToInputWeights", "inputToCellWeights");
3559 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3560 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3561
3562 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3563 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3564 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3565 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3566 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3567 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3568 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3569 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3570
3571 // Validate number of dimensions and number of elements in bias tensors
3572 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3573 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3574 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3575
3576 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3577 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3578 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3579
3580 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3581 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3582 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3583
3584 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3585 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3586 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3587
3588 // Validate data types for bias tensors (all should match each other)
3589 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3590
3591 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3592 "inputGateBias", "forgetGateBias");
3593 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3594 "inputGateBias", "cellBias");
3595 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3596 "inputGateBias", "outputGateBias");
3597
3598 // Validate bias tensor quantization info
Ryan OSheaf183acd2023-07-06 11:41:25 +01003599 ValidateBiasTensorQuantization(inputGateBiasInfo, inputToInputWeightsInfo, descriptorName);
3600 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputToInputWeightsInfo, descriptorName);
3601 ValidateBiasTensorQuantization(cellBiasInfo, inputToInputWeightsInfo, descriptorName);
3602 ValidateBiasTensorQuantization(outputGateBiasInfo, inputToInputWeightsInfo, descriptorName);
James Conroy9c3cae82019-08-01 16:01:48 +01003603}
3604
Kevin May868eb142019-09-04 17:29:31 +01003605void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3606{
3607 const std::string descriptorName{"AbsQueueDescriptor"};
3608
3609 ValidateNumInputs(workloadInfo, descriptorName, 1);
3610 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3611
3612 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3613 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3614
3615 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3616
3617 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003618 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003619 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003620 DataType::Float16,
3621 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003622 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003623 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003624 DataType::QSymmS16,
3625 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003626 };
Kevin May868eb142019-09-04 17:29:31 +01003627
3628 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3629 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3630}
3631
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003632void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3633{
3634 const std::string descriptorName{"SliceQueueDescriptor"};
3635
3636 ValidateNumInputs(workloadInfo, descriptorName, 1);
3637 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3638
3639 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3640 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3641
3642 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3643
3644 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3645 if (rank > 4)
3646 {
3647 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3648 }
3649
3650 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3651
3652 // Check if m_Begin and m_Size have the expected length
3653 if (m_Parameters.m_Begin.size() != rank)
3654 {
3655 throw InvalidArgumentException(descriptorName +
3656 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3657 }
3658 if (m_Parameters.m_Size.size() != rank)
3659 {
3660 throw InvalidArgumentException(descriptorName +
3661 ": Length of size descriptor must equal rank " + std::to_string(rank));
3662 }
3663
3664 // Check if the shape of the output tensor matches m_Size
3665 const TensorShape& outputShape = outputTensorInfo.GetShape();
3666 for (unsigned int i = 0u; i < rank; ++i)
3667 {
3668 if (m_Parameters.m_Size[i] != outputShape[i])
3669 {
3670 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3671 }
3672 }
3673
3674 // Check if the sum of begin offset and size in a given dimension
3675 // does not exceed the size of corresponding input
3676 const TensorShape& inputShape = inputTensorInfo.GetShape();
3677 for(unsigned int i = 0u; i < rank; ++i)
3678 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003679 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003680 {
3681 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3682 std::to_string(i) + " exceeds input size.");
3683 }
3684 }
3685}
3686
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003687void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3688{
3689 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3690
3691 ValidateNumInputs(workloadInfo, descriptorName, 1);
3692 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3693
3694 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3695 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3696
3697 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3698 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3699
3700 std::vector<DataType> supportedTypes =
3701 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003702 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003703 DataType::Float32,
3704 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003705 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003706 DataType::QAsymmU8,
3707 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003708 };
3709
3710 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3711 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3712
3713 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3714
3715 if (m_Parameters.m_BlockSize == 0)
3716 {
3717 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3718 }
3719
3720 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3721 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3722 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3723 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3724
3725 const TensorShape& outputShape = outputInfo.GetShape();
3726 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3727 {
3728 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3729 "must be divisible by block size.");
3730 }
3731
3732 const TensorShape& inputShape = inputInfo.GetShape();
3733 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3734 {
3735 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3736 "must be divisible by the square of block size." );
3737 }
3738}
3739
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003740void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3741{
3742 const std::string descriptorName{"ComparisonQueueDescriptor"};
3743
3744 ValidateNumInputs(workloadInfo, descriptorName, 2);
3745 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3746
3747 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3748 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3749 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3750
3751 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3752 inputTensorInfo1,
3753 outputTensorInfo,
3754 descriptorName,
3755 "input_0",
3756 "input_1");
3757
3758 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3759 {
3760 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3761 }
3762}
3763
Mike Kelly3ec30772023-03-08 13:47:17 +00003764void ElementwiseBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3765{
3766 const std::string descriptorName{"ElementwiseBinaryQueueDescriptor"};
3767
3768 ValidateNumInputs(workloadInfo, descriptorName, 2);
3769 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3770
3771 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3772 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3773 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3774
3775 std::vector<DataType> supportedTypes =
3776 {
3777 DataType::BFloat16,
3778 DataType::Float16,
3779 DataType::Float32,
3780 DataType::QAsymmS8,
3781 DataType::QAsymmU8,
3782 DataType::QSymmS16,
3783 DataType::Signed32
3784 };
3785
3786 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
3787 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
3788
3789 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input", "output");
3790 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input", "output");
3791}
3792
josh minor4a3c6102020-01-06 16:40:46 -06003793void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3794{
3795 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3796
3797 ValidateNumInputs(workloadInfo, descriptorName, 1);
3798 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3799
3800 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3801 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3802
3803 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3804
3805 std::vector<DataType> supportedTypes =
3806 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003807 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003808 DataType::Float16,
3809 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003810 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003811 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003812 DataType::QSymmS16,
3813 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003814 };
3815
James Conroyaba90cd2020-11-06 16:28:18 +00003816 std::vector<DataType> logicalSupportedTypes =
3817 {
3818 DataType::Boolean
3819 };
3820
3821 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3822 {
3823 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3824 }
3825 else
3826 {
3827 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3828 }
3829
3830
josh minor4a3c6102020-01-06 16:40:46 -06003831 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3832}
3833
Finn Williams2605b232020-06-10 15:53:46 +01003834void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3835{
3836 const std::string descriptorName{"RankQueueDescriptor"};
3837
3838 ValidateNumInputs(workloadInfo, descriptorName, 1);
3839 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3840
3841 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3842 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3843
3844 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3845 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3846
3847 std::vector<DataType> supportedTypes =
3848 {
3849 DataType::BFloat16,
3850 DataType::Float16,
3851 DataType::Float32,
3852 DataType::QAsymmS8,
3853 DataType::QAsymmU8,
3854 DataType::QSymmS8,
3855 DataType::QSymmS16,
3856 DataType::Signed32
3857 };
3858
3859 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3860 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3861}
3862
James Conroyaba90cd2020-11-06 16:28:18 +00003863void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3864{
3865 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3866
3867 ValidateNumInputs(workloadInfo, descriptorName, 2);
3868 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3869
3870 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3871 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3872 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3873
3874 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3875 inputTensorInfo1,
3876 outputTensorInfo,
3877 descriptorName,
3878 "input_0",
3879 "input_1");
3880
3881 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3882 {
3883 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3884 }
3885
3886 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3887 {
3888 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3889 }
3890
3891 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3892 {
3893 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3894 }
3895}
3896
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003897void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3898{
3899 const std::string descriptorName{"ReduceQueueDescriptor"};
3900
3901 ValidateNumInputs(workloadInfo, descriptorName, 1);
3902 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3903
3904 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3905 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3906
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003907 std::vector<DataType> supportedTypes =
3908 {
3909 DataType::BFloat16,
3910 DataType::Float16,
3911 DataType::Float32,
3912 DataType::QAsymmS8,
3913 DataType::QAsymmU8,
3914 DataType::QSymmS16,
3915 DataType::Signed32
3916 };
3917
3918 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3919 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3920}
3921
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003922void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3923{
3924 // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm
3925
3926 const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"};
3927
3928 // check dimensions of all inputs and outputs
3929 if (workloadInfo.m_InputTensorInfos.size() != 3)
3930 {
3931 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
3932 }
Mike Kelly12994962022-04-21 11:57:09 +01003933 if (workloadInfo.m_OutputTensorInfos.size() != 3)
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003934 {
3935 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
3936 }
3937
3938 std::vector<DataType> supportedTypes =
3939 {
Mike Kelly12994962022-04-21 11:57:09 +01003940 DataType::Float32,
3941 DataType::QAsymmS8
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003942 };
3943
3944 // check for supported type of one input and match them with all the other input and output
3945 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
3946
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003947 // Making sure clipping parameters have valid values.
3948 // == 0 means no clipping
3949 // > 0 means clipping
3950 if (m_Parameters.m_ClippingThresCell < 0.0f)
3951 {
3952 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
3953 }
3954 if (m_Parameters.m_ClippingThresProj < 0.0f)
3955 {
3956 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
3957 }
3958
3959 unsigned int batchIndx = 0;
3960 unsigned int inputIndx = 1;
3961 uint32_t timeStep = 1;
3962 unsigned int timeIndx = 1;
3963 inputIndx = 2;
3964 if (m_Parameters.m_TimeMajor)
3965 {
3966 batchIndx = 1;
3967 timeIndx = 0;
3968
3969 }
3970 timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx];
3971
3972 // Inferring batch size, number of outputs and number of cells from the inputs.
3973 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx];
3974 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx];
3975 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
3976 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
3977 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
3978 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
3979
3980 // input tensor
3981 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input),
3982 descriptorName + " input_0");
3983 // outputStateInTensor
3984 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
3985 descriptorName + " input_1");
3986 // outputStateInTensor
3987 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
3988 descriptorName + " input_2");
3989
3990 // outputTensor
Mike Kelly12994962022-04-21 11:57:09 +01003991 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 3, (timeStep * n_batch * n_output),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003992 descriptorName + " output_0");
3993
3994 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
3995 if ( m_InputToInputWeights )
3996 {
3997 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
3998 (n_cell * n_input), "InputLayerNormWeights");
3999 }
4000
4001 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
4002 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
4003 (n_cell * n_input), "InputToForgetWeights");
4004
4005 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
4006 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
4007 (n_cell * n_input), "InputToCellWeights");
4008
4009 if ( m_RecurrentToInputWeights )
4010 {
4011 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
4012 (n_cell * n_output), "RecurrentToInputWeights");
4013 }
4014
4015 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
4016 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
4017 (n_cell * n_output), "RecurrentToForgetWeights");
4018
4019 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
4020 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
4021 (n_cell * n_output), "RecurrentToCellWeights");
4022
4023 // Make sure the input-gate's parameters are either both present (regular
4024 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
4025 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
4026 !m_Parameters.m_CifgEnabled) ||
4027 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
4028 m_Parameters.m_CifgEnabled));
4029 if (!cifg_weights_all_or_none)
4030 {
4031 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
4032 "RecurrentToInputWeights must either both be present (regular LSTM) "
4033 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
4034 "accordingly.");
4035 }
4036
4037 if ( m_CellToInputWeights )
4038 {
4039 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
4040 n_cell, "CellToInputWeights");
4041 }
4042 if ( m_CellToForgetWeights )
4043 {
4044 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
4045 n_cell, "CellToForgetWeights");
4046 }
4047 if ( m_CellToOutputWeights )
4048 {
4049 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
4050 n_cell, "CellToOutputWeights");
4051 }
4052
4053 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
4054 bool peephole_weights_all_or_none =
4055 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
4056 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
4057 || ( !m_CellToInputWeights && !m_CellToForgetWeights
4058 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
4059 if (!peephole_weights_all_or_none)
4060 {
4061 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
4062 }
4063
4064 // Make sure the input gate bias is present only when not a CIFG-LSTM.
4065 if (m_Parameters.m_CifgEnabled)
4066 {
4067 if (m_InputGateBias)
4068 {
4069 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
4070 }
4071 }
4072 else
4073 {
4074 if (!m_InputGateBias)
4075 {
4076 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
4077 "must be present.");
4078 }
4079 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
4080 n_cell, "InputGateBias");
4081 }
4082
4083 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
4084 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
4085
4086 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
4087 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
4088
4089 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
4090 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
4091
4092 if (m_ProjectionWeights)
4093 {
4094 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
4095 (n_cell * n_output), "ProjectionWeights");
4096 }
4097 if (m_ProjectionBias)
4098 {
4099 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
4100 }
4101
4102 // Making sure the projection tensors are consistent:
4103 // 1) If projection weight is not present, then projection bias should not be
4104 // present.
4105 // 2) If projection weight is present, then projection bias is optional.
4106 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
4107 !m_Parameters.m_ProjectionEnabled)
4108 || (m_ProjectionWeights && !m_ProjectionBias &&
4109 m_Parameters.m_ProjectionEnabled)
4110 || (m_ProjectionWeights && m_ProjectionBias &&
4111 m_Parameters.m_ProjectionEnabled));
4112 if (!projecton_tensors_consistent)
4113 {
4114 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
4115 }
4116
4117 // The four layer normalization weights either all have values or none of them have values. Additionally, if
4118 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
4119 // either all have values or none of them have values. Layer normalization is used when the values of all the
4120 // layer normalization weights are present
4121 if (m_InputLayerNormWeights)
4122 {
4123 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
4124 }
4125 if (m_ForgetLayerNormWeights)
4126 {
4127 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4128 }
4129 if (m_CellLayerNormWeights)
4130 {
4131 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4132 }
4133 if (m_OutputLayerNormWeights)
4134 {
4135 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4136 }
4137
4138 if (m_Parameters.m_LayerNormEnabled)
4139 {
4140 if (!m_Parameters.m_CifgEnabled)
4141 {
4142 if (!m_InputLayerNormWeights)
4143 {
4144 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
4145 "disabled but InputLayerNormWeights are not present");
4146 }
4147 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
4148 1, n_cell, "InputLayerNormWeights");
4149 }
4150 else if (m_InputLayerNormWeights)
4151 {
4152 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
4153 "enabled");
4154 }
4155
4156 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
4157 "ForgetLayerNormWeights");
4158 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4159
4160 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
4161 "OutputLayerNormWeights");
4162 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4163
4164 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
4165 "CellLayerNormWeights");
4166 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4167 }
4168 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
4169 {
4170 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
4171 "normalisation weights are present.");
4172 }
4173}
4174
Samuel Yap6b478092022-07-06 15:36:03 +01004175void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4176{
4177 const std::string descriptorName{"BatchMatMulDescriptor"};
4178
4179 ValidateNumInputs(workloadInfo, descriptorName, 2);
4180 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4181
4182 // Inputs must be: both 2D+
4183 // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
4184 // axes N and I must be the same size
4185
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004186 const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
4187 const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
4188 const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
4189 // Output info has already been inferred
Samuel Yap6b478092022-07-06 15:36:03 +01004190
4191 std::vector<DataType> supportedTypes =
4192 {
4193 DataType::BFloat16,
4194 DataType::Float16,
4195 DataType::Float32,
4196 DataType::QAsymmS8,
4197 DataType::QAsymmU8,
4198 DataType::QSymmS16
4199 };
4200
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004201 ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
4202 ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
4203 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
Samuel Yap6b478092022-07-06 15:36:03 +01004204
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004205 if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
4206 (inputYInfoBeforeParams.GetNumDimensions() < 2))
Samuel Yap6b478092022-07-06 15:36:03 +01004207 {
4208 throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
4209 }
4210
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004211 TensorInfo inputXInfoAfterParams;
4212 TensorInfo inputYInfoAfterParams;
4213
4214 if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) ||
4215 (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY))
Samuel Yap6b478092022-07-06 15:36:03 +01004216 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004217 throw InvalidArgumentException(descriptorName +
4218 ": Invalid descriptor parameters - Transpose and Adjoint "
4219 "cannot both be true for a given input tensor.");
4220 }
4221 if(m_Parameters.m_TransposeX)
4222 {
4223 inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
4224 BatchMatMulDescriptor::GetPermuteVec(
4225 m_Parameters.m_DataLayoutX,
4226 inputXInfoBeforeParams.GetShape()));
4227 }
4228 else if(m_Parameters.m_AdjointX)
4229 {
4230 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4231 inputXInfoBeforeParams.GetShape());
4232 if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
4233 inputXInfoBeforeParams.GetShape()[axesToMul.second])
Samuel Yap6b478092022-07-06 15:36:03 +01004234 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004235 throw InvalidArgumentException(descriptorName +
4236 ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
Samuel Yap6b478092022-07-06 15:36:03 +01004237 }
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004238 // Shape remains the same as it's square
4239 inputXInfoAfterParams = inputXInfoBeforeParams;
4240 }
4241 else
4242 {
4243 inputXInfoAfterParams = inputXInfoBeforeParams;
Samuel Yap6b478092022-07-06 15:36:03 +01004244 }
4245
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004246 if(m_Parameters.m_TransposeY)
Samuel Yap6b478092022-07-06 15:36:03 +01004247 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004248 inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
4249 BatchMatMulDescriptor::GetPermuteVec(
4250 m_Parameters.m_DataLayoutY,
4251 inputYInfoBeforeParams.GetShape()));
4252 }
4253 else if(m_Parameters.m_AdjointY)
4254 {
4255 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4256 inputYInfoBeforeParams.GetShape());
4257 if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
4258 inputYInfoBeforeParams.GetShape()[axesToMul.second])
Samuel Yap6b478092022-07-06 15:36:03 +01004259 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004260 throw InvalidArgumentException(descriptorName +
4261 ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
Samuel Yap6b478092022-07-06 15:36:03 +01004262 }
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004263 // Shape remains the same as it's square
4264 inputYInfoAfterParams = inputYInfoBeforeParams;
4265 }
4266 else
4267 {
4268 inputYInfoAfterParams = inputYInfoBeforeParams;
Samuel Yap6b478092022-07-06 15:36:03 +01004269 }
4270
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004271 switch(m_Parameters.m_DataLayoutX)
4272 {
4273 case DataLayout::NCDHW:
4274 case DataLayout::NDHWC:
4275 if(inputXInfoAfterParams.GetNumDimensions() < 3)
4276 {
4277 throw InvalidArgumentException(descriptorName +
4278 ": Input tensor X does not have the correct "
4279 "number of dimensions for the Data Layout that it has been assigned.");
4280 }
4281 break;
4282 case DataLayout::NCHW:
4283 case DataLayout::NHWC:
4284 default:
4285 break;
4286 }
Samuel Yap6b478092022-07-06 15:36:03 +01004287
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004288 switch(m_Parameters.m_DataLayoutY)
4289 {
4290 case DataLayout::NCDHW:
4291 case DataLayout::NDHWC:
4292 if(inputYInfoAfterParams.GetNumDimensions() < 3)
4293 {
4294 throw InvalidArgumentException(descriptorName +
4295 ": Input tensor Y does not have the correct "
4296 "number of dimensions for the Data Layout that it has been assigned.");
4297 }
4298 break;
4299 case DataLayout::NCHW:
4300 case DataLayout::NHWC:
4301 default:
4302 break;
4303 }
4304
4305 auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4306 inputXInfoAfterParams.GetShape());
4307 auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4308 inputXInfoBeforeParams.GetShape());
4309
4310 if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
4311 != inputYInfoAfterParams.GetShape()[axesYToMul.first])
Samuel Yap6b478092022-07-06 15:36:03 +01004312 {
4313 throw InvalidArgumentException(descriptorName +
4314 ": The final axis of input tensor X must be the same size as "
4315 "the second last axis of input tensor Y.");
4316 }
4317
Samuel Yap6b478092022-07-06 15:36:03 +01004318 { // Separate scope so we don't pollute the rest of the scope with our temp variables
4319 // e.g. NHWC isnt compatible with NCHW as of now
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004320 DataLayout xLayout = m_Parameters.m_DataLayoutX;
4321 DataLayout yLayout = m_Parameters.m_DataLayoutY;
Samuel Yap6b478092022-07-06 15:36:03 +01004322
4323 if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
4324 {
4325 if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
4326 {
4327 throw InvalidArgumentException(descriptorName +
4328 ": Invalid input tensor data layout combination.");
4329 }
4330 }
4331 if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
4332 {
4333 if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
4334 {
4335 throw InvalidArgumentException(descriptorName +
4336 ": Invalid input tensor data layout combination.");
4337 }
4338 }
4339 }
4340
4341 // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004342 unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
4343 inputYInfoAfterParams.GetNumDimensions());
Samuel Yap6b478092022-07-06 15:36:03 +01004344 if(outputTensorDimSize-2 > 0)
4345 {
4346 TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4347 DataType::Float32);
4348 TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4349 DataType::Float32);
4350 TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4351 DataType::Float32);
4352
4353 auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
4354 {
4355 auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4356
4357 for(unsigned int i = 0; i < sizeDiff; i++)
4358 {
4359 axisIndices.insert(axisIndices.begin(), 1);
4360 }
4361
4362 for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4363 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004364 ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
Samuel Yap6b478092022-07-06 15:36:03 +01004365 }
4366 };
4367
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004368 auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX,
4369 inputXInfoAfterParams.GetShape());
4370 auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY,
4371 inputYInfoAfterParams.GetShape());
4372
4373 doAxisExtension(axesXNotMul, tiXNotMul);
4374 doAxisExtension(axesYNotMul, tiYNotMul);
Samuel Yap6b478092022-07-06 15:36:03 +01004375
4376 for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
4377 {
4378 tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
4379 tiYNotMul.GetShape()[i]);
4380 }
4381
4382 ValidateBroadcastTensorShapesMatch(tiXNotMul,
4383 tiYNotMul,
4384 tiOutNotMul,
4385 descriptorName,
4386 "input_X",
4387 "input_Y");
4388 }
Samuel Yap6b478092022-07-06 15:36:03 +01004389}
4390
Teresa Charlin79a06a52023-07-13 17:16:45 +01004391void TileQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4392{
4393 const std::string& descriptorName{"TileQueueDescriptor"};
4394
4395 ValidateNumInputs(workloadInfo, descriptorName, 1);
4396 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4397
4398 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
4399 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
4400
4401 std::vector<DataType> supportedTypes =
4402 {
4403 DataType::Float32,
4404 DataType::Float16,
4405 DataType::QAsymmS8,
4406 DataType::QAsymmU8,
4407 DataType::QSymmS8,
4408 DataType::QSymmS16,
4409 DataType::Signed32
4410 };
4411
4412 // Multiples length must be the same as the number of dimensions in input.
4413 if (m_Parameters.m_Multiples.size() != inputTensorInfo.GetNumDimensions())
4414 {
4415 throw InvalidArgumentException(descriptorName +
4416 ": Multiples length is not same as the number of dimensions in Input.");
4417 }
4418
4419 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
4420 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
4421}
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01004422
Idriss Chaouch98e383e2023-08-28 14:28:31 +01004423void BroadcastToQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4424{
4425 const std::string& descriptorName{"BroadcastToQueueDescriptor"};
4426
4427 ValidateNumInputs(workloadInfo, descriptorName, 1);
4428 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4429
4430 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
4431 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
4432
4433 std::vector<DataType> supportedTypes =
4434 {
4435 DataType::Float32,
4436 DataType::Float16,
4437 DataType::QAsymmS8,
4438 DataType::QAsymmU8,
4439 DataType::QSymmS8,
4440 DataType::QSymmS16,
4441 DataType::Signed32,
4442 DataType::Signed64
4443 };
4444
4445 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
4446 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
4447}
4448
Tianle Cheng28288182024-02-23 17:56:54 +00004449void ScatterNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4450{
4451 const std::string& descriptorName{"ScatterQueueDescriptor"};
4452
4453 ValidateNumInputs(workloadInfo, descriptorName, 3);
4454 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4455
4456 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
4457 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
4458 const TensorInfo& inputTensorInfo2 = workloadInfo.m_InputTensorInfos[2];
4459 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
4460
4461 std::vector<DataType> supportedTypes =
4462 {
4463 DataType::Float32,
4464 DataType::Float16,
4465 DataType::QAsymmS8,
4466 DataType::QAsymmU8,
4467 DataType::QSymmS8,
4468 DataType::QSymmS16,
4469 DataType::Signed32
4470 };
4471
4472 std::vector<DataType> indicesSupportedTypes =
4473 {
4474 DataType::Signed32
4475 };
4476
4477 if (m_Parameters.m_InputEnabled)
4478 {
4479 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
4480 }
4481 else
4482 {
4483 ValidateDataTypes(inputTensorInfo0, indicesSupportedTypes, descriptorName);
4484 }
4485
4486 ValidateDataTypes(inputTensorInfo1, indicesSupportedTypes, descriptorName);
4487 ValidateDataTypes(inputTensorInfo2, supportedTypes, descriptorName);
4488 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
4489}
4490
mathad01df9a3222021-04-28 11:42:57 +01004491} // namespace armnn