blob: 0ddb4291f19bb3ad918ebfd80168723cf13c55e9 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Colm Donelanb4ef1632024-02-01 15:00:43 +00002// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Colm Donelan0c479742021-12-10 12:43:54 +00006#include <armnn/backends/TensorHandle.hpp>
7#include <armnn/backends/WorkloadData.hpp>
8#include <armnn/backends/WorkloadInfo.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00009#include <armnnUtils/DataLayoutIndexed.hpp>
10#include <armnnUtils/TensorUtils.hpp>
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010011#include <armnnUtils/Permute.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010012#include <armnn/utility/NumericCast.hpp>
mathad01df9a3222021-04-28 11:42:57 +010013#include <armnn/Logging.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000014
telsoa014fcda012018-03-09 14:13:49 +000015#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000017#include <string>
18#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000019
James Ward47fce872020-09-10 11:57:28 +010020#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000021
Matteo Martincigh21350152018-11-28 16:22:22 +000022using namespace armnnUtils;
23
telsoa014fcda012018-03-09 14:13:49 +000024namespace armnn
25{
26
27//---------------------------------------------------------------
28DataType GetBiasDataType(DataType inputDataType)
29{
30 switch (inputDataType)
31 {
telsoa01c577f2c2018-08-31 09:22:23 +010032 case DataType::Float16:
33 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000034 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000035 case DataType::Float32:
36 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000037 case DataType::QAsymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +000038 case DataType::QAsymmU8:
Keith Davis5204aa82020-01-27 15:24:59 +000039 case DataType::QSymmS8:
Derek Lambertif90c56d2020-01-10 17:14:08 +000040 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010041 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000042 default:
Colm Donelanb4ef1632024-02-01 15:00:43 +000043 throw InvalidArgumentException("GetBiasDataType(): Unsupported data type.");
telsoa014fcda012018-03-09 14:13:49 +000044 }
45}
46
47namespace
48{
49
50//---------------------------------------------------------------
51//android ndk does not support std::to_string function.
52template <typename T>
53std::string to_string(T value)
54{
55 std::ostringstream os;
56 os << value;
57 return os.str();
58}
59
60//---------------------------------------------------------------
61void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
62{
63 if (!ptr)
64 {
65 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
66 paramName + " parameter must be set.");
67 }
68}
69
70//---------------------------------------------------------------
71void ValidateTensorShapesMatch(const TensorInfo& first,
72 const TensorInfo& second,
73 std::string const& descName,
74 std::string const& firstName,
75 std::string const& secondName)
76{
77 if (first.GetShape() != second.GetShape())
78 {
79 throw InvalidArgumentException(descName + ": "
80 + firstName + " & " + secondName + " must have identical shapes");
81 }
82}
83
84//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010085void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000086{
Sadik Armaganeff363d2019-04-05 15:25:46 +010087 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000088 {
89 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010090 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000091 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
92 }
93}
94
95//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010096void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000097{
Sadik Armaganeff363d2019-04-05 15:25:46 +010098 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000099 {
100 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100101 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000102 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
103 }
104}
105
106//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000107
108//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100109void ValidateTensorNumElements(const TensorInfo& tensor,
110 std::string const& descName,
111 unsigned int numElements,
112 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100113{
114 if (tensor.GetNumElements() != numElements)
115 {
116 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100117 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100118 tensorName + " tensor.");
119 }
120}
121
122//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000123void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
124 const std::string& descName, std::string const& tensorName)
125{
126 if (tensor.GetDataType() != dataType)
127 {
128 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
129 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
130 }
131}
132
Derek Lambertid466a542020-01-22 15:37:29 +0000133void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
134{
Jan Eilers1b2654f2021-09-24 15:45:46 +0100135 if (tensor.GetDataType() != DataType::QSymmS8)
Derek Lambertid466a542020-01-22 15:37:29 +0000136 {
137 throw InvalidArgumentException(descName +
138 ": Expected data type which supports per-axis quantization scheme but got " +
139 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
140 }
Derek Lambertid466a542020-01-22 15:37:29 +0000141}
142
telsoa014fcda012018-03-09 14:13:49 +0000143//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100144void ValidateTensorQuantizationSpace(const TensorInfo& first,
145 const TensorInfo& second,
146 const std::string& descName,
147 std::string const& firstName,
148 std::string const& secondName)
149{
150 if (!first.IsQuantized() ||
151 !second.IsQuantized())
152 {
153 // Not a quantized type, ignore the validation
154 return;
155 }
156
157 DataType firstDataType = first.GetDataType();
158 DataType secondDataType = second.GetDataType();
159
160 if (firstDataType != secondDataType)
161 {
162 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
163 " must be of the same quantized type, " +
164 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
165 secondName + " is " + GetDataTypeName(secondDataType));
166 }
167
168 if (!first.IsTypeSpaceMatch(second))
169 {
170 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
171 " must have the same quantization space, " +
172 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
173 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
174 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
175 " and scale " + to_string(second.GetQuantizationScale()));
176 }
177}
178
179//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100180void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100181 const TensorInfo& weightsTensorInfo,
182 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000183{
184 if (biasTensor.GetQuantizationOffset() != 0)
185 {
186 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
187 to_string(biasTensor.GetQuantizationOffset()));
188 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000189
James Conroy8502ade2020-11-12 19:26:29 +0000190 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000191 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000192 // Validate per-axis quantization scales
193 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
194 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
195
196 if (weightScales.size() != biasScales.size())
197 {
198 std::stringstream msg;
James Conroy8502ade2020-11-12 19:26:29 +0000199 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
200 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
201 << ", biases=" << biasScales.size();
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000202 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
203 }
telsoa014fcda012018-03-09 14:13:49 +0000204 }
205}
206
207//---------------------------------------------------------------
208void ValidateTensors(const std::vector<ITensorHandle*>& vec,
Teresa Charlin79a06a52023-07-13 17:16:45 +0100209 unsigned int numExpected,
210 const std::string& descName,
211 const std::string& varName)
telsoa014fcda012018-03-09 14:13:49 +0000212{
213 if (vec.empty() && numExpected > 0)
214 {
215 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
216 }
217
218 for (unsigned int i = 0; i < numExpected; ++i)
219 {
220 if (!vec[i])
221 {
222 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
223 }
224 }
225}
226
227//---------------------------------------------------------------
228void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
229 const TensorInfo& second,
230 const TensorInfo& output,
231 std::string const& descName,
232 std::string const& firstName,
233 std::string const& secondName)
234{
235 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
236 // broadcasted.
237 if (first.GetNumDimensions() != second.GetNumDimensions())
238 {
239 throw InvalidArgumentException(descName + ": Tensors "
240 + firstName + " & " + secondName
241 + " must have the same number of dimensions in order to be broadcasted");
242 }
243 uint32_t numDims = first.GetNumDimensions();
244 std::vector<uint32_t> outputDims(numDims, 0u);
245 for (uint32_t i = 0; i < numDims; i++)
246 {
247 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
248 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
249 if (dimsNotEqual && dimsNotOne)
250 {
251 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
252 }
253 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
254 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100255 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000256 if (broadcastShape != output.GetShape())
257 {
258 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
259 + firstName + " & " + secondName
260 + " does not match the output shape");
261 }
262}
263
264//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100265void ValidateDataTypes(const TensorInfo& info,
266 const std::vector<armnn::DataType>& supportedTypes,
267 std::string const& descName)
268{
269 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
270 if (iterator == supportedTypes.end())
271 {
272 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
273 }
274}
275
James Conroy4d1ff582019-06-10 17:06:39 +0100276//---------------------------------------------------------------
277void ValidateTensorDataTypesMatch(const TensorInfo& first,
278 const TensorInfo& second,
279 std::string const& descName,
280 std::string const& firstName,
281 std::string const& secondName)
282{
283 if (first.GetDataType() != second.GetDataType())
284 {
285 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
286 " must have identical data types.");
287 }
288}
289
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100290//---------------------------------------------------------------
291void ValidateTensorNumElementsMatch(const TensorInfo& first,
292 const TensorInfo& second,
293 std::string const& descName,
294 std::string const& firstName,
295 std::string const& secondName)
296{
297 if (first.GetNumElements() != second.GetNumElements())
298 {
299 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
300 " must have the same number of elements.");
301 }
302}
303
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000304void ValidateWeightDataType(const TensorInfo& inputInfo,
305 const TensorInfo& weightInfo,
306 const std::string& descName)
307{
308 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000309 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000310 {
311 const std::vector<DataType> validTypes =
312 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000313 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100314 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +0100315 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000316 };
317
318 ValidateDataTypes(weightInfo, validTypes, descName);
319 }
320 else
321 {
322 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
323 }
324}
325
326void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
327 const std::string& descName,
328 const std::string& tensorName)
329{
330 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
331 if (!quantizationDim.has_value())
332 {
James Ward47fce872020-09-10 11:57:28 +0100333 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
334 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000335 }
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000336}
337
338void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
339 const std::string& descName,
340 const std::string& tensorName)
341{
342 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
343 if (quantizationOffset != 0)
344 {
James Ward47fce872020-09-10 11:57:28 +0100345 throw InvalidArgumentException(fmt::format(
346 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
347 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000348 }
349}
350
351void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
352 const TensorInfo& outputInfo,
353 const TensorInfo& weightInfo,
354 const Optional<TensorInfo>& optionalBiasInfo,
355 const std::string& descName)
356{
357 if (weightInfo.HasPerAxisQuantization())
358 {
359 const DataType inputDataType = inputInfo.GetDataType();
360 const DataType outputDataType = outputInfo.GetDataType();
361
Keith Davis0c2eeac2020-02-11 16:51:50 +0000362 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000363
364 if (!canHavePerAxisQuantization)
365 {
James Ward47fce872020-09-10 11:57:28 +0100366 throw InvalidArgumentException(fmt::format(
367 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
368 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000369 }
370
Derek Lambertid466a542020-01-22 15:37:29 +0000371
372 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000373 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
374 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
375
376 if (optionalBiasInfo.has_value())
377 {
378 const TensorInfo& biasInfo = optionalBiasInfo.value();
379 if (!biasInfo.HasPerAxisQuantization())
380 {
James Ward47fce872020-09-10 11:57:28 +0100381 throw InvalidArgumentException(fmt::format(
382 "{}: Per-axis quantization parameters not set on bias tensor, "
383 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000384 }
385
386 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
387 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
388 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
389 }
390 }
391}
392
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100393} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000394
Mike Kelly80512b02022-05-16 23:10:42 +0100395//---------------------------------------------------------------
396void QueueDescriptor::ValidateTensorNumDimensions(const TensorInfo& tensor,
397 std::string const& descName,
398 unsigned int numDimensions,
399 std::string const& tensorName) const
400{
401 // If we're allowing expanded dimensions then numDimensions becomes the minimum number of Dimensions we can allow.
402 // Throw an Exception if the tensors has fewer than numDimensions or if the squeezed dimensions are greater than
403 // numDimensions.
404 if (m_AllowExpandedDims)
405 {
406 unsigned int squeezedDims = 0;
407
408 for (unsigned int i = 0; i < tensor.GetNumDimensions(); ++i)
409 {
410 if (tensor.GetShape()[i] != 1)
411 {
412 ++squeezedDims;
413 }
414 }
415 if (tensor.GetNumDimensions() < numDimensions || squeezedDims > numDimensions)
416 {
417 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " or less but got " +
418 to_string(tensor.GetNumDimensions()) + " dimensions for " +
419 tensorName + " tensor.");
420 }
421 }
422 else
423 {
424 if (tensor.GetNumDimensions() != numDimensions)
425 {
426 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
427 to_string(tensor.GetNumDimensions()) + " dimensions for " +
428 tensorName + " tensor.");
429 }
430 }
431}
432
433//---------------------------------------------------------------
434void QueueDescriptor::ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Teresa Charlin79a06a52023-07-13 17:16:45 +0100435 unsigned int numDimension,
436 unsigned int numElements,
437 std::string const& tensorName) const
Mike Kelly80512b02022-05-16 23:10:42 +0100438{
439 const std::string functionName{"ValidateTensorNumDimNumElem"};
440 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
441 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
442}
443
444//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000445void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
446 unsigned int numExpectedIn, unsigned int numExpectedOut) const
447{
448 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
449 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
450}
451
452//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100453void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
454{
455 const std::string descriptorName{"MapQueueDescriptor"};
456
457 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100458 ValidateNumOutputs(workloadInfo, descriptorName, 0);
459
460 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
461 {
462 if (!m_Inputs[i])
463 {
464 throw InvalidArgumentException(
465 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
466 }
467 }
468}
469
470//---------------------------------------------------------------
471void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
472{
473 const std::string descriptorName{"UnmapQueueDescriptor"};
474
475 ValidateNumInputs(workloadInfo, descriptorName, 1);
476 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100477
478 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
479 {
480 if (!m_Inputs[i])
481 {
482 throw InvalidArgumentException(
483 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
484 }
485 }
486}
487
488//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000489void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
490{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100491 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000492
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100493 ValidateNumInputs(workloadInfo, descriptorName, 1);
494 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000495
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100496 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
497 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
498
499 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
500 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000501
502 if (m_Inputs.size() != m_Outputs.size())
503 {
James Ward47fce872020-09-10 11:57:28 +0100504 throw InvalidArgumentException(fmt::format(
505 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
506 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000507 }
508
509 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
510 {
511 if (!m_Inputs[i])
512 {
James Ward47fce872020-09-10 11:57:28 +0100513 throw InvalidArgumentException(fmt::format(
514 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000515 }
516
517 if (!m_Outputs[i])
518 {
James Ward47fce872020-09-10 11:57:28 +0100519 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000520 }
521 }
522}
523
Derek Lambertif674aa02019-08-01 15:56:25 +0100524//---------------------------------------------------------------
525void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
526{
527 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
528 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
529
530 if (workloadInfo.m_InputTensorInfos.size() != 1)
531 {
James Ward47fce872020-09-10 11:57:28 +0100532 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
533 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100534
535 }
536
537 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
538 {
James Ward47fce872020-09-10 11:57:28 +0100539 throw InvalidArgumentException(fmt::format(
540 "Number of input infos ({0}) does not match the number of output infos ({1})",
541 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100542 }
543
544 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
545 {
546 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
547 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
548 {
James Ward47fce872020-09-10 11:57:28 +0100549 throw InvalidArgumentException(fmt::format(
550 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100551 }
552 }
553
554 if (m_Inputs.size() != 1)
555 {
James Ward47fce872020-09-10 11:57:28 +0100556 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100557 }
558
559 if (m_Inputs.size() != m_Outputs.size())
560 {
James Ward47fce872020-09-10 11:57:28 +0100561 throw InvalidArgumentException(fmt::format(
562 "Number of inputs ({0}) does not match the number of outputs ({1})",
563 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100564 }
565
566 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
567 {
568 if (!m_Inputs[i])
569 {
James Ward47fce872020-09-10 11:57:28 +0100570 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100571 }
572
573 if (!m_Outputs[i])
574 {
James Ward47fce872020-09-10 11:57:28 +0100575 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100576 }
577 }
578}
579
580//---------------------------------------------------------------
581void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
582{
583 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
Derek Lambertif674aa02019-08-01 15:56:25 +0100584
Derek Lambertif674aa02019-08-01 15:56:25 +0100585 if (m_Inputs.size() != 1)
586 {
James Ward47fce872020-09-10 11:57:28 +0100587 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100588 }
589
590 if (m_Outputs.size() != 0)
591 {
James Ward47fce872020-09-10 11:57:28 +0100592 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100593 }
594
595 if (!m_Inputs[0])
596 {
James Ward47fce872020-09-10 11:57:28 +0100597 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100598 }
599}
600
601//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000602void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
603{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100604 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100605
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100606 ValidateNumInputs(workloadInfo, descriptorName, 1);
607 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100608
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100609 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
610 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100611
612 std::vector<DataType> supportedTypes =
613 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000614 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100615 DataType::Float16,
616 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000617 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000618 DataType::QAsymmU8,
619 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100620 };
621
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100622 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
623 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
624 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000625}
626
Nikhil Rajee391d52019-09-05 17:50:44 +0100627void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
628{
629 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
630
631 ValidateNumInputs(workloadInfo, descriptorName, 1);
632 ValidateNumOutputs(workloadInfo, descriptorName, 1);
633
634 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
635 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
636
Inki Daed4619e22020-09-10 15:33:54 +0900637 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
638 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100639 {
Inki Daed4619e22020-09-10 15:33:54 +0900640 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100641 }
642
James Conroyd47a0642019-09-17 14:22:06 +0100643 std::vector<DataType> supportedInputTypes =
644 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000645 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100646 DataType::Float16,
647 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100648 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000649 DataType::QAsymmU8,
650 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900651 DataType::Signed32,
652 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100653 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100654
James Conroyd47a0642019-09-17 14:22:06 +0100655 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100656
657 auto inputShape = inputTensorInfo.GetShape();
658 auto outputShape = outputTensorInfo.GetShape();
659
660 auto inputNumDimensions = inputShape.GetNumDimensions();
661 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
662
663 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
664
665 // 1D input shape results in scalar output shape
666 if (inputShape.GetNumDimensions() == 1)
667 {
668 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
669 {
670 throw InvalidArgumentException(descriptorName + outputShapeError);
671 }
672 }
673 else
674 {
675 for (unsigned int i = 0; i < unsignedAxis; ++i)
676 {
677 if (outputShape[i] != inputShape[i])
678 {
679 throw InvalidArgumentException(descriptorName + outputShapeError);
680 }
681 }
682
683 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
684 {
685 if (outputShape[i - 1] != inputShape[i])
686 {
687 throw InvalidArgumentException(descriptorName + outputShapeError);
688 }
689 }
690 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100691}
692
mathad01b392e982021-04-07 12:07:30 +0100693void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
694{
695 const std::string descriptorName{"CastQueueDescriptor"};
696
697 ValidateNumInputs(workloadInfo, descriptorName, 1);
698 ValidateNumOutputs(workloadInfo, descriptorName, 1);
699
700 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
701 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
702
703 std::vector<DataType> supportedTypes =
704 {
705 DataType::BFloat16,
706 DataType::Float16,
707 DataType::Float32,
708 DataType::QAsymmS8,
709 DataType::QAsymmU8,
710 DataType::QSymmS8,
711 DataType::QSymmS16,
712 DataType::Signed32,
713 DataType::Signed64
714 };
715
716 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
717 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
718}
719
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100720void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
721{
722 const std::string descriptorName{"SoftmaxQueueDescriptor"};
723
724 ValidateNumInputs(workloadInfo, descriptorName, 1);
725 ValidateNumOutputs(workloadInfo, descriptorName, 1);
726
727 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
728 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
729
730 std::vector<DataType> supportedTypes =
731 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000732 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100733 DataType::Float16,
734 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000735 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000736 DataType::QAsymmU8,
737 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100738 };
739
740 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
741 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
742 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
743}
744
telsoa014fcda012018-03-09 14:13:49 +0000745void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
746{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100747 const std::string descriptorName{"SplitterQueueDescriptor"};
748
749 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000750
Ruomei Yan25339c32019-05-28 16:48:20 +0100751 // Check the supported data types
752 std::vector<DataType> supportedTypes =
753 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000754 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100755 DataType::Float32,
756 DataType::Float16,
757 DataType::Boolean,
758 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100759 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000760 DataType::QAsymmU8,
761 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100762 };
763
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100764 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
765 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100766 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100767 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
768 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
769
770 const std::string outputName = "output_" + std::to_string(i);
771 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100772 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100773
telsoa014fcda012018-03-09 14:13:49 +0000774 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
775 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100776 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000777 }
778
779 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
780 {
781 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100782 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000783 "has to match number of workloadInfo.m_OutputTensorInfos. "
784 "Number of windows: " +
785 to_string(m_ViewOrigins.size()) +
786 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
787 }
788
telsoa01c577f2c2018-08-31 09:22:23 +0100789 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000790 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
791 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
792 {
telsoa01c577f2c2018-08-31 09:22:23 +0100793 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000794 ViewOrigin const& e = m_ViewOrigins[w];
795 if (e.m_Origin.size() != inputDims)
796 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100797 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000798 "have the same dimensionality as the input tensor. "
799 "Window origin (index: " +
800 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
801 " dimensions, the input "
802 "tensor has " +
803 to_string(inputDims) + " dimensions.");
804 }
805 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
806 {
807 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
808 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
809 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100810 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000811 "be smaller or equal than the size of the input in that coord.");
812 }
813 }
814 }
815}
816
Jim Flynne242f2d2019-05-22 14:24:13 +0100817void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000818{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100819 const std::string descriptorName{"ConcatQueueDescriptor"};
820
821 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000822
823 if (m_Inputs.size() <= 0)
824 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100825 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000826 }
827 if (m_Outputs.size() <= 0)
828 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100829 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000830 }
831
832 if (workloadInfo.m_InputTensorInfos.size() <= 0)
833 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100834 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000835 }
836 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
837 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100838 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000839 }
840
Nikhil Raj8599a412018-11-19 14:51:07 +0000841 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
842 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100843 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000844 }
845
846 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
847 {
848 return;
849 }
850
telsoa014fcda012018-03-09 14:13:49 +0000851 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
852 {
853 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100854 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000855 "has to match number of workloadInfo.m_InputTensorInfos. "
856 "Number of windows: " +
857 to_string(m_ViewOrigins.size()) +
858 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
859 }
860
telsoa01c577f2c2018-08-31 09:22:23 +0100861 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000862 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
863 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
864 {
telsoa01c577f2c2018-08-31 09:22:23 +0100865 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000866 ViewOrigin const& e = m_ViewOrigins[w];
867 if (e.m_Origin.size() != outputDims)
868 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100869 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000870 "have the same dimensionality as the output tensor. "
871 "Window origin (index: " +
872 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
873 " dimensions, the output "
874 "tensor has " +
875 to_string(outputDims) + " dimensions.");
876 }
telsoa01c577f2c2018-08-31 09:22:23 +0100877 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000878 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
879 {
880 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
881 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
882 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100883 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000884 "be smaller or equal than the size of the output in that coord.");
885 }
886 }
887 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100888
889 // Check the supported data types
890 std::vector<DataType> supportedTypes =
891 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000892 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100893 DataType::Float32,
894 DataType::Float16,
895 DataType::Boolean,
896 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100897 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000898 DataType::QAsymmU8,
899 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100900 };
901
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100902 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
903 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100904 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100905 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
906 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
907
908 const std::string inputName = "input_" + std::to_string(i);
909 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100910 }
telsoa014fcda012018-03-09 14:13:49 +0000911}
912
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100913void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
914{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100915 const std::string descriptorName{"StackQueueDescriptor"};
916
917 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100918
919 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
920 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100921 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100922 }
923
924 // All inputs must have the same shape, which is defined in parameters
925 const TensorShape& inputShape = m_Parameters.m_InputShape;
926 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
927 {
928 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
929 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100930 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100931 }
932 }
933
Matthew Jacksondba634f2019-08-15 15:14:18 +0100934 if (inputShape.GetNumDimensions() > 4)
935 {
936 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
937 }
938
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100939 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
940 // since the output tensor has an additional dimension.
941 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
942 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100943 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100944 "than the number of input dimensions.");
945 }
946
947 // Output shape must be as inferred from the input shape
948 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
949 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
950 {
951 if (outputShape[i] != inputShape[i])
952 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100953 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100954 "match shape inferred from input tensor.");
955 }
956 }
957
958 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
959 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100960 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100961 "match shape inferred from input tensor.");
962 }
963
964 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
965 {
966 if (outputShape[i] != inputShape[i-1])
967 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100968 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100969 "match shape inferred from input tensor.");
970 }
971 }
972
Matthew Jacksondba634f2019-08-15 15:14:18 +0100973 if (outputShape.GetNumDimensions() > 5)
974 {
975 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
976 }
977
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100978 // Check the supported data types
979 std::vector<DataType> supportedTypes =
980 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000981 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100982 DataType::Float32,
983 DataType::Float16,
984 DataType::Boolean,
985 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100986 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000987 DataType::QAsymmU8,
988 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100989 };
990
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100991 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100992
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100993 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100994 {
995 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
996 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100997 descriptorName,
998 "input_0",
999 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001000 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001001
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001002 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1003 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001004 descriptorName,
1005 "input_0",
1006 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001007}
1008
Ryan OSheaec6c6802020-06-05 17:17:06 +01001009void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1010{
1011 const std::string descriptorName{"FillQueueDescriptor"};
1012
1013 ValidateNumInputs(workloadInfo, descriptorName, 1);
1014 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1015
1016 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1017 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1018
1019 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1020
1021 std::vector<DataType> supportedTypes =
1022 {
1023 DataType::BFloat16,
1024 DataType::Float32,
1025 DataType::Float16,
1026 DataType::Signed32
1027 };
1028
1029 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1030}
1031
telsoa014fcda012018-03-09 14:13:49 +00001032void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1033{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001034 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001035
Matthew Sloyan81beae32021-07-13 19:46:11 +01001036 uint32_t numInputs = 2;
1037 if (m_Parameters.m_BiasEnabled)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001038 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001039 numInputs = 3;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001040 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001041
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001042 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001043 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1044
1045 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1046 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1047
1048 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1049
1050 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001051 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001052 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001053 }
1054
Matthew Sloyan81beae32021-07-13 19:46:11 +01001055 TensorInfo weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001056 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001057
1058 if (m_Parameters.m_BiasEnabled)
1059 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001060 TensorInfo biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
telsoa01c577f2c2018-08-31 09:22:23 +01001061 // Validates type and quantization values.
Ryan OSheaf183acd2023-07-06 11:41:25 +01001062 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001063 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1064 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001065 }
1066
Francis Murtagh46c09d02019-05-28 08:15:28 +01001067 // Check the supported data types
1068 std::vector<DataType> supportedTypes =
1069 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001070 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001071 DataType::Float32,
1072 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001073 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001074 DataType::QAsymmU8,
1075 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001076 };
1077
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001078 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001079
1080 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1081 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1082 {
1083 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1084 {
1085 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1086 "for BFloat16 input.");
1087 }
1088 }
1089 else
1090 {
1091 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1092 }
telsoa014fcda012018-03-09 14:13:49 +00001093}
1094
Teresa Charlin9145e382023-08-17 18:44:58 +01001095void FusedQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
1096{
1097 // This is internally generated, so it should not need validation.
1098}
1099
telsoa014fcda012018-03-09 14:13:49 +00001100void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1101{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001102 const std::string descriptorName{"NormalizationQueueDescriptor"};
1103
1104 ValidateNumInputs(workloadInfo, descriptorName, 1);
1105 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1106
1107 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1108 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001109
1110 // Check the supported data types
1111 std::vector<DataType> supportedTypes =
1112 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001113 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001114 DataType::Float16,
1115 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001116 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001117 DataType::QAsymmU8,
1118 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001119 };
1120
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001121 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001122
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001123 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001124
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001125 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001126}
1127
1128void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1129{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001130 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001131
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001132 ValidateNumInputs(workloadInfo, descriptorName, 2);
1133 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1134
1135 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1136 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1137 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1138
1139 std::vector<DataType> supportedTypes =
1140 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001141 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001142 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001143 DataType::Float16,
1144 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001145 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001146 DataType::QSymmS16,
1147 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001148 };
1149
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001150 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1151 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1152 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001153
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001154 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1155 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001156
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001157 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1158 inputTensorInfo1,
1159 outputTensorInfo,
1160 descriptorName,
1161 "input_0",
1162 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001163}
1164
telsoa014fcda012018-03-09 14:13:49 +00001165void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1166{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001167 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001168
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001169 ValidateNumInputs(workloadInfo, descriptorName, 2);
1170 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1171
1172 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1173 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1174 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1175
1176 std::vector<DataType> supportedTypes =
1177 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001178 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001179 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001180 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001181 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001182 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001183 DataType::QSymmS16,
1184 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001185 };
1186
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001187 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1188 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1189 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001190
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001191 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1192 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001193
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001194 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1195 inputTensorInfo1,
1196 outputTensorInfo,
1197 descriptorName,
1198 "input_0",
1199 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001200}
1201
1202void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1203{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001204 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001205
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001206 ValidateNumInputs(workloadInfo, descriptorName, 1);
1207 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1208
1209 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1210 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001211
1212 std::vector<DataType> supportedTypes =
1213 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001214 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001215 DataType::Float16,
1216 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001217 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001218 DataType::QAsymmU8,
1219 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001220 };
1221
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001222 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1223 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001224
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001225 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001226 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001227
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001228 ValidatePointer(m_Mean, descriptorName, "mean");
1229 ValidatePointer(m_Variance, descriptorName, "variance");
1230 ValidatePointer(m_Beta, descriptorName, "beta");
1231 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001232
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001233 const TensorInfo& mean = m_Mean->GetTensorInfo();
1234 const TensorInfo& variance = m_Variance->GetTensorInfo();
1235 const TensorInfo& beta = m_Beta->GetTensorInfo();
1236 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001237
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001238 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1239 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1240 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1241 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001242
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001243 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1244 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1245 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001246}
1247
1248void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1249{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001250 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001251
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001252 uint32_t numInputs = 2;
1253 if (m_Parameters.m_BiasEnabled)
1254 {
1255 numInputs = 3;
1256 }
1257
1258 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001259 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001260
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001261 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1262 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001263
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001264 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1265 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001266
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001267 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
telsoa014fcda012018-03-09 14:13:49 +00001268
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001269 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001270
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001271 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001272
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001273 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001274 if (m_Parameters.m_BiasEnabled)
1275 {
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001276 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001277 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001278
1279 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Ryan OSheaf183acd2023-07-06 11:41:25 +01001280 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001281 }
1282
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001283 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1284 {
1285 throw InvalidArgumentException(
1286 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1287 "cannot be either negative or 0.",
1288 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1289 }
1290
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001291 ValidatePerAxisQuantization(inputTensorInfo,
1292 outputTensorInfo,
1293 weightTensorInfo,
1294 optionalBiasTensorInfo,
1295 descriptorName);
1296
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001297 std::vector<DataType> supportedTypes =
1298 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001299 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001300 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001301 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001302 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001303 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001304 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001305 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001306 };
1307
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001308 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001309
1310 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1311 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1312 {
1313 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1314 {
1315 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1316 "for BFloat16 input.");
1317 }
1318 }
1319 else
1320 {
1321 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1322 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001323}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001324
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001325void Convolution3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1326{
1327 const std::string descriptorName{"Convolution3dQueueDescriptor"};
1328
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001329 uint32_t numInputs = 2;
1330 if (m_Parameters.m_BiasEnabled)
1331 {
1332 numInputs = 3;
1333 }
1334 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001335 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1336
1337 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1338 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1339
1340 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1341 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1342
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001343 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001344 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 5, "weight");
1345
1346 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1347
1348 Optional<TensorInfo> optionalBiasTensorInfo;
1349 if (m_Parameters.m_BiasEnabled)
1350 {
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001351 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001352 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1353
1354 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Ryan OSheaf183acd2023-07-06 11:41:25 +01001355 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001356 }
1357
1358 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 || m_Parameters.m_StrideZ <= 0 )
1359 {
1360 throw InvalidArgumentException(
1361 fmt::format("{}: strideX (provided {}), strideY (provided {}) or strideZ (provided {})"
1362 "cannot be either negative or 0.",
1363 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY, m_Parameters.m_StrideZ));
1364 }
1365
1366 ValidatePerAxisQuantization(inputTensorInfo,
1367 outputTensorInfo,
1368 weightTensorInfo,
1369 optionalBiasTensorInfo,
1370 descriptorName);
1371
1372 std::vector<DataType> supportedTypes =
1373 {
1374 DataType::BFloat16,
1375 DataType::Float16,
1376 DataType::Float32,
1377 DataType::QAsymmS8,
1378 DataType::QAsymmU8,
1379 DataType::QSymmS16,
1380 DataType::QSymmS8
1381 };
1382
1383 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1384 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1385}
1386
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001387void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1388{
1389 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1390
Cathal Corbett06902652022-04-14 17:55:11 +01001391 uint32_t numInputs = 2;
1392 if (m_Parameters.m_BiasEnabled)
1393 {
1394 numInputs = 3;
1395 }
1396
1397 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001398 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1399
1400 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1401 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1402
1403 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1404 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1405
Cathal Corbett06902652022-04-14 17:55:11 +01001406 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001407 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1408
1409 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1410 {
1411 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001412 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1413 "cannot be smaller than 1.",
1414 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001415 }
1416
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001417 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1418 {
1419 throw InvalidArgumentException(
1420 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1421 "cannot be either negative or 0.",
1422 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1423 }
1424
Jan Eilers53ef7952021-06-02 12:01:25 +01001425 if (weightTensorInfo.GetShape()[0] != 1)
1426 {
1427 throw InvalidArgumentException(fmt::format(
1428 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1429 "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
1430 descriptorName,
1431 weightTensorInfo.GetShape()[0],
1432 weightTensorInfo.GetShape()[1],
1433 weightTensorInfo.GetShape()[2],
1434 weightTensorInfo.GetShape()[3]));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001435 }
1436
Cathal Corbett4b19d222022-05-11 20:12:17 +01001437 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1438 const unsigned int numWeightOutputChannelsRefFormat = weightTensorInfo.GetShape()[3];
1439 const unsigned int numWeightOutputChannelsAclFormat = weightTensorInfo.GetShape()[1];
1440 const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1441
1442 // Weights format has two valid options: [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] (CpuAcc/GpuAcc).
1443 bool validRefFormat = (numWeightOutputChannelsRefFormat == numOutputChannels);
1444 bool validAclFormat = (numWeightOutputChannelsAclFormat == numOutputChannels);
1445
1446 if (!(validRefFormat || validAclFormat))
1447 {
1448 throw InvalidArgumentException(fmt::format(
1449 "{0}: The weight format in armnn is expected to be [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] "
1450 "(CpuAcc/GpuAcc). But neither the 4th (CpuRef) or 2nd (CpuAcc/GpuAcc) dimension is equal to Cout."
1451 "Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
1452 descriptorName,
1453 numOutputChannels,
1454 weightTensorInfo.GetShape()[0],
1455 weightTensorInfo.GetShape()[1],
1456 weightTensorInfo.GetShape()[2],
1457 weightTensorInfo.GetShape()[3]));
1458 }
1459
Teresa Charlind8df0262019-11-11 12:28:15 +00001460 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001461
Teresa Charlind8df0262019-11-11 12:28:15 +00001462 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001463 if (m_Parameters.m_BiasEnabled)
1464 {
Cathal Corbett06902652022-04-14 17:55:11 +01001465 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Teresa Charlind8df0262019-11-11 12:28:15 +00001466 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001467
Ryan OSheaf183acd2023-07-06 11:41:25 +01001468 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001469 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1470 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001471 ValidatePerAxisQuantization(inputTensorInfo,
1472 outputTensorInfo,
1473 weightTensorInfo,
1474 optionalBiasTensorInfo,
1475 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001476
1477 std::vector<DataType> supportedTypes =
1478 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001479 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001480 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001481 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001482 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001483 DataType::QAsymmU8,
1484 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001485 };
1486
1487 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1488 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001489}
1490
1491void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1492{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001493 const std::string descriptorName{"PermuteQueueDescriptor"};
1494
1495 ValidateNumInputs(workloadInfo, descriptorName, 1);
1496 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001497
1498 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1499
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001500 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1501 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001502
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001503 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1504 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001505
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001506 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001507 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001508 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001509 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001510 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1511 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1512 "must match dst dimension " + to_string(mapping[i]) +
1513 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001514 }
1515 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001516
1517 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001518}
1519
1520void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1521{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001522 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001523
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001524 ValidateNumInputs(workloadInfo, descriptorName, 1);
1525 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1526
1527 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1528 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1529
1530 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1531 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001532
1533 std::vector<DataType> supportedTypes =
1534 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001535 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001536 DataType::Float32,
1537 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001538 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001539 DataType::QAsymmU8,
1540 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001541 };
1542
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001543 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1544 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001545}
1546
Tamás Nyíri7b885b32021-10-26 14:47:57 +01001547void Pooling3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1548{
1549 const std::string descriptorName{"Pooling3dQueueDescriptor"};
1550
1551 ValidateNumInputs(workloadInfo, descriptorName, 1);
1552 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1553
1554 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1555 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1556
1557 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1558 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1559
1560 std::vector<DataType> supportedTypes =
1561 {
1562 DataType::BFloat16,
1563 DataType::Float32,
1564 DataType::Float16,
1565 DataType::QAsymmS8,
1566 DataType::QAsymmU8,
1567 DataType::QSymmS16
1568 };
1569
1570 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1571 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1572}
1573
Teresa Charlin970f43b2019-07-01 13:51:07 +01001574void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1575{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001576 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001577
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001578 ValidateNumInputs(workloadInfo, descriptorName, 1);
1579 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1580
1581 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1582 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1583
1584 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1585 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001586
1587 std::vector<DataType> supportedTypes =
1588 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001589 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001590 DataType::Float16,
1591 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001592 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001593 DataType::QAsymmU8,
Teresa Charlince655882023-11-21 15:44:13 +00001594 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001595 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001596 };
1597
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001598 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1599 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001600
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001601 // Resize only changes width and height: batch and channel count must match.
1602 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1603 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001604 if (inputBatchSize != outputBatchSize)
1605 {
1606 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001607 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1608 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001609 }
1610
1611 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001612 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1613 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001614 if (inputChannelCount != outputChannelCount)
1615 {
1616 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001617 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1618 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001619 }
1620}
1621
Teresa Charlin79a06a52023-07-13 17:16:45 +01001622void ReverseV2QueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
1623{
Tianle Cheng988354d2023-06-28 13:20:47 +01001624 const std::string descriptorName{"ReverseV2QueueDescriptor"};
1625
Tracy Narinebb8d7592023-07-13 16:50:54 +01001626 // Backend restriction
1627 const unsigned int maxDimensions = 4;
1628
1629 ValidateNumInputs(workloadInfo, descriptorName, 2);
Tianle Cheng988354d2023-06-28 13:20:47 +01001630 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1631
1632 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
Tracy Narinebb8d7592023-07-13 16:50:54 +01001633 const TensorInfo& axisTensorInfo = workloadInfo.m_InputTensorInfos[1];
Tianle Cheng988354d2023-06-28 13:20:47 +01001634 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1635
Tracy Narinebb8d7592023-07-13 16:50:54 +01001636 const auto inputTensorNumDimensions = inputTensorInfo.GetNumDimensions();
1637 if (inputTensorNumDimensions > maxDimensions)
Tianle Cheng988354d2023-06-28 13:20:47 +01001638 {
1639 throw InvalidArgumentException(descriptorName +
1640 ": Input tensors with rank greater than " +
Tracy Narinebb8d7592023-07-13 16:50:54 +01001641 std::to_string(maxDimensions) + " are not supported.");
1642 }
1643
1644 const auto axisTensorNumDimensions = axisTensorInfo.GetNumDimensions();
1645 if (axisTensorNumDimensions > maxDimensions)
1646 {
1647 throw InvalidArgumentException(descriptorName +
1648 ": More than " + std::to_string(maxDimensions) + " axes cannot be specified.");
1649 }
1650
1651 if (axisTensorNumDimensions > inputTensorNumDimensions)
1652 {
1653 throw InvalidArgumentException(descriptorName +
1654 ": More axes specified than the number of axes on the input tensor.");
Tianle Cheng988354d2023-06-28 13:20:47 +01001655 }
1656
1657 std::vector<DataType> supportedTypes =
1658 {
1659 DataType::BFloat16,
1660 DataType::Float16,
1661 DataType::Float32,
1662 DataType::QAsymmS8,
1663 DataType::QAsymmU8,
Declan-ARM1bf56cd2023-07-20 17:32:57 +01001664 DataType::QSymmS8,
1665 DataType::QSymmS16,
1666 DataType::Signed32
Tianle Cheng988354d2023-06-28 13:20:47 +01001667 };
1668
1669 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Tracy Narinebb8d7592023-07-13 16:50:54 +01001670
1671 std::vector<DataType> axisSupportedTypes =
1672 {
1673 DataType::Signed32,
1674 };
1675
1676 ValidateDataTypes(axisTensorInfo, axisSupportedTypes, descriptorName);
1677
Tianle Cheng988354d2023-06-28 13:20:47 +01001678 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1679 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Tianle Cheng988354d2023-06-28 13:20:47 +01001680}
1681
telsoa014fcda012018-03-09 14:13:49 +00001682void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1683{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001684 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001685
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001686 ValidateNumInputs(workloadInfo, descriptorName, 1);
1687 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1688
1689 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1690 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1691
1692 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1693 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1694
1695 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1696
telsoa014fcda012018-03-09 14:13:49 +00001697 if (m_Parameters.m_Min > m_Parameters.m_Max)
1698 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001699 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001700 }
telsoa014fcda012018-03-09 14:13:49 +00001701}
1702
Kevin Mayce5045a2019-10-02 14:07:47 +01001703void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1704{
1705 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1706
1707 ValidateNumInputs(workloadInfo, descriptorName, 1);
1708 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1709
1710 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1711 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1712
1713 if (inputTensorInfo.GetNumDimensions() > 4)
1714 {
1715 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1716 }
1717
1718 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1719
1720 // Check the supported data types
1721 std::vector<DataType> supportedTypes =
1722 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001723 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001724 DataType::Float32,
1725 DataType::Float16
1726 };
1727
1728 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001729 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001730}
1731
telsoa014fcda012018-03-09 14:13:49 +00001732void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1733{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001734 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001735
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001736 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001737 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1738
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001739 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1740 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1741
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001742 if (inputTensorInfo.GetNumDimensions() > 4)
1743 {
1744 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1745 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001746
1747 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001748
1749 // Check the supported data types
1750 std::vector<DataType> supportedTypes =
1751 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001752 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001753 DataType::Float32,
1754 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001755 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001756 DataType::QAsymmU8,
1757 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001758 };
1759
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001760 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001761 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1762}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001763
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001764void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1765{
1766 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1767
1768 ValidateNumInputs(workloadInfo, descriptorName, 1);
1769 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1770
1771 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1772 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1773
1774 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1775
1776 std::vector<DataType> supportedTypes =
1777 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001778 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001779 DataType::Float32,
1780 DataType::Float16,
1781 };
1782
1783 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001784 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001785}
1786
1787void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1788{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001789 const std::string descriptorName{"ConstantQueueDescriptor"};
1790
1791 ValidateNumInputs(workloadInfo, descriptorName, 0);
1792 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001793
1794 if (!m_LayerOutput)
1795 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001796 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001797 }
1798
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001799 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1800 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001801
1802 // Check the supported data types
1803 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001804 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001805 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001806 DataType::Float32,
1807 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001808 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001809 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001810 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001811 DataType::QSymmS16,
1812 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001813 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001814
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001815 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001816}
1817
1818void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1819{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001820 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001821
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001822 ValidateNumInputs(workloadInfo, descriptorName, 1);
1823 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1824
1825 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1826 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1827
1828 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001829
1830 // Check the supported data types
1831 std::vector<DataType> supportedTypes =
1832 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001833 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001834 DataType::Float32,
1835 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001836 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001837 DataType::QAsymmU8,
1838 DataType::QSymmS16,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001839 DataType::Signed32,
1840 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001841 };
1842
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001843 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1844 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001845}
1846
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001847void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1848{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001849 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001850
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001851 ValidateNumInputs(workloadInfo, descriptorName, 1);
1852 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1853
1854 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1855 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1856
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001857 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1858 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001859 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1860 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001861 }
1862
Teresa Charlinf77cab52023-06-01 16:15:13 +01001863 if (m_Parameters.m_BlockShape.size() == 2)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001864 {
Teresa Charlinf77cab52023-06-01 16:15:13 +01001865 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1866 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1867 }
1868 else if (m_Parameters.m_BlockShape.size() == 1)
1869 {
1870 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 3, "input");
1871 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 3, "output");
1872 }
1873 else
1874 {
1875 throw InvalidArgumentException(descriptorName + ": Invalid Block and Crops size.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001876 }
1877
Teresa Charlinf77cab52023-06-01 16:15:13 +01001878 // Check input + padding and output have the same number of elements
1879 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1880 const unsigned int inputHeight = inputTensorInfo.GetShape()[dimensionIndices.GetHeightIndex()] +
1881 m_Parameters.m_PadList[0].first + m_Parameters.m_PadList[0].second;
1882 const unsigned int inputWidth = (inputTensorInfo.GetNumDimensions() == 3) ? 1 :
1883 inputTensorInfo.GetShape()[dimensionIndices.GetWidthIndex()] +
1884 m_Parameters.m_PadList[1].first + m_Parameters.m_PadList[1].second;
1885
1886 const int channelsIndex_int = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : -1;
1887 const unsigned int channelsIndex = channelsIndex_int < 0 ?
1888 static_cast<unsigned int>(channelsIndex_int) + inputTensorInfo.GetNumDimensions()
1889 : static_cast<unsigned int>(channelsIndex_int);
1890
1891 const unsigned int numInputElements = inputTensorInfo.GetShape()[0] *
1892 inputHeight *
1893 inputWidth *
1894 inputTensorInfo.GetShape()[channelsIndex];
1895
1896 if (outputTensorInfo.GetNumElements() != numInputElements)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001897 {
Teresa Charlinf77cab52023-06-01 16:15:13 +01001898 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
1899 to_string(numInputElements) + " after padding but output tensor has " +
1900 to_string(outputTensorInfo.GetNumElements()) + " elements.");
1901 }
1902
1903 // In a 4D tensor, there will be 2 spatialDimensions (H and W), and the for loop will run twice.
1904 // In a 3D tensor, there will be 1 spatialDimensions, and the for loop will run once.
1905 unsigned int firstSpatialDimension = m_Parameters.m_DataLayout == DataLayout::NCHW ? 2 : 1;
1906 for (unsigned int i = 0; i < m_Parameters.m_BlockShape.size(); ++i)
1907 {
1908 unsigned int spatialDimension = firstSpatialDimension + i;
1909 auto inputSize = inputTensorInfo.GetShape()[spatialDimension] +
1910 m_Parameters.m_PadList[i].first +
1911 m_Parameters.m_PadList[i].second;
1912 if (inputSize % m_Parameters.m_BlockShape[i] != 0)
1913 {
1914 throw InvalidArgumentException(descriptorName + ": Input dimension size after padding must be "
1915 "divisible by Block Shape in dimension: " + to_string(spatialDimension) + ".");
1916 }
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001917 }
nikraj01120522a2019-05-31 11:33:07 +01001918
1919 std::vector<DataType> supportedTypes =
1920 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001921 DataType::BFloat16,
1922 DataType::Float16,
1923 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001924 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001925 DataType::QAsymmU8,
1926 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001927 };
1928
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001929 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1930 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001931}
1932
Keith Davisa57eccb2019-06-14 17:33:22 +01001933void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1934{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001935 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001936
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001937 ValidateNumInputs(workloadInfo, descriptorName, 1);
1938 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001939
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001940 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1941 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1942
1943 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1944 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001945
1946 std::vector<DataType> supportedTypes =
1947 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001948 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001949 DataType::Float32,
1950 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001951 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001952 DataType::QAsymmU8,
1953 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001954 };
1955
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001956 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1957 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001958
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001959 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1960
1961 if (m_Parameters.m_BlockSize == 0)
1962 {
1963 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1964 }
1965
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001966 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1967 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1968 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1969 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001970
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001971 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001972 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001973 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001974 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1975 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001976 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001977
1978 const TensorShape& outputShape = outputTensorInfo.GetShape();
1979 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1980 {
1981 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1982 "must be divisible by the square of block size." );
1983 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001984}
1985
telsoa014fcda012018-03-09 14:13:49 +00001986void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1987{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001988 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001989
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001990 ValidateNumInputs(workloadInfo, descriptorName, 1);
1991 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1992
1993 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1994 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001995
1996 std::vector<DataType> supportedTypes =
1997 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001998 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001999 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002000 DataType::Float16,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01002001 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01002002 };
2003
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002004 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matthew Sloyan81beae32021-07-13 19:46:11 +01002005 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2006 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2007 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00002008}
2009
telsoa01c577f2c2018-08-31 09:22:23 +01002010void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2011{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002012 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
2013
2014 const std::string descriptorName{"LstmQueueDescriptor"};
2015
2016 // check dimensions of all inputs and outputs
2017 if (workloadInfo.m_InputTensorInfos.size() != 3)
2018 {
2019 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
2020 }
2021 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2022 {
2023 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
2024 }
2025
2026 std::vector<DataType> supportedTypes =
2027 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002028 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01002029 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002030 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002031 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002032 };
2033
Jan Eilers38e05bd2019-06-26 13:10:09 +01002034 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002035 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
2036
Jan Eilers38e05bd2019-06-26 13:10:09 +01002037 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002038 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01002039 {
2040 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2041 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002042 descriptorName,
2043 "input_0",
2044 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01002045 }
2046 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002047 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01002048 {
2049 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2050 workloadInfo.m_OutputTensorInfos[i],
2051 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002052 "input_0",
2053 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01002054 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002055
janeil0117d8d852019-11-15 15:00:16 +00002056 // Making sure clipping parameters have valid values.
2057 // == 0 means no clipping
2058 // > 0 means clipping
2059 if (m_Parameters.m_ClippingThresCell < 0.0f)
2060 {
2061 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
2062 }
2063 if (m_Parameters.m_ClippingThresProj < 0.0f)
2064 {
2065 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
2066 }
2067
Jan Eilers38e05bd2019-06-26 13:10:09 +01002068 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01002069 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
2070 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
2071 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
2072 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
2073 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
2074 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
2075
Jan Eilers38e05bd2019-06-26 13:10:09 +01002076 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002077 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
2078 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002079 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002080 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
2081 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002082 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002083 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
2084 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002085 // scratchBufferTensor
2086 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002087 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
2088 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002089 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002090 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
2091 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002092 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002093 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
2094 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002095 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002096 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
2097 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002098
Jan Eilers38e05bd2019-06-26 13:10:09 +01002099 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
2100 if ( m_InputToInputWeights )
2101 {
2102 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
2103 (n_cell * n_input), "InputLayerNormWeights");
2104 }
2105
2106 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2107 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
2108 (n_cell * n_input), "InputToForgetWeights");
2109
2110 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2111 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2112 (n_cell * n_input), "InputToCellWeights");
2113
2114 if ( m_RecurrentToInputWeights )
2115 {
2116 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2117 (n_cell * n_output), "RecurrentToInputWeights");
2118 }
2119
2120 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2121 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2122 (n_cell * n_output), "RecurrentToForgetWeights");
2123
2124 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2125 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2126 (n_cell * n_output), "RecurrentToCellWeights");
2127
2128 // Make sure the input-gate's parameters are either both present (regular
2129 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2130 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2131 !m_Parameters.m_CifgEnabled) ||
2132 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2133 m_Parameters.m_CifgEnabled));
2134 if (!cifg_weights_all_or_none)
2135 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002136 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2137 "RecurrentToInputWeights must either both be present (regular LSTM) "
2138 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2139 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002140 }
2141
2142 if ( m_CellToInputWeights )
2143 {
2144 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2145 n_cell, "CellToInputWeights");
2146 }
2147 if ( m_CellToForgetWeights )
2148 {
2149 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2150 n_cell, "CellToForgetWeights");
2151 }
2152 if ( m_CellToOutputWeights )
2153 {
2154 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2155 n_cell, "CellToOutputWeights");
2156 }
2157
2158 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2159 bool peephole_weights_all_or_none =
2160 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2161 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2162 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2163 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2164 if (!peephole_weights_all_or_none)
2165 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002166 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002167 }
2168
2169 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2170 if (m_Parameters.m_CifgEnabled)
2171 {
2172 if (m_InputGateBias)
2173 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002174 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002175 }
2176 }
2177 else
2178 {
2179 if (!m_InputGateBias)
2180 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002181 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2182 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002183 }
2184 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2185 n_cell, "InputGateBias");
2186 }
2187
2188 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2189 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2190
2191 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2192 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2193
2194 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2195 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2196
2197 if (m_ProjectionWeights)
2198 {
2199 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2200 (n_cell * n_output), "ProjectionWeights");
2201 }
2202 if (m_ProjectionBias)
2203 {
2204 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2205 }
2206
2207 // Making sure the projection tensors are consistent:
2208 // 1) If projection weight is not present, then projection bias should not be
2209 // present.
2210 // 2) If projection weight is present, then projection bias is optional.
2211 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2212 !m_Parameters.m_ProjectionEnabled)
2213 || (m_ProjectionWeights && !m_ProjectionBias &&
2214 m_Parameters.m_ProjectionEnabled)
2215 || (m_ProjectionWeights && m_ProjectionBias &&
2216 m_Parameters.m_ProjectionEnabled));
2217 if (!projecton_tensors_consistent)
2218 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002219 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002220 }
2221
2222 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2223 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2224 // either all have values or none of them have values. Layer normalization is used when the values of all the
2225 // layer normalization weights are present
2226 if (m_InputLayerNormWeights)
2227 {
2228 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2229 }
2230 if (m_ForgetLayerNormWeights)
2231 {
2232 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2233 }
2234 if (m_CellLayerNormWeights)
2235 {
2236 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2237 }
2238 if (m_OutputLayerNormWeights)
2239 {
2240 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2241 }
2242
Jan Eilers38e05bd2019-06-26 13:10:09 +01002243 if (m_Parameters.m_LayerNormEnabled)
2244 {
2245 if (!m_Parameters.m_CifgEnabled)
2246 {
2247 if (!m_InputLayerNormWeights)
2248 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002249 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2250 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002251 }
2252 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2253 1, n_cell, "InputLayerNormWeights");
2254 }
2255 else if (m_InputLayerNormWeights)
2256 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002257 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2258 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002259 }
2260
2261 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2262 "ForgetLayerNormWeights");
2263 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2264
2265 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2266 "OutputLayerNormWeights");
2267 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2268
2269 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2270 "CellLayerNormWeights");
2271 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2272 }
2273 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2274 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002275 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2276 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002277 }
telsoa01c577f2c2018-08-31 09:22:23 +01002278}
2279
2280void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2281{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002282 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002283
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002284 ValidateNumInputs(workloadInfo, descriptorName, 1);
2285 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2286
2287 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2288 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2289
2290 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002291 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002292 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002293 }
2294
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002295 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002296 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002297 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002298 }
2299
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002300 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002301}
2302
2303void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2304{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002305 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002306
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002307 ValidateNumInputs(workloadInfo, descriptorName, 1);
2308 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2309
2310 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2311 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2312
2313 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002314 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002315 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002316 }
2317
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002318 if (outputTensorInfo.GetDataType() != DataType::Float32)
2319 {
2320 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2321 }
2322
2323 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002324}
2325
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002326void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2327{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002328 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002329
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002330 ValidateNumInputs(workloadInfo, descriptorName, 2);
2331 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2332
2333 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2334 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2335 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2336
2337 std::vector<DataType> supportedTypes =
2338 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002339 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002340 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002341 DataType::Float32,
2342 DataType::QAsymmS8,
2343 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002344 DataType::QSymmS16,
2345 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002346 };
2347
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002348 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2349 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2350 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002351
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002352 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2353 inputTensorInfo1,
2354 outputTensorInfo,
2355 descriptorName,
2356 "input_0",
2357 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002358}
2359
David Beckc2044fe2018-09-05 15:00:38 +01002360void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2361{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002362 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002363
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002364 ValidateNumInputs(workloadInfo, descriptorName, 2);
2365 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2366
2367 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2368 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2369 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2370
2371 std::vector<DataType> supportedTypes =
2372 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002373 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002374 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002375 DataType::Float32,
2376 DataType::QAsymmS8,
2377 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002378 DataType::QSymmS16,
2379 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002380 };
2381
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002382 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2383 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2384 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002385
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002386 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2387 inputTensorInfo1,
2388 outputTensorInfo,
2389 descriptorName,
2390 "input_0",
2391 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002392}
2393
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002394void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2395{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002396 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002397
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002398 ValidateNumInputs(workloadInfo, descriptorName, 2);
2399 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2400
2401 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2402 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2403 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2404
2405 std::vector<DataType> supportedTypes =
2406 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002407 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002408 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002409 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002410 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002411 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002412 DataType::QSymmS16,
2413 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002414 };
2415
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002416 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2417 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2418 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002419
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002420 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2421 inputTensorInfo1,
2422 outputTensorInfo,
2423 descriptorName,
2424 "input_0",
2425 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002426}
2427
narpra01a6bf9122018-09-10 09:50:09 +01002428void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2429{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002430 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002431
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002432 ValidateNumInputs(workloadInfo, descriptorName, 1);
2433 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2434
2435 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2436 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002437
2438 std::vector<DataType> supportedTypes =
2439 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002440 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002441 DataType::Float32,
2442 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002443 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002444 DataType::QAsymmU8,
2445 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002446 };
narpra01eb061912018-09-10 17:35:27 +01002447
James Conroy4d1ff582019-06-10 17:06:39 +01002448 // First check if input tensor data type is supported, then
2449 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002450 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2451 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002452
narpra0132b90462018-09-13 11:07:48 +01002453 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002454 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002455 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002456 }
narpra0132b90462018-09-13 11:07:48 +01002457 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002458 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002459 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002460 }
2461 else
2462 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002463 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002464 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002465 ValidateTensorNumDimensions(outputTensorInfo,
2466 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002467 outputDim > 0 ? outputDim : 1,
2468 "output");
2469 }
narpra01a6bf9122018-09-10 09:50:09 +01002470}
2471
jimfly012c9322a2018-09-19 10:59:49 +01002472void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2473{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002474 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002475
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002476 ValidateNumInputs(workloadInfo, descriptorName, 1);
2477 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2478
2479 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2480 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002481
jimfly012c9322a2018-09-19 10:59:49 +01002482 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002483 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2484
jimfly012c9322a2018-09-19 10:59:49 +01002485 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002486 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2487 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2488 "as there are dimensions in the input tensor that is " +
2489 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2490 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002491 }
2492}
2493
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002494void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2495{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002496 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002497
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002498 ValidateNumInputs(workloadInfo, descriptorName, 1);
2499 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002500
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002501 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2502 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2503
Sadik Armagan2208b602019-07-31 16:36:27 +01002504 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002505 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002506 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002507 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002508 DataType::Float16,
2509 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002510 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002511 DataType::QAsymmU8,
2512 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002513 };
2514
2515 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002516
Keith Davis0c2eeac2020-02-11 16:51:50 +00002517 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002518 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002519 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002520 }
2521}
2522
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002523void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2524{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002525 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002526
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002527 ValidateNumInputs(workloadInfo, descriptorName, 1);
2528 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002529
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002530 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2531 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002532
Teresa Charlinf77cab52023-06-01 16:15:13 +01002533 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_Crops.size())
2534 {
2535 throw InvalidArgumentException(descriptorName + ": Crops must contain the same number of "
2536 "dimensions as Block Shape.");
2537 }
2538
2539 if (m_Parameters.m_BlockShape.size() == 2)
2540 {
2541 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2542 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
2543 }
2544 else if (m_Parameters.m_BlockShape.size() == 1)
2545 {
2546 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 3, "input");
2547 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 3, "output");
2548 }
2549 else
2550 {
2551 throw InvalidArgumentException(descriptorName + ": Invalid Block and Crops size.");
2552 }
2553
2554 // In a 4D tensor, there will be 2 spatialDimensions (H and W), and the for loop will run twice.
2555 // In a 3D tensor, there will be 1 spatialDimensions, and the for loop will run once.
2556 unsigned int firstSpatialDimension = m_Parameters.m_DataLayout == DataLayout::NCHW ? 2 : 1;
2557 for (unsigned int i = 0; i < m_Parameters.m_BlockShape.size(); ++i)
2558 {
2559 unsigned int spatialDimension = firstSpatialDimension + i;
2560 unsigned int cropSize = m_Parameters.m_Crops[i].first + m_Parameters.m_Crops[i].second;
2561 unsigned int outputSize = inputTensorInfo.GetShape()[spatialDimension] * m_Parameters.m_BlockShape[i];
2562 if (cropSize > outputSize)
2563 {
2564 throw InvalidArgumentException(descriptorName + ": CropSize must be less than or equal to the uncropped"
2565 "outputSize in dimension: " + to_string(spatialDimension) + ".");
2566 }
2567 }
2568
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002569 std::vector<DataType> supportedTypes =
2570 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002571 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002572 DataType::Float32,
2573 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002574 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002575 DataType::QAsymmU8,
2576 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002577 };
2578
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002579 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2580 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002581}
2582
Conor Kennedy430b5d82018-11-14 15:28:28 +00002583void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2584{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002585 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002586
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002587 ValidateNumInputs(workloadInfo, descriptorName, 1);
2588 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2589
2590 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2591 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002592
2593 std::vector<DataType> supportedTypes =
2594 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002595 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002596 DataType::Float16,
2597 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002598 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002599 DataType::QAsymmU8,
2600 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002601 };
2602
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002603 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2604 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002605
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002606 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002607
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002608 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002609 if (rank > 4)
2610 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002611 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002612 }
2613
Conor Kennedy430b5d82018-11-14 15:28:28 +00002614 // Begin, End & Stride length must be of rank(input0)
2615 if (m_Parameters.m_Begin.size() != rank)
2616 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002617 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002618 }
2619
2620 if (m_Parameters.m_End.size() != rank)
2621 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002622 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002623 }
2624
2625 if (m_Parameters.m_Stride.size() != rank)
2626 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002627 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002628 }
2629
2630 // Stride entries must be non-zero
2631 for (auto& stride : m_Parameters.m_Stride)
2632 {
2633 if (stride == 0)
2634 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002635 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002636 }
2637 }
2638}
2639
kevmay0190539692018-11-29 08:40:19 +00002640void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2641{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002642 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002643
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002644 ValidateNumInputs(workloadInfo, descriptorName, 2);
2645 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2646
2647 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2648 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2649 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2650
2651 std::vector<DataType> supportedTypes =
2652 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002653 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002654 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002655 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002656 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002657 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002658 DataType::QSymmS16,
2659 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002660 };
2661
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002662 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2663 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2664 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002665
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002666 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2667 inputTensorInfo1,
2668 outputTensorInfo,
2669 descriptorName,
2670 "input_0",
2671 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002672}
2673
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002674void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2675{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002676 const std::string descriptorName{"DebugQueueDescriptor"};
2677
2678 ValidateNumInputs(workloadInfo, descriptorName, 1);
2679 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002680}
2681
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002682void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2683{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002684 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002685
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002686 ValidateNumInputs(workloadInfo, descriptorName, 2);
2687 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002688
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002689 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2690 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2691 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2692
2693 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2694 inputTensorInfo1,
2695 outputTensorInfo,
2696 descriptorName,
2697 "input_0",
2698 "input_1");
2699
2700 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002701 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002702 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002703 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002704}
2705
FrancisMurtagh878f0232018-12-19 10:56:15 +00002706void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2707{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002708 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002709
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002710 ValidateNumInputs(workloadInfo, descriptorName, 2);
2711 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002712
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002713 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2714 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2715 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2716
2717 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2718 inputTensorInfo1,
2719 outputTensorInfo,
2720 descriptorName,
2721 "input_0",
2722 "input_1");
2723
2724 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002725 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002726 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002727 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002728}
2729
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002730void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2731{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002732 const std::string descriptorName{"RsqrtQueueDescriptor"};
2733
2734 ValidateNumInputs(workloadInfo, descriptorName, 1);
2735 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2736
2737 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2738 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2739
2740 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002741
2742 std::vector<DataType> supportedTypes =
2743 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002744 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002745 DataType::Float16,
2746 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002747 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002748 DataType::QAsymmU8,
2749 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002750 };
2751
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002752 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2753 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002754}
2755
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01002756void GatherNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2757{
2758 const std::string descriptorName{"GatherNdQueueDescriptor"};
2759
2760 ValidateNumInputs(workloadInfo, descriptorName, 2);
2761 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2762
2763 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2764 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
2765 {
2766 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
2767 }
2768
2769 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2770 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2771
2772 std::vector<DataType> supportedTypes =
2773 {
2774 DataType::BFloat16,
2775 DataType::Float16,
2776 DataType::Float32,
2777 DataType::QAsymmS8,
2778 DataType::QAsymmU8,
2779 DataType::QSymmS16,
2780 DataType::Signed32,
2781 };
2782
2783 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2784
2785 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2786
2787 unsigned int outputDim = outputTensorInfo.GetNumDimensions();
2788 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
2789}
2790
narpra01b89b05f2019-01-16 09:53:09 +00002791void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2792{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002793 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002794
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002795 ValidateNumInputs(workloadInfo, descriptorName, 2);
2796 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002797
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002798 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2799 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002800 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002801 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002802 }
2803
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002804 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2805 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2806
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002807 std::vector<DataType> supportedTypes =
2808 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002809 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002810 DataType::Float16,
2811 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002812 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002813 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002814 DataType::QSymmS16,
2815 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002816 };
2817
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002818 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002819
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002820 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002821
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002822 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2823 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002824}
2825
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002826void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2827{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002828 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2829
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002830 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002831
2832 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2833 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002834 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002835 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2836 }
2837
2838 if (m_Anchors == nullptr)
2839 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002840 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002841 }
2842
2843 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002844 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2845 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2846
2847 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002848 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002849 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2850 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002851
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002852 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2853 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2854 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002855
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002856 const std::vector<DataType> supportedInputTypes =
2857 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002858 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002859 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002860 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002861 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002862 DataType::QAsymmU8,
2863 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002864 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002865
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002866 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2867 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2868 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2869
2870 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2871 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2872 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2873 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2874
2875 // NOTE: Output is always Float32 regardless of input type
2876 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2877 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2878 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2879 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002880
2881 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2882 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002883 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002884 "must be positive and less than or equal to 1.");
2885 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002886
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002887 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2888 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002889 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002890 "should be equal to number of classes + 1.");
2891 }
2892}
2893
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002894void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2895{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002896 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002897
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002898 ValidateNumInputs(workloadInfo, descriptorName, 1);
2899 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2900
2901 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2902 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2903
Teresa Charlin07307f32022-05-15 14:07:05 +01002904 std::vector<DataType> inputSupportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002905 {
Teresa Charlin07307f32022-05-15 14:07:05 +01002906 DataType::QAsymmS8,
2907 DataType::QAsymmU8,
2908 DataType::QSymmS8,
2909 DataType::QSymmS16,
2910 DataType::Float16
2911 };
2912 ValidateDataTypes(inputTensorInfo, inputSupportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002913
Teresa Charlin07307f32022-05-15 14:07:05 +01002914 std::vector<DataType> outputSupportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002915 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002916 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002917 DataType::Float32,
2918 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002919 };
2920
Teresa Charlin07307f32022-05-15 14:07:05 +01002921 ValidateDataTypes(outputTensorInfo, outputSupportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002922}
2923
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002924void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2925{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002926 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002927
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002928 ValidateNumInputs(workloadInfo, descriptorName, 2);
2929 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002930
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002931 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2932 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2933 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002934
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002935 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2936 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2937
2938 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2939 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002940}
2941
Keith Davis3ae3f972021-05-21 16:33:48 +01002942void ShapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2943{
2944 const std::string& descriptorName{"ShapeQueueDescriptor"};
2945
2946 ValidateNumInputs(workloadInfo, descriptorName, 1);
2947 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2948
2949 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2950 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2951
2952 std::vector<DataType> supportedTypes =
2953 {
2954 DataType::BFloat16,
2955 DataType::Float16,
2956 DataType::Float32,
2957 DataType::QAsymmS8,
2958 DataType::QAsymmU8,
Keith Davis3ae3f972021-05-21 16:33:48 +01002959 DataType::QSymmS8,
2960 DataType::QSymmS16,
2961 DataType::Signed32
2962 };
2963
2964 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2965 ValidateDataTypes(outputTensorInfo, {DataType::Signed32}, descriptorName);
2966}
2967
Sadik Armaganeff363d2019-04-05 15:25:46 +01002968void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2969{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002970 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002971
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002972 ValidateNumInputs(workloadInfo, descriptorName, 2);
2973 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2974
2975 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2976 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2977
2978 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2979 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2980
2981 std::vector<DataType> supportedTypes =
2982 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002983 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002984 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002985 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002986 DataType::QAsymmU8,
2987 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002988 };
2989
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002990 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2991 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002992
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002993 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2994 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002995
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002996 ValidateTensorShapesMatch(inputTensorInfo0,
2997 outputTensorInfo0,
2998 descriptorName,
2999 "input_0",
3000 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01003001
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003002 ValidateTensorShapesMatch(inputTensorInfo0,
3003 outputTensorInfo1,
3004 descriptorName,
3005 "input_0",
3006 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01003007}
3008
Derek Lamberti901ea112019-12-10 22:07:09 +00003009void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00003010{
Teresa Charlin9145e382023-08-17 18:44:58 +01003011 // This is internally generated, so it should not need validation.
Matteo Martincigh49124022019-01-11 13:25:59 +00003012}
3013
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003014void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3015{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003016 const std::string& descriptorName{"PreluQueueDescriptor"};
3017
3018 ValidateNumInputs(workloadInfo, descriptorName, 2);
3019 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3020
3021 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3022 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
3023 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003024
3025 std::vector<DataType> supportedTypes
3026 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003027 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003028 DataType::Float16,
3029 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003030 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003031 DataType::QAsymmU8,
3032 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003033 };
3034
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003035 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3036 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003037
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003038 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003039
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003040 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
3041 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003042
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003043 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
3044 alphaTensorInfo,
3045 outputTensorInfo,
3046 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01003047 "input",
3048 "alpha");
3049}
3050
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003051void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3052{
3053 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
3054
3055 ValidateNumInputs(workloadInfo, descriptorName, 1);
3056 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3057
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003058 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3059 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3060
3061 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
3062 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003063
3064 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003065
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003066 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
3067 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003068
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003069 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
3070
3071 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003072 if (m_Parameters.m_BiasEnabled)
3073 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003074 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003075
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003076 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
3077 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003078
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003079 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Ryan OSheaf183acd2023-07-06 11:41:25 +01003080 ValidateBiasTensorQuantization(biasTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003081 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003082
3083 ValidatePerAxisQuantization(inputTensorInfo,
3084 outputTensorInfo,
3085 weightTensorInfo,
3086 optionalBiasTensorInfo,
3087 descriptorName);
3088
3089 std::vector<DataType> supportedTypes =
3090 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003091 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003092 DataType::Float32,
3093 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003094 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003095 DataType::QAsymmU8,
3096 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003097 };
3098
3099 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3100 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003101}
3102
Mike Kellyc9ea45a2020-02-28 18:11:58 +00003103void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3104{
3105 const std::string descriptorName{"TransposeQueueDescriptor"};
3106
3107 ValidateNumInputs(workloadInfo, descriptorName, 1);
3108 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3109
3110 const PermutationVector& mapping = m_Parameters.m_DimMappings;
3111
3112 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3113 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3114
3115 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
3116 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
3117
3118 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
3119 {
3120 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
3121 {
3122 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
3123 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
3124 "must match dst dimension " + to_string(i) +
3125 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
3126 }
3127 }
3128
3129 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3130}
3131
Simon Obute51f67772021-09-03 15:50:13 +01003132void ChannelShuffleQueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
3133{
3134 const std::string descriptorName{"TransposeQueueDescriptor"};
3135
3136 ValidateNumInputs(workloadInfo, descriptorName, 1);
3137 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3138
3139 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3140 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3141
3142 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3143}
3144
James Conroy4f1f8992020-04-29 20:01:10 +01003145void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3146{
3147 const std::string descriptorName{"QLstmQueueDescriptor"};
3148
3149 // Validate number of inputs/outputs
3150 ValidateNumInputs(workloadInfo, descriptorName, 3);
3151 ValidateNumOutputs(workloadInfo, descriptorName, 3);
3152
3153 // Input/output tensor info
3154 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3155 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
3156 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
3157
3158 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3159 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3160 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
3161
3162 // Supported types for various tensors in QLSTM
3163 std::vector<DataType> inputOutputSupportedTypes =
3164 {
3165 DataType::QAsymmS8
3166 };
3167
3168 std::vector<DataType> cellStateSupportedTypes =
3169 {
3170 DataType::QSymmS16
3171 };
3172
3173 std::vector<DataType> weightsSupportedTypes =
3174 {
3175 DataType::QSymmS8
3176 };
3177
3178 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3179 {
3180 DataType::QSymmS16
3181 };
3182
3183 std::vector<DataType> biasSupportedTypes =
3184 {
3185 DataType::Signed32
3186 };
3187
3188 // Validate types of input/output tensors
3189 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3190 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3191 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3192
3193 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3194 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3195 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3196
3197 // Validate matching types of input/output tensors
3198 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3199 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3200 "outputStateIn", "outputStateOut");
3201 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3202
3203 // Infer number of batches, number of units, input size and output size from tensor dimensions
3204 const uint32_t numBatches = inputInfo.GetShape()[0];
3205 const uint32_t inputSize = inputInfo.GetShape()[1];
3206 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3207 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3208
3209 // Validate number of dimensions and number of elements for input/output tensors
3210 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3211 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3212 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3213
3214 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3215 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3216 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3217
3218 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3219 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3220 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3221 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3222
3223 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3224 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3225 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3226
3227 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3228 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3229 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3230
3231 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3232 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3233 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3234 " RecurrentToForgetWeights");
3235
3236 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3237 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3238 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3239
3240 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3241 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3242 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3243
3244 // Validate data types for MANDATORY weights tensors (all should match each other)
3245 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3246
3247 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3248 "inputToForgetWeights", "inputToCellWeights");
3249 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3250 "inputToForgetWeights", "inputToOutputWeights");
3251
3252 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3253 "inputToForgetWeights", "recurrentToForgeteights");
3254 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3255 "inputToForgetWeights", "recurrentToCellWeights");
3256 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3257 "inputToForgetWeights", "recurrentToOutputWeights");
3258
3259 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3260 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3261 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3262 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3263
3264 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3265 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3266 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3267
3268 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3269 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3270 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3271
3272 // Validate data types for MANDATORY bias tensors
3273 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3274
3275 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3276 "forgetGateBias", "cellBias");
3277 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3278 "forgetGateBias", "outputGateBias");
3279
3280 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3281 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3282 !m_Parameters.m_CifgEnabled) ||
3283 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3284 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3285
3286 if (!allCifgParamsPresentOrNot)
3287 {
3288 throw InvalidArgumentException(descriptorName +
3289 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3290 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3291 "set appropriately.");
3292 }
3293
3294 if (!m_Parameters.m_CifgEnabled)
3295 {
3296 // Validate number of dimensions and number of elements
3297 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3298 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3299
3300 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3301 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3302 " RecurrentToInputWeights");
3303
3304 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3305 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3306
3307 // Validate data types
3308 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3309 "inputToForgetWeights", "inputToInputWeights");
3310 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3311 "inputToForgetWeights", "recurrentToInputWeights");
3312 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3313 "forgetGateBias", "inputGateBias");
3314 }
3315
3316 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3317 bool allPeepholeWeightsPresentOrNot =
3318 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3319 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3320 || (!m_CellToInputWeights && !m_CellToForgetWeights
3321 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3322
3323 if (!allPeepholeWeightsPresentOrNot)
3324 {
3325 throw InvalidArgumentException(descriptorName +
3326 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3327 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3328 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3329 "appropriately.");
3330 }
3331
3332 if (m_Parameters.m_PeepholeEnabled)
3333 {
3334 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3335 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3336 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3337
3338 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3339 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3340 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3341 "cellToForgetWeight", "cellToOutputWeights");
3342
3343 if (!m_Parameters.m_CifgEnabled)
3344 {
3345 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3346 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3347 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3348 "cellToForgetWeights", "cellToInputWeights");
3349 }
3350 }
3351
3352 // Validate OPTIONAL params: Layer Norm Weights
3353 bool allLayerNormWeightsPresentOrNot =
3354 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3355 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3356 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3357 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3358
3359 if (!allLayerNormWeightsPresentOrNot)
3360 {
3361 throw InvalidArgumentException(descriptorName +
3362 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3363 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3364 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3365 "only be present when Layer Norm is enabled and CIFG is disabled. "
3366 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3367 }
3368
3369 if (m_Parameters.m_LayerNormEnabled)
3370 {
3371 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3372 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3373 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3374
3375 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3376 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3377 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3378 "forgetLayerNormWeights", "cellLayerNormWeights");
3379
3380 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3381 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3382 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3383 "forgetLayerNormWeights", "outputLayerNormWeights");
3384
3385 if (!m_Parameters.m_CifgEnabled)
3386 {
3387 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3388 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3389 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3390 "forgetLayerNormWeights", "inputLayerNormWeights");
3391 }
3392 }
3393
3394 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3395 bool correctProjectionTensorsPresent =
3396 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3397 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3398 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3399
3400 if (!correctProjectionTensorsPresent)
3401 {
3402 throw InvalidArgumentException(descriptorName +
3403 ": If projection is enabled, ProjectionWeights should be present and "
3404 "ProjectionBias is optional. If projection is disabled, neither "
3405 "ProjectionWeights nor ProjectionBias should be present.");
3406 }
3407
3408 if (m_Parameters.m_ProjectionEnabled)
3409 {
3410 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3411 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3412 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3413
3414 if (m_ProjectionBias)
3415 {
3416 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003417 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003418 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3419 }
3420
3421 }
3422 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3423 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3424 throw InvalidArgumentException(descriptorName +
3425 ": If projection is disabled, output quantization info (scale, offset) "
3426 "should match HiddenStateScale and HiddenStateZeroPoint.");
3427 }
3428
3429}
3430
James Conroy9c3cae82019-08-01 16:01:48 +01003431void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3432{
3433 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3434
3435 // Validate number of inputs/outputs
3436 ValidateNumInputs(workloadInfo, descriptorName, 3);
3437 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3438
3439 // Input/output tensor infos
3440 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3441 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3442 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3443
3444 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3445 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3446
3447 std::vector<DataType> inputOutputSupportedTypes =
3448 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003449 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003450 };
3451
3452 std::vector<DataType> cellStateSupportedTypes =
3453 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003454 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003455 };
3456
3457 std::vector<DataType> weightsSupportedTypes =
3458 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003459 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003460 };
3461
3462 std::vector<DataType> biasSupportedTypes =
3463 {
3464 DataType::Signed32
3465 };
3466
3467 // Validate types of input/output tensors
3468 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3469 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3470 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3471
3472 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3473 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3474
3475 // Validate matching types of input/output tensors
3476 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3477 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3478 "outputStateIn", "outputStateOut");
3479 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3480
3481 // Validate matching quantization info for input/output tensors
3482 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3483 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3484 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003485
James Conroy9c3cae82019-08-01 16:01:48 +01003486 // Infer number of batches, input size and output size from tensor dimensions
3487 const uint32_t numBatches = inputInfo.GetShape()[0];
3488 const uint32_t inputSize = inputInfo.GetShape()[1];
3489 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3490
3491 // Validate number of dimensions and number of elements for input/output tensors
3492 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3493 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3494 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3495 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3496 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3497
3498 // Validate number of dimensions and number of elements for weights tensors
3499 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3500 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3501 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3502
3503 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3504 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3505 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3506
3507 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3508 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3509 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3510
3511 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3512 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3513 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3514
3515 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3516 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3517 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3518
3519 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3520 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3521 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3522 " RecurrentToForgetWeights");
3523
3524 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3525 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3526 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3527
3528 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3529 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3530 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3531
3532 // Validate data types for weights tensors (all should match each other)
3533 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3534
3535 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3536 "inputToInputWeights", "inputToForgetWeights");
3537 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3538 "inputToInputWeights", "inputToCellWeights");
3539 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3540 "inputToInputWeights", "inputToOutputWeights");
3541
3542 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3543 "inputToInputWeights", "recurrentToInputWeights");
3544 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3545 "inputToInputWeights", "recurrentToForgeteights");
3546 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3547 "inputToInputWeights", "recurrentToCellWeights");
3548 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3549 "inputToInputWeights", "recurrentToOutputWeights");
3550
3551 // Validate matching quantization info for weight tensors (all should match each other)
3552 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3553 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3554 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3555 descriptorName, "inputToInputWeights", "inputToCellWeights");
3556 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3557 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3558
3559 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3560 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3561 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3562 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3563 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3564 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3565 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3566 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3567
3568 // Validate number of dimensions and number of elements in bias tensors
3569 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3570 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3571 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3572
3573 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3574 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3575 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3576
3577 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3578 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3579 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3580
3581 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3582 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3583 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3584
3585 // Validate data types for bias tensors (all should match each other)
3586 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3587
3588 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3589 "inputGateBias", "forgetGateBias");
3590 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3591 "inputGateBias", "cellBias");
3592 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3593 "inputGateBias", "outputGateBias");
3594
3595 // Validate bias tensor quantization info
Ryan OSheaf183acd2023-07-06 11:41:25 +01003596 ValidateBiasTensorQuantization(inputGateBiasInfo, inputToInputWeightsInfo, descriptorName);
3597 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputToInputWeightsInfo, descriptorName);
3598 ValidateBiasTensorQuantization(cellBiasInfo, inputToInputWeightsInfo, descriptorName);
3599 ValidateBiasTensorQuantization(outputGateBiasInfo, inputToInputWeightsInfo, descriptorName);
James Conroy9c3cae82019-08-01 16:01:48 +01003600}
3601
Kevin May868eb142019-09-04 17:29:31 +01003602void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3603{
3604 const std::string descriptorName{"AbsQueueDescriptor"};
3605
3606 ValidateNumInputs(workloadInfo, descriptorName, 1);
3607 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3608
3609 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3610 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3611
3612 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3613
3614 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003615 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003616 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003617 DataType::Float16,
3618 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003619 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003620 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003621 DataType::QSymmS16,
3622 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003623 };
Kevin May868eb142019-09-04 17:29:31 +01003624
3625 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3626 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3627}
3628
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003629void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3630{
3631 const std::string descriptorName{"SliceQueueDescriptor"};
3632
3633 ValidateNumInputs(workloadInfo, descriptorName, 1);
3634 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3635
3636 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3637 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3638
3639 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3640
3641 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3642 if (rank > 4)
3643 {
3644 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3645 }
3646
3647 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3648
3649 // Check if m_Begin and m_Size have the expected length
3650 if (m_Parameters.m_Begin.size() != rank)
3651 {
3652 throw InvalidArgumentException(descriptorName +
3653 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3654 }
3655 if (m_Parameters.m_Size.size() != rank)
3656 {
3657 throw InvalidArgumentException(descriptorName +
3658 ": Length of size descriptor must equal rank " + std::to_string(rank));
3659 }
3660
3661 // Check if the shape of the output tensor matches m_Size
3662 const TensorShape& outputShape = outputTensorInfo.GetShape();
3663 for (unsigned int i = 0u; i < rank; ++i)
3664 {
3665 if (m_Parameters.m_Size[i] != outputShape[i])
3666 {
3667 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3668 }
3669 }
3670
3671 // Check if the sum of begin offset and size in a given dimension
3672 // does not exceed the size of corresponding input
3673 const TensorShape& inputShape = inputTensorInfo.GetShape();
3674 for(unsigned int i = 0u; i < rank; ++i)
3675 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003676 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003677 {
3678 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3679 std::to_string(i) + " exceeds input size.");
3680 }
3681 }
3682}
3683
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003684void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3685{
3686 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3687
3688 ValidateNumInputs(workloadInfo, descriptorName, 1);
3689 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3690
3691 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3692 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3693
3694 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3695 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3696
3697 std::vector<DataType> supportedTypes =
3698 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003699 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003700 DataType::Float32,
3701 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003702 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003703 DataType::QAsymmU8,
3704 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003705 };
3706
3707 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3708 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3709
3710 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3711
3712 if (m_Parameters.m_BlockSize == 0)
3713 {
3714 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3715 }
3716
3717 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3718 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3719 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3720 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3721
3722 const TensorShape& outputShape = outputInfo.GetShape();
3723 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3724 {
3725 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3726 "must be divisible by block size.");
3727 }
3728
3729 const TensorShape& inputShape = inputInfo.GetShape();
3730 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3731 {
3732 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3733 "must be divisible by the square of block size." );
3734 }
3735}
3736
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003737void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3738{
3739 const std::string descriptorName{"ComparisonQueueDescriptor"};
3740
3741 ValidateNumInputs(workloadInfo, descriptorName, 2);
3742 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3743
3744 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3745 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3746 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3747
3748 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3749 inputTensorInfo1,
3750 outputTensorInfo,
3751 descriptorName,
3752 "input_0",
3753 "input_1");
3754
3755 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3756 {
3757 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3758 }
3759}
3760
Mike Kelly3ec30772023-03-08 13:47:17 +00003761void ElementwiseBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3762{
3763 const std::string descriptorName{"ElementwiseBinaryQueueDescriptor"};
3764
3765 ValidateNumInputs(workloadInfo, descriptorName, 2);
3766 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3767
3768 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3769 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3770 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3771
3772 std::vector<DataType> supportedTypes =
3773 {
3774 DataType::BFloat16,
3775 DataType::Float16,
3776 DataType::Float32,
3777 DataType::QAsymmS8,
3778 DataType::QAsymmU8,
3779 DataType::QSymmS16,
3780 DataType::Signed32
3781 };
3782
3783 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
3784 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
3785
3786 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input", "output");
3787 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input", "output");
3788}
3789
josh minor4a3c6102020-01-06 16:40:46 -06003790void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3791{
3792 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3793
3794 ValidateNumInputs(workloadInfo, descriptorName, 1);
3795 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3796
3797 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3798 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3799
3800 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3801
3802 std::vector<DataType> supportedTypes =
3803 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003804 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003805 DataType::Float16,
3806 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003807 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003808 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003809 DataType::QSymmS16,
3810 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003811 };
3812
James Conroyaba90cd2020-11-06 16:28:18 +00003813 std::vector<DataType> logicalSupportedTypes =
3814 {
3815 DataType::Boolean
3816 };
3817
3818 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3819 {
3820 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3821 }
3822 else
3823 {
3824 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3825 }
3826
3827
josh minor4a3c6102020-01-06 16:40:46 -06003828 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3829}
3830
Finn Williams2605b232020-06-10 15:53:46 +01003831void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3832{
3833 const std::string descriptorName{"RankQueueDescriptor"};
3834
3835 ValidateNumInputs(workloadInfo, descriptorName, 1);
3836 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3837
3838 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3839 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3840
3841 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3842 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3843
3844 std::vector<DataType> supportedTypes =
3845 {
3846 DataType::BFloat16,
3847 DataType::Float16,
3848 DataType::Float32,
3849 DataType::QAsymmS8,
3850 DataType::QAsymmU8,
3851 DataType::QSymmS8,
3852 DataType::QSymmS16,
3853 DataType::Signed32
3854 };
3855
3856 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3857 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3858}
3859
James Conroyaba90cd2020-11-06 16:28:18 +00003860void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3861{
3862 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3863
3864 ValidateNumInputs(workloadInfo, descriptorName, 2);
3865 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3866
3867 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3868 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3869 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3870
3871 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3872 inputTensorInfo1,
3873 outputTensorInfo,
3874 descriptorName,
3875 "input_0",
3876 "input_1");
3877
3878 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3879 {
3880 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3881 }
3882
3883 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3884 {
3885 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3886 }
3887
3888 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3889 {
3890 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3891 }
3892}
3893
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003894void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3895{
3896 const std::string descriptorName{"ReduceQueueDescriptor"};
3897
3898 ValidateNumInputs(workloadInfo, descriptorName, 1);
3899 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3900
3901 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3902 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3903
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003904 std::vector<DataType> supportedTypes =
3905 {
3906 DataType::BFloat16,
3907 DataType::Float16,
3908 DataType::Float32,
3909 DataType::QAsymmS8,
3910 DataType::QAsymmU8,
3911 DataType::QSymmS16,
3912 DataType::Signed32
3913 };
3914
3915 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3916 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3917}
3918
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003919void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3920{
3921 // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm
3922
3923 const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"};
3924
3925 // check dimensions of all inputs and outputs
3926 if (workloadInfo.m_InputTensorInfos.size() != 3)
3927 {
3928 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
3929 }
Mike Kelly12994962022-04-21 11:57:09 +01003930 if (workloadInfo.m_OutputTensorInfos.size() != 3)
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003931 {
3932 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
3933 }
3934
3935 std::vector<DataType> supportedTypes =
3936 {
Mike Kelly12994962022-04-21 11:57:09 +01003937 DataType::Float32,
3938 DataType::QAsymmS8
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003939 };
3940
3941 // check for supported type of one input and match them with all the other input and output
3942 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
3943
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003944 // Making sure clipping parameters have valid values.
3945 // == 0 means no clipping
3946 // > 0 means clipping
3947 if (m_Parameters.m_ClippingThresCell < 0.0f)
3948 {
3949 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
3950 }
3951 if (m_Parameters.m_ClippingThresProj < 0.0f)
3952 {
3953 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
3954 }
3955
3956 unsigned int batchIndx = 0;
3957 unsigned int inputIndx = 1;
3958 uint32_t timeStep = 1;
3959 unsigned int timeIndx = 1;
3960 inputIndx = 2;
3961 if (m_Parameters.m_TimeMajor)
3962 {
3963 batchIndx = 1;
3964 timeIndx = 0;
3965
3966 }
3967 timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx];
3968
3969 // Inferring batch size, number of outputs and number of cells from the inputs.
3970 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx];
3971 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx];
3972 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
3973 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
3974 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
3975 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
3976
3977 // input tensor
3978 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input),
3979 descriptorName + " input_0");
3980 // outputStateInTensor
3981 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
3982 descriptorName + " input_1");
3983 // outputStateInTensor
3984 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
3985 descriptorName + " input_2");
3986
3987 // outputTensor
Mike Kelly12994962022-04-21 11:57:09 +01003988 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 3, (timeStep * n_batch * n_output),
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003989 descriptorName + " output_0");
3990
3991 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
3992 if ( m_InputToInputWeights )
3993 {
3994 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
3995 (n_cell * n_input), "InputLayerNormWeights");
3996 }
3997
3998 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
3999 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
4000 (n_cell * n_input), "InputToForgetWeights");
4001
4002 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
4003 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
4004 (n_cell * n_input), "InputToCellWeights");
4005
4006 if ( m_RecurrentToInputWeights )
4007 {
4008 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
4009 (n_cell * n_output), "RecurrentToInputWeights");
4010 }
4011
4012 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
4013 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
4014 (n_cell * n_output), "RecurrentToForgetWeights");
4015
4016 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
4017 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
4018 (n_cell * n_output), "RecurrentToCellWeights");
4019
4020 // Make sure the input-gate's parameters are either both present (regular
4021 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
4022 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
4023 !m_Parameters.m_CifgEnabled) ||
4024 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
4025 m_Parameters.m_CifgEnabled));
4026 if (!cifg_weights_all_or_none)
4027 {
4028 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
4029 "RecurrentToInputWeights must either both be present (regular LSTM) "
4030 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
4031 "accordingly.");
4032 }
4033
4034 if ( m_CellToInputWeights )
4035 {
4036 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
4037 n_cell, "CellToInputWeights");
4038 }
4039 if ( m_CellToForgetWeights )
4040 {
4041 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
4042 n_cell, "CellToForgetWeights");
4043 }
4044 if ( m_CellToOutputWeights )
4045 {
4046 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
4047 n_cell, "CellToOutputWeights");
4048 }
4049
4050 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
4051 bool peephole_weights_all_or_none =
4052 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
4053 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
4054 || ( !m_CellToInputWeights && !m_CellToForgetWeights
4055 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
4056 if (!peephole_weights_all_or_none)
4057 {
4058 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
4059 }
4060
4061 // Make sure the input gate bias is present only when not a CIFG-LSTM.
4062 if (m_Parameters.m_CifgEnabled)
4063 {
4064 if (m_InputGateBias)
4065 {
4066 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
4067 }
4068 }
4069 else
4070 {
4071 if (!m_InputGateBias)
4072 {
4073 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
4074 "must be present.");
4075 }
4076 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
4077 n_cell, "InputGateBias");
4078 }
4079
4080 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
4081 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
4082
4083 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
4084 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
4085
4086 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
4087 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
4088
4089 if (m_ProjectionWeights)
4090 {
4091 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
4092 (n_cell * n_output), "ProjectionWeights");
4093 }
4094 if (m_ProjectionBias)
4095 {
4096 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
4097 }
4098
4099 // Making sure the projection tensors are consistent:
4100 // 1) If projection weight is not present, then projection bias should not be
4101 // present.
4102 // 2) If projection weight is present, then projection bias is optional.
4103 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
4104 !m_Parameters.m_ProjectionEnabled)
4105 || (m_ProjectionWeights && !m_ProjectionBias &&
4106 m_Parameters.m_ProjectionEnabled)
4107 || (m_ProjectionWeights && m_ProjectionBias &&
4108 m_Parameters.m_ProjectionEnabled));
4109 if (!projecton_tensors_consistent)
4110 {
4111 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
4112 }
4113
4114 // The four layer normalization weights either all have values or none of them have values. Additionally, if
4115 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
4116 // either all have values or none of them have values. Layer normalization is used when the values of all the
4117 // layer normalization weights are present
4118 if (m_InputLayerNormWeights)
4119 {
4120 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
4121 }
4122 if (m_ForgetLayerNormWeights)
4123 {
4124 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4125 }
4126 if (m_CellLayerNormWeights)
4127 {
4128 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4129 }
4130 if (m_OutputLayerNormWeights)
4131 {
4132 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4133 }
4134
4135 if (m_Parameters.m_LayerNormEnabled)
4136 {
4137 if (!m_Parameters.m_CifgEnabled)
4138 {
4139 if (!m_InputLayerNormWeights)
4140 {
4141 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
4142 "disabled but InputLayerNormWeights are not present");
4143 }
4144 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
4145 1, n_cell, "InputLayerNormWeights");
4146 }
4147 else if (m_InputLayerNormWeights)
4148 {
4149 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
4150 "enabled");
4151 }
4152
4153 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
4154 "ForgetLayerNormWeights");
4155 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4156
4157 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
4158 "OutputLayerNormWeights");
4159 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4160
4161 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
4162 "CellLayerNormWeights");
4163 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4164 }
4165 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
4166 {
4167 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
4168 "normalisation weights are present.");
4169 }
4170}
4171
Samuel Yap6b478092022-07-06 15:36:03 +01004172void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4173{
4174 const std::string descriptorName{"BatchMatMulDescriptor"};
4175
4176 ValidateNumInputs(workloadInfo, descriptorName, 2);
4177 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4178
4179 // Inputs must be: both 2D+
4180 // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
4181 // axes N and I must be the same size
4182
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004183 const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
4184 const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
4185 const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
4186 // Output info has already been inferred
Samuel Yap6b478092022-07-06 15:36:03 +01004187
4188 std::vector<DataType> supportedTypes =
4189 {
4190 DataType::BFloat16,
4191 DataType::Float16,
4192 DataType::Float32,
4193 DataType::QAsymmS8,
4194 DataType::QAsymmU8,
4195 DataType::QSymmS16
4196 };
4197
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004198 ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
4199 ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
4200 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
Samuel Yap6b478092022-07-06 15:36:03 +01004201
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004202 if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
4203 (inputYInfoBeforeParams.GetNumDimensions() < 2))
Samuel Yap6b478092022-07-06 15:36:03 +01004204 {
4205 throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
4206 }
4207
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004208 TensorInfo inputXInfoAfterParams;
4209 TensorInfo inputYInfoAfterParams;
4210
4211 if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) ||
4212 (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY))
Samuel Yap6b478092022-07-06 15:36:03 +01004213 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004214 throw InvalidArgumentException(descriptorName +
4215 ": Invalid descriptor parameters - Transpose and Adjoint "
4216 "cannot both be true for a given input tensor.");
4217 }
4218 if(m_Parameters.m_TransposeX)
4219 {
4220 inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
4221 BatchMatMulDescriptor::GetPermuteVec(
4222 m_Parameters.m_DataLayoutX,
4223 inputXInfoBeforeParams.GetShape()));
4224 }
4225 else if(m_Parameters.m_AdjointX)
4226 {
4227 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4228 inputXInfoBeforeParams.GetShape());
4229 if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
4230 inputXInfoBeforeParams.GetShape()[axesToMul.second])
Samuel Yap6b478092022-07-06 15:36:03 +01004231 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004232 throw InvalidArgumentException(descriptorName +
4233 ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
Samuel Yap6b478092022-07-06 15:36:03 +01004234 }
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004235 // Shape remains the same as it's square
4236 inputXInfoAfterParams = inputXInfoBeforeParams;
4237 }
4238 else
4239 {
4240 inputXInfoAfterParams = inputXInfoBeforeParams;
Samuel Yap6b478092022-07-06 15:36:03 +01004241 }
4242
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004243 if(m_Parameters.m_TransposeY)
Samuel Yap6b478092022-07-06 15:36:03 +01004244 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004245 inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
4246 BatchMatMulDescriptor::GetPermuteVec(
4247 m_Parameters.m_DataLayoutY,
4248 inputYInfoBeforeParams.GetShape()));
4249 }
4250 else if(m_Parameters.m_AdjointY)
4251 {
4252 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4253 inputYInfoBeforeParams.GetShape());
4254 if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
4255 inputYInfoBeforeParams.GetShape()[axesToMul.second])
Samuel Yap6b478092022-07-06 15:36:03 +01004256 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004257 throw InvalidArgumentException(descriptorName +
4258 ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
Samuel Yap6b478092022-07-06 15:36:03 +01004259 }
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004260 // Shape remains the same as it's square
4261 inputYInfoAfterParams = inputYInfoBeforeParams;
4262 }
4263 else
4264 {
4265 inputYInfoAfterParams = inputYInfoBeforeParams;
Samuel Yap6b478092022-07-06 15:36:03 +01004266 }
4267
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004268 switch(m_Parameters.m_DataLayoutX)
4269 {
4270 case DataLayout::NCDHW:
4271 case DataLayout::NDHWC:
4272 if(inputXInfoAfterParams.GetNumDimensions() < 3)
4273 {
4274 throw InvalidArgumentException(descriptorName +
4275 ": Input tensor X does not have the correct "
4276 "number of dimensions for the Data Layout that it has been assigned.");
4277 }
4278 break;
4279 case DataLayout::NCHW:
4280 case DataLayout::NHWC:
4281 default:
4282 break;
4283 }
Samuel Yap6b478092022-07-06 15:36:03 +01004284
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004285 switch(m_Parameters.m_DataLayoutY)
4286 {
4287 case DataLayout::NCDHW:
4288 case DataLayout::NDHWC:
4289 if(inputYInfoAfterParams.GetNumDimensions() < 3)
4290 {
4291 throw InvalidArgumentException(descriptorName +
4292 ": Input tensor Y does not have the correct "
4293 "number of dimensions for the Data Layout that it has been assigned.");
4294 }
4295 break;
4296 case DataLayout::NCHW:
4297 case DataLayout::NHWC:
4298 default:
4299 break;
4300 }
4301
4302 auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4303 inputXInfoAfterParams.GetShape());
4304 auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4305 inputXInfoBeforeParams.GetShape());
4306
4307 if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
4308 != inputYInfoAfterParams.GetShape()[axesYToMul.first])
Samuel Yap6b478092022-07-06 15:36:03 +01004309 {
4310 throw InvalidArgumentException(descriptorName +
4311 ": The final axis of input tensor X must be the same size as "
4312 "the second last axis of input tensor Y.");
4313 }
4314
Samuel Yap6b478092022-07-06 15:36:03 +01004315 { // Separate scope so we don't pollute the rest of the scope with our temp variables
4316 // e.g. NHWC isnt compatible with NCHW as of now
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004317 DataLayout xLayout = m_Parameters.m_DataLayoutX;
4318 DataLayout yLayout = m_Parameters.m_DataLayoutY;
Samuel Yap6b478092022-07-06 15:36:03 +01004319
4320 if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
4321 {
4322 if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
4323 {
4324 throw InvalidArgumentException(descriptorName +
4325 ": Invalid input tensor data layout combination.");
4326 }
4327 }
4328 if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
4329 {
4330 if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
4331 {
4332 throw InvalidArgumentException(descriptorName +
4333 ": Invalid input tensor data layout combination.");
4334 }
4335 }
4336 }
4337
4338 // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004339 unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
4340 inputYInfoAfterParams.GetNumDimensions());
Samuel Yap6b478092022-07-06 15:36:03 +01004341 if(outputTensorDimSize-2 > 0)
4342 {
4343 TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4344 DataType::Float32);
4345 TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4346 DataType::Float32);
4347 TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4348 DataType::Float32);
4349
4350 auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
4351 {
4352 auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4353
4354 for(unsigned int i = 0; i < sizeDiff; i++)
4355 {
4356 axisIndices.insert(axisIndices.begin(), 1);
4357 }
4358
4359 for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4360 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004361 ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
Samuel Yap6b478092022-07-06 15:36:03 +01004362 }
4363 };
4364
Samuel Yapdc8ed9d2022-08-08 14:07:42 +01004365 auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX,
4366 inputXInfoAfterParams.GetShape());
4367 auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY,
4368 inputYInfoAfterParams.GetShape());
4369
4370 doAxisExtension(axesXNotMul, tiXNotMul);
4371 doAxisExtension(axesYNotMul, tiYNotMul);
Samuel Yap6b478092022-07-06 15:36:03 +01004372
4373 for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
4374 {
4375 tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
4376 tiYNotMul.GetShape()[i]);
4377 }
4378
4379 ValidateBroadcastTensorShapesMatch(tiXNotMul,
4380 tiYNotMul,
4381 tiOutNotMul,
4382 descriptorName,
4383 "input_X",
4384 "input_Y");
4385 }
Samuel Yap6b478092022-07-06 15:36:03 +01004386}
4387
Teresa Charlin79a06a52023-07-13 17:16:45 +01004388void TileQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4389{
4390 const std::string& descriptorName{"TileQueueDescriptor"};
4391
4392 ValidateNumInputs(workloadInfo, descriptorName, 1);
4393 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4394
4395 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
4396 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
4397
4398 std::vector<DataType> supportedTypes =
4399 {
4400 DataType::Float32,
4401 DataType::Float16,
4402 DataType::QAsymmS8,
4403 DataType::QAsymmU8,
4404 DataType::QSymmS8,
4405 DataType::QSymmS16,
4406 DataType::Signed32
4407 };
4408
4409 // Multiples length must be the same as the number of dimensions in input.
4410 if (m_Parameters.m_Multiples.size() != inputTensorInfo.GetNumDimensions())
4411 {
4412 throw InvalidArgumentException(descriptorName +
4413 ": Multiples length is not same as the number of dimensions in Input.");
4414 }
4415
4416 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
4417 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
4418}
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01004419
Idriss Chaouch98e383e2023-08-28 14:28:31 +01004420void BroadcastToQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4421{
4422 const std::string& descriptorName{"BroadcastToQueueDescriptor"};
4423
4424 ValidateNumInputs(workloadInfo, descriptorName, 1);
4425 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4426
4427 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
4428 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
4429
4430 std::vector<DataType> supportedTypes =
4431 {
4432 DataType::Float32,
4433 DataType::Float16,
4434 DataType::QAsymmS8,
4435 DataType::QAsymmU8,
4436 DataType::QSymmS8,
4437 DataType::QSymmS16,
4438 DataType::Signed32,
4439 DataType::Signed64
4440 };
4441
4442 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
4443 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
4444}
4445
mathad01df9a3222021-04-28 11:42:57 +01004446} // namespace armnn