blob: 385affa5fa3b0f510703583694f95722583153bb [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Colm Donelan0c479742021-12-10 12:43:54 +00006#include <armnn/backends/TensorHandle.hpp>
7#include <armnn/backends/WorkloadData.hpp>
8#include <armnn/backends/WorkloadInfo.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00009#include <armnnUtils/DataLayoutIndexed.hpp>
10#include <armnnUtils/TensorUtils.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
mathad01df9a3222021-04-28 11:42:57 +010012#include <armnn/Logging.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000013
telsoa014fcda012018-03-09 14:13:49 +000014#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000015#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000016#include <string>
17#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000018
James Ward47fce872020-09-10 11:57:28 +010019#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000020
Matteo Martincigh21350152018-11-28 16:22:22 +000021using namespace armnnUtils;
22
telsoa014fcda012018-03-09 14:13:49 +000023namespace armnn
24{
25
26//---------------------------------------------------------------
27DataType GetBiasDataType(DataType inputDataType)
28{
29 switch (inputDataType)
30 {
telsoa01c577f2c2018-08-31 09:22:23 +010031 case DataType::Float16:
32 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000033 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000034 case DataType::Float32:
35 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000036 case DataType::QAsymmS8:
37 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000038 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000039 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000040 case DataType::QSymmS8:
41 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000042 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010043 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000044 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010045 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000046 return DataType::Float32;
47 }
48}
49
50namespace
51{
52
53//---------------------------------------------------------------
54//android ndk does not support std::to_string function.
55template <typename T>
56std::string to_string(T value)
57{
58 std::ostringstream os;
59 os << value;
60 return os.str();
61}
62
63//---------------------------------------------------------------
64void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
65{
66 if (!ptr)
67 {
68 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
69 paramName + " parameter must be set.");
70 }
71}
72
73//---------------------------------------------------------------
74void ValidateTensorShapesMatch(const TensorInfo& first,
75 const TensorInfo& second,
76 std::string const& descName,
77 std::string const& firstName,
78 std::string const& secondName)
79{
80 if (first.GetShape() != second.GetShape())
81 {
82 throw InvalidArgumentException(descName + ": "
83 + firstName + " & " + secondName + " must have identical shapes");
84 }
85}
86
87//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010088void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000089{
Sadik Armaganeff363d2019-04-05 15:25:46 +010090 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000091 {
92 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010093 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000094 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
95 }
96}
97
98//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010099void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000100{
Sadik Armaganeff363d2019-04-05 15:25:46 +0100101 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000102 {
103 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100104 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000105 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
106 }
107}
108
109//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100110void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000111 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100112 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000113 std::string const& tensorName)
114{
115 if (tensor.GetNumDimensions() != numDimensions)
116 {
117 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
118 to_string(tensor.GetNumDimensions()) + " dimensions for " +
119 tensorName + " tensor.");
120 }
121}
122
123//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100124void ValidateTensorNumElements(const TensorInfo& tensor,
125 std::string const& descName,
126 unsigned int numElements,
127 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100128{
129 if (tensor.GetNumElements() != numElements)
130 {
131 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100132 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100133 tensorName + " tensor.");
134 }
135}
136
137//---------------------------------------------------------------
138void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100139 unsigned int numDimension,
140 unsigned int numElements,
141 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100142{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100143 const std::string functionName{"ValidateTensorNumDimNumElem"};
144 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
145 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100146}
147
148//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000149void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
150 const std::string& descName, std::string const& tensorName)
151{
152 if (tensor.GetDataType() != dataType)
153 {
154 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
155 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
156 }
157}
158
Derek Lambertid466a542020-01-22 15:37:29 +0000159void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
160{
Jan Eilers1b2654f2021-09-24 15:45:46 +0100161 if (tensor.GetDataType() != DataType::QSymmS8)
Derek Lambertid466a542020-01-22 15:37:29 +0000162 {
163 throw InvalidArgumentException(descName +
164 ": Expected data type which supports per-axis quantization scheme but got " +
165 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
166 }
Derek Lambertid466a542020-01-22 15:37:29 +0000167}
168
telsoa014fcda012018-03-09 14:13:49 +0000169//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100170void ValidateTensorQuantizationSpace(const TensorInfo& first,
171 const TensorInfo& second,
172 const std::string& descName,
173 std::string const& firstName,
174 std::string const& secondName)
175{
176 if (!first.IsQuantized() ||
177 !second.IsQuantized())
178 {
179 // Not a quantized type, ignore the validation
180 return;
181 }
182
183 DataType firstDataType = first.GetDataType();
184 DataType secondDataType = second.GetDataType();
185
186 if (firstDataType != secondDataType)
187 {
188 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
189 " must be of the same quantized type, " +
190 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
191 secondName + " is " + GetDataTypeName(secondDataType));
192 }
193
194 if (!first.IsTypeSpaceMatch(second))
195 {
196 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
197 " must have the same quantization space, " +
198 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
199 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
200 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
201 " and scale " + to_string(second.GetQuantizationScale()));
202 }
203}
204
205//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100206void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
207 const TensorInfo& inputTensorInfo,
208 const TensorInfo& weightsTensorInfo,
209 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000210{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000211 // Helper lambda function to validate a single bias quantization scale value
212 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
213 {
mathad01df9a3222021-04-28 11:42:57 +0100214 constexpr float tolerance = 0.0001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000215 if (std::abs(biasScale - expectedScale) > tolerance)
216 {
217 // Print the float values with extra precision to see very small differences
mathad01df9a3222021-04-28 11:42:57 +0100218 ARMNN_LOG(warning) << std::setprecision(6) << descName << ": Expected " << expectedScale <<
219 " for bias quantization scale (product of input and weight scales), but got " <<
220 biasScale << ". Using scale provided.";
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000221 }
222 };
223
telsoa014fcda012018-03-09 14:13:49 +0000224 if (biasTensor.GetQuantizationOffset() != 0)
225 {
226 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
227 to_string(biasTensor.GetQuantizationOffset()));
228 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000229
James Conroy8502ade2020-11-12 19:26:29 +0000230 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000231 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000232 // Validate per-axis quantization scales
233 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
234 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
235
236 if (weightScales.size() != biasScales.size())
237 {
238 std::stringstream msg;
James Conroy8502ade2020-11-12 19:26:29 +0000239 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
240 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
241 << ", biases=" << biasScales.size();
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000242 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
243 }
244
245 for (size_t i = 0ul; i < biasScales.size(); ++i)
246 {
247 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
248 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
249 }
250 }
251 else
252 {
253 // Validate per-tensor quantization scale
254 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
255 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000256 }
257}
258
259//---------------------------------------------------------------
260void ValidateTensors(const std::vector<ITensorHandle*>& vec,
261 unsigned int numExpected,
262 const std::string& descName,
263 const std::string& varName)
264{
265 if (vec.empty() && numExpected > 0)
266 {
267 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
268 }
269
270 for (unsigned int i = 0; i < numExpected; ++i)
271 {
272 if (!vec[i])
273 {
274 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
275 }
276 }
277}
278
279//---------------------------------------------------------------
280void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
281 const TensorInfo& second,
282 const TensorInfo& output,
283 std::string const& descName,
284 std::string const& firstName,
285 std::string const& secondName)
286{
287 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
288 // broadcasted.
289 if (first.GetNumDimensions() != second.GetNumDimensions())
290 {
291 throw InvalidArgumentException(descName + ": Tensors "
292 + firstName + " & " + secondName
293 + " must have the same number of dimensions in order to be broadcasted");
294 }
295 uint32_t numDims = first.GetNumDimensions();
296 std::vector<uint32_t> outputDims(numDims, 0u);
297 for (uint32_t i = 0; i < numDims; i++)
298 {
299 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
300 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
301 if (dimsNotEqual && dimsNotOne)
302 {
303 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
304 }
305 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
306 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100307 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000308 if (broadcastShape != output.GetShape())
309 {
310 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
311 + firstName + " & " + secondName
312 + " does not match the output shape");
313 }
314}
315
316//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100317void ValidateDataTypes(const TensorInfo& info,
318 const std::vector<armnn::DataType>& supportedTypes,
319 std::string const& descName)
320{
321 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
322 if (iterator == supportedTypes.end())
323 {
324 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
325 }
326}
327
James Conroy4d1ff582019-06-10 17:06:39 +0100328//---------------------------------------------------------------
329void ValidateTensorDataTypesMatch(const TensorInfo& first,
330 const TensorInfo& second,
331 std::string const& descName,
332 std::string const& firstName,
333 std::string const& secondName)
334{
335 if (first.GetDataType() != second.GetDataType())
336 {
337 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
338 " must have identical data types.");
339 }
340}
341
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100342//---------------------------------------------------------------
343void ValidateTensorNumElementsMatch(const TensorInfo& first,
344 const TensorInfo& second,
345 std::string const& descName,
346 std::string const& firstName,
347 std::string const& secondName)
348{
349 if (first.GetNumElements() != second.GetNumElements())
350 {
351 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
352 " must have the same number of elements.");
353 }
354}
355
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000356void ValidateWeightDataType(const TensorInfo& inputInfo,
357 const TensorInfo& weightInfo,
358 const std::string& descName)
359{
360 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000361 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000362 {
363 const std::vector<DataType> validTypes =
364 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000365 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100366 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +0100367 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000368 };
369
370 ValidateDataTypes(weightInfo, validTypes, descName);
371 }
372 else
373 {
374 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
375 }
376}
377
378void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
379 const std::string& descName,
380 const std::string& tensorName)
381{
382 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
383 if (!quantizationDim.has_value())
384 {
James Ward47fce872020-09-10 11:57:28 +0100385 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
386 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000387 }
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000388}
389
390void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
391 const std::string& descName,
392 const std::string& tensorName)
393{
394 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
395 if (quantizationOffset != 0)
396 {
James Ward47fce872020-09-10 11:57:28 +0100397 throw InvalidArgumentException(fmt::format(
398 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
399 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000400 }
401}
402
403void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
404 const TensorInfo& outputInfo,
405 const TensorInfo& weightInfo,
406 const Optional<TensorInfo>& optionalBiasInfo,
407 const std::string& descName)
408{
409 if (weightInfo.HasPerAxisQuantization())
410 {
411 const DataType inputDataType = inputInfo.GetDataType();
412 const DataType outputDataType = outputInfo.GetDataType();
413
Keith Davis0c2eeac2020-02-11 16:51:50 +0000414 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000415
416 if (!canHavePerAxisQuantization)
417 {
James Ward47fce872020-09-10 11:57:28 +0100418 throw InvalidArgumentException(fmt::format(
419 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
420 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000421 }
422
Derek Lambertid466a542020-01-22 15:37:29 +0000423
424 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000425 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
426 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
427
428 if (optionalBiasInfo.has_value())
429 {
430 const TensorInfo& biasInfo = optionalBiasInfo.value();
431 if (!biasInfo.HasPerAxisQuantization())
432 {
James Ward47fce872020-09-10 11:57:28 +0100433 throw InvalidArgumentException(fmt::format(
434 "{}: Per-axis quantization parameters not set on bias tensor, "
435 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000436 }
437
438 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
439 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
440 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
441 }
442 }
443}
444
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100445} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000446
447void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
448 unsigned int numExpectedIn, unsigned int numExpectedOut) const
449{
450 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
451 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
452}
453
454//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100455void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
456{
457 const std::string descriptorName{"MapQueueDescriptor"};
458
459 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100460 ValidateNumOutputs(workloadInfo, descriptorName, 0);
461
462 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
463 {
464 if (!m_Inputs[i])
465 {
466 throw InvalidArgumentException(
467 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
468 }
469 }
470}
471
472//---------------------------------------------------------------
473void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
474{
475 const std::string descriptorName{"UnmapQueueDescriptor"};
476
477 ValidateNumInputs(workloadInfo, descriptorName, 1);
478 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100479
480 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
481 {
482 if (!m_Inputs[i])
483 {
484 throw InvalidArgumentException(
485 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
486 }
487 }
488}
489
490//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000491void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
492{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100493 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000494
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100495 ValidateNumInputs(workloadInfo, descriptorName, 1);
496 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000497
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100498 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
499 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
500
501 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
502 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000503
504 if (m_Inputs.size() != m_Outputs.size())
505 {
James Ward47fce872020-09-10 11:57:28 +0100506 throw InvalidArgumentException(fmt::format(
507 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
508 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000509 }
510
511 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
512 {
513 if (!m_Inputs[i])
514 {
James Ward47fce872020-09-10 11:57:28 +0100515 throw InvalidArgumentException(fmt::format(
516 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000517 }
518
519 if (!m_Outputs[i])
520 {
James Ward47fce872020-09-10 11:57:28 +0100521 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000522 }
523 }
524}
525
Derek Lambertif674aa02019-08-01 15:56:25 +0100526//---------------------------------------------------------------
527void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
528{
529 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
530 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
531
532 if (workloadInfo.m_InputTensorInfos.size() != 1)
533 {
James Ward47fce872020-09-10 11:57:28 +0100534 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
535 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100536
537 }
538
539 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
540 {
James Ward47fce872020-09-10 11:57:28 +0100541 throw InvalidArgumentException(fmt::format(
542 "Number of input infos ({0}) does not match the number of output infos ({1})",
543 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100544 }
545
546 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
547 {
548 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
549 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
550 {
James Ward47fce872020-09-10 11:57:28 +0100551 throw InvalidArgumentException(fmt::format(
552 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100553 }
554 }
555
556 if (m_Inputs.size() != 1)
557 {
James Ward47fce872020-09-10 11:57:28 +0100558 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100559 }
560
561 if (m_Inputs.size() != m_Outputs.size())
562 {
James Ward47fce872020-09-10 11:57:28 +0100563 throw InvalidArgumentException(fmt::format(
564 "Number of inputs ({0}) does not match the number of outputs ({1})",
565 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100566 }
567
568 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
569 {
570 if (!m_Inputs[i])
571 {
James Ward47fce872020-09-10 11:57:28 +0100572 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100573 }
574
575 if (!m_Outputs[i])
576 {
James Ward47fce872020-09-10 11:57:28 +0100577 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100578 }
579 }
580}
581
582//---------------------------------------------------------------
583void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
584{
585 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
586 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
587
Derek Lambertif674aa02019-08-01 15:56:25 +0100588 if (m_Inputs.size() != 1)
589 {
James Ward47fce872020-09-10 11:57:28 +0100590 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100591 }
592
593 if (m_Outputs.size() != 0)
594 {
James Ward47fce872020-09-10 11:57:28 +0100595 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100596 }
597
598 if (!m_Inputs[0])
599 {
James Ward47fce872020-09-10 11:57:28 +0100600 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100601 }
602}
603
604//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000605void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
606{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100607 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100608
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100609 ValidateNumInputs(workloadInfo, descriptorName, 1);
610 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100611
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100612 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
613 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100614
615 std::vector<DataType> supportedTypes =
616 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000617 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100618 DataType::Float16,
619 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000620 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000621 DataType::QAsymmU8,
622 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100623 };
624
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100625 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
626 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
627 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000628}
629
Nikhil Rajee391d52019-09-05 17:50:44 +0100630void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
631{
632 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
633
634 ValidateNumInputs(workloadInfo, descriptorName, 1);
635 ValidateNumOutputs(workloadInfo, descriptorName, 1);
636
637 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
638 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
639
Inki Daed4619e22020-09-10 15:33:54 +0900640 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
641 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100642 {
Inki Daed4619e22020-09-10 15:33:54 +0900643 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100644 }
645
James Conroyd47a0642019-09-17 14:22:06 +0100646 std::vector<DataType> supportedInputTypes =
647 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000648 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100649 DataType::Float16,
650 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100651 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000652 DataType::QAsymmU8,
653 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900654 DataType::Signed32,
655 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100656 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100657
James Conroyd47a0642019-09-17 14:22:06 +0100658 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100659
660 auto inputShape = inputTensorInfo.GetShape();
661 auto outputShape = outputTensorInfo.GetShape();
662
663 auto inputNumDimensions = inputShape.GetNumDimensions();
664 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
665
666 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
667
668 // 1D input shape results in scalar output shape
669 if (inputShape.GetNumDimensions() == 1)
670 {
671 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
672 {
673 throw InvalidArgumentException(descriptorName + outputShapeError);
674 }
675 }
676 else
677 {
678 for (unsigned int i = 0; i < unsignedAxis; ++i)
679 {
680 if (outputShape[i] != inputShape[i])
681 {
682 throw InvalidArgumentException(descriptorName + outputShapeError);
683 }
684 }
685
686 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
687 {
688 if (outputShape[i - 1] != inputShape[i])
689 {
690 throw InvalidArgumentException(descriptorName + outputShapeError);
691 }
692 }
693 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100694}
695
mathad01b392e982021-04-07 12:07:30 +0100696void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
697{
698 const std::string descriptorName{"CastQueueDescriptor"};
699
700 ValidateNumInputs(workloadInfo, descriptorName, 1);
701 ValidateNumOutputs(workloadInfo, descriptorName, 1);
702
703 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
704 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
705
706 std::vector<DataType> supportedTypes =
707 {
708 DataType::BFloat16,
709 DataType::Float16,
710 DataType::Float32,
711 DataType::QAsymmS8,
712 DataType::QAsymmU8,
713 DataType::QSymmS8,
714 DataType::QSymmS16,
715 DataType::Signed32,
716 DataType::Signed64
717 };
718
719 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
720 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
721}
722
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100723void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
724{
725 const std::string descriptorName{"SoftmaxQueueDescriptor"};
726
727 ValidateNumInputs(workloadInfo, descriptorName, 1);
728 ValidateNumOutputs(workloadInfo, descriptorName, 1);
729
730 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
731 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
732
733 std::vector<DataType> supportedTypes =
734 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000735 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100736 DataType::Float16,
737 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000738 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000739 DataType::QAsymmU8,
740 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100741 };
742
743 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
744 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
745 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
746}
747
telsoa014fcda012018-03-09 14:13:49 +0000748void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
749{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100750 const std::string descriptorName{"SplitterQueueDescriptor"};
751
752 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000753
Ruomei Yan25339c32019-05-28 16:48:20 +0100754 // Check the supported data types
755 std::vector<DataType> supportedTypes =
756 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000757 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100758 DataType::Float32,
759 DataType::Float16,
760 DataType::Boolean,
761 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100762 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000763 DataType::QAsymmU8,
764 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100765 };
766
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100767 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
768 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100769 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100770 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
771 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
772
773 const std::string outputName = "output_" + std::to_string(i);
774 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100775 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100776
telsoa014fcda012018-03-09 14:13:49 +0000777 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
778 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100779 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000780 }
781
782 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
783 {
784 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100785 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000786 "has to match number of workloadInfo.m_OutputTensorInfos. "
787 "Number of windows: " +
788 to_string(m_ViewOrigins.size()) +
789 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
790 }
791
telsoa01c577f2c2018-08-31 09:22:23 +0100792 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000793 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
794 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
795 {
telsoa01c577f2c2018-08-31 09:22:23 +0100796 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000797 ViewOrigin const& e = m_ViewOrigins[w];
798 if (e.m_Origin.size() != inputDims)
799 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100800 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000801 "have the same dimensionality as the input tensor. "
802 "Window origin (index: " +
803 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
804 " dimensions, the input "
805 "tensor has " +
806 to_string(inputDims) + " dimensions.");
807 }
808 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
809 {
810 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
811 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
812 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100813 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000814 "be smaller or equal than the size of the input in that coord.");
815 }
816 }
817 }
818}
819
Jim Flynne242f2d2019-05-22 14:24:13 +0100820void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000821{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100822 const std::string descriptorName{"ConcatQueueDescriptor"};
823
824 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000825
826 if (m_Inputs.size() <= 0)
827 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100828 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000829 }
830 if (m_Outputs.size() <= 0)
831 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100832 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000833 }
834
835 if (workloadInfo.m_InputTensorInfos.size() <= 0)
836 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100837 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000838 }
839 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
840 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100841 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000842 }
843
Nikhil Raj8599a412018-11-19 14:51:07 +0000844 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
845 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100846 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000847 }
848
849 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
850 {
851 return;
852 }
853
telsoa014fcda012018-03-09 14:13:49 +0000854 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
855 {
856 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100857 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000858 "has to match number of workloadInfo.m_InputTensorInfos. "
859 "Number of windows: " +
860 to_string(m_ViewOrigins.size()) +
861 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
862 }
863
telsoa01c577f2c2018-08-31 09:22:23 +0100864 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000865 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
866 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
867 {
telsoa01c577f2c2018-08-31 09:22:23 +0100868 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000869 ViewOrigin const& e = m_ViewOrigins[w];
870 if (e.m_Origin.size() != outputDims)
871 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100872 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000873 "have the same dimensionality as the output tensor. "
874 "Window origin (index: " +
875 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
876 " dimensions, the output "
877 "tensor has " +
878 to_string(outputDims) + " dimensions.");
879 }
telsoa01c577f2c2018-08-31 09:22:23 +0100880 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000881 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
882 {
883 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
884 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
885 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100886 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000887 "be smaller or equal than the size of the output in that coord.");
888 }
889 }
890 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100891
892 // Check the supported data types
893 std::vector<DataType> supportedTypes =
894 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000895 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100896 DataType::Float32,
897 DataType::Float16,
898 DataType::Boolean,
899 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100900 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000901 DataType::QAsymmU8,
902 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100903 };
904
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100905 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
906 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100907 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100908 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
909 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
910
911 const std::string inputName = "input_" + std::to_string(i);
912 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100913 }
telsoa014fcda012018-03-09 14:13:49 +0000914}
915
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100916void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
917{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100918 const std::string descriptorName{"StackQueueDescriptor"};
919
920 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100921
922 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
923 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100924 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100925 }
926
927 // All inputs must have the same shape, which is defined in parameters
928 const TensorShape& inputShape = m_Parameters.m_InputShape;
929 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
930 {
931 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
932 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100933 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100934 }
935 }
936
Matthew Jacksondba634f2019-08-15 15:14:18 +0100937 if (inputShape.GetNumDimensions() > 4)
938 {
939 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
940 }
941
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100942 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
943 // since the output tensor has an additional dimension.
944 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
945 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100946 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100947 "than the number of input dimensions.");
948 }
949
950 // Output shape must be as inferred from the input shape
951 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
952 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
953 {
954 if (outputShape[i] != inputShape[i])
955 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100956 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100957 "match shape inferred from input tensor.");
958 }
959 }
960
961 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
962 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100963 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100964 "match shape inferred from input tensor.");
965 }
966
967 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
968 {
969 if (outputShape[i] != inputShape[i-1])
970 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100971 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100972 "match shape inferred from input tensor.");
973 }
974 }
975
Matthew Jacksondba634f2019-08-15 15:14:18 +0100976 if (outputShape.GetNumDimensions() > 5)
977 {
978 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
979 }
980
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100981 // Check the supported data types
982 std::vector<DataType> supportedTypes =
983 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000984 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100985 DataType::Float32,
986 DataType::Float16,
987 DataType::Boolean,
988 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100989 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000990 DataType::QAsymmU8,
991 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100992 };
993
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100994 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100995
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100996 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100997 {
998 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
999 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001000 descriptorName,
1001 "input_0",
1002 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001003 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001004
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001005 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1006 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001007 descriptorName,
1008 "input_0",
1009 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001010}
1011
Ryan OSheaec6c6802020-06-05 17:17:06 +01001012void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1013{
1014 const std::string descriptorName{"FillQueueDescriptor"};
1015
1016 ValidateNumInputs(workloadInfo, descriptorName, 1);
1017 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1018
1019 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1020 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1021
1022 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1023
1024 std::vector<DataType> supportedTypes =
1025 {
1026 DataType::BFloat16,
1027 DataType::Float32,
1028 DataType::Float16,
1029 DataType::Signed32
1030 };
1031
1032 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1033}
1034
telsoa014fcda012018-03-09 14:13:49 +00001035void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1036{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001037 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001038
Matthew Sloyan81beae32021-07-13 19:46:11 +01001039 uint32_t numInputs = 2;
1040 if (m_Parameters.m_BiasEnabled)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001041 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001042 numInputs = 3;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001043 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001044
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001045 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001046 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1047
1048 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1049 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1050
1051 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1052
1053 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001054 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001055 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001056 }
1057
Matthew Sloyan81beae32021-07-13 19:46:11 +01001058 TensorInfo weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001059 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001060
1061 if (m_Parameters.m_BiasEnabled)
1062 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001063 TensorInfo biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
telsoa01c577f2c2018-08-31 09:22:23 +01001064 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001065 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001066 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1067 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001068 }
1069
Francis Murtagh46c09d02019-05-28 08:15:28 +01001070 // Check the supported data types
1071 std::vector<DataType> supportedTypes =
1072 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001073 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001074 DataType::Float32,
1075 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001076 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001077 DataType::QAsymmU8,
1078 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001079 };
1080
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001081 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001082
1083 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1084 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1085 {
1086 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1087 {
1088 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1089 "for BFloat16 input.");
1090 }
1091 }
1092 else
1093 {
1094 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1095 }
telsoa014fcda012018-03-09 14:13:49 +00001096}
1097
telsoa014fcda012018-03-09 14:13:49 +00001098void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1099{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001100 const std::string descriptorName{"NormalizationQueueDescriptor"};
1101
1102 ValidateNumInputs(workloadInfo, descriptorName, 1);
1103 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1104
1105 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1106 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001107
1108 // Check the supported data types
1109 std::vector<DataType> supportedTypes =
1110 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001111 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001112 DataType::Float16,
1113 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001114 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001115 DataType::QAsymmU8,
1116 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001117 };
1118
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001119 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001120
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001121 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001122
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001123 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001124}
1125
1126void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1127{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001128 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001129
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001130 ValidateNumInputs(workloadInfo, descriptorName, 2);
1131 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1132
1133 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1134 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1135 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1136
1137 std::vector<DataType> supportedTypes =
1138 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001139 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001140 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001141 DataType::Float16,
1142 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001143 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001144 DataType::QSymmS16,
1145 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001146 };
1147
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001148 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1149 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1150 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001151
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001152 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1153 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001154
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001155 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1156 inputTensorInfo1,
1157 outputTensorInfo,
1158 descriptorName,
1159 "input_0",
1160 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001161}
1162
telsoa014fcda012018-03-09 14:13:49 +00001163void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1164{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001165 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001166
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001167 ValidateNumInputs(workloadInfo, descriptorName, 2);
1168 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1169
1170 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1171 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1172 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1173
1174 std::vector<DataType> supportedTypes =
1175 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001176 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001177 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001178 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001179 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001180 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001181 DataType::QSymmS16,
1182 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001183 };
1184
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001185 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1186 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1187 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001188
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001189 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1190 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001191
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001192 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1193 inputTensorInfo1,
1194 outputTensorInfo,
1195 descriptorName,
1196 "input_0",
1197 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001198}
1199
1200void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1201{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001202 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001203
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001204 ValidateNumInputs(workloadInfo, descriptorName, 1);
1205 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1206
1207 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1208 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001209
1210 std::vector<DataType> supportedTypes =
1211 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001212 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001213 DataType::Float16,
1214 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001215 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001216 DataType::QAsymmU8,
1217 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001218 };
1219
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001220 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1221 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001222
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001223 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001224 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001225
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001226 ValidatePointer(m_Mean, descriptorName, "mean");
1227 ValidatePointer(m_Variance, descriptorName, "variance");
1228 ValidatePointer(m_Beta, descriptorName, "beta");
1229 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001230
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001231 const TensorInfo& mean = m_Mean->GetTensorInfo();
1232 const TensorInfo& variance = m_Variance->GetTensorInfo();
1233 const TensorInfo& beta = m_Beta->GetTensorInfo();
1234 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001235
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001236 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1237 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1238 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1239 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001240
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001241 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1242 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1243 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001244}
1245
1246void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1247{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001248 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001249
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001250 ValidateNumInputs(workloadInfo, descriptorName, 1);
1251 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001252
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001253 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1254 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001255
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001256 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1257 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001258
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001259 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001260
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001261 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1262 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001263
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001264 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001265
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001266 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001267 if (m_Parameters.m_BiasEnabled)
1268 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001269 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001270
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001271 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1272 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001273
1274 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1275 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001276 }
1277
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001278 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1279 {
1280 throw InvalidArgumentException(
1281 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1282 "cannot be either negative or 0.",
1283 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1284 }
1285
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001286 ValidatePerAxisQuantization(inputTensorInfo,
1287 outputTensorInfo,
1288 weightTensorInfo,
1289 optionalBiasTensorInfo,
1290 descriptorName);
1291
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001292 std::vector<DataType> supportedTypes =
1293 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001294 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001295 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001296 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001297 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001298 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001299 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001300 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001301 };
1302
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001303 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001304
1305 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1306 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1307 {
1308 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1309 {
1310 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1311 "for BFloat16 input.");
1312 }
1313 }
1314 else
1315 {
1316 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1317 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001318}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001319
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001320void Convolution3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1321{
1322 const std::string descriptorName{"Convolution3dQueueDescriptor"};
1323
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001324 uint32_t numInputs = 2;
1325 if (m_Parameters.m_BiasEnabled)
1326 {
1327 numInputs = 3;
1328 }
1329 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001330 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1331
1332 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1333 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1334
1335 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1336 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1337
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001338 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001339 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 5, "weight");
1340
1341 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1342
1343 Optional<TensorInfo> optionalBiasTensorInfo;
1344 if (m_Parameters.m_BiasEnabled)
1345 {
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001346 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001347 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1348
1349 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1350 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1351 }
1352
1353 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 || m_Parameters.m_StrideZ <= 0 )
1354 {
1355 throw InvalidArgumentException(
1356 fmt::format("{}: strideX (provided {}), strideY (provided {}) or strideZ (provided {})"
1357 "cannot be either negative or 0.",
1358 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY, m_Parameters.m_StrideZ));
1359 }
1360
1361 ValidatePerAxisQuantization(inputTensorInfo,
1362 outputTensorInfo,
1363 weightTensorInfo,
1364 optionalBiasTensorInfo,
1365 descriptorName);
1366
1367 std::vector<DataType> supportedTypes =
1368 {
1369 DataType::BFloat16,
1370 DataType::Float16,
1371 DataType::Float32,
1372 DataType::QAsymmS8,
1373 DataType::QAsymmU8,
1374 DataType::QSymmS16,
1375 DataType::QSymmS8
1376 };
1377
1378 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1379 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1380}
1381
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001382void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1383{
1384 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1385
1386 ValidateNumInputs(workloadInfo, descriptorName, 1);
1387 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1388
1389 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1390 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1391
1392 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1393 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1394
1395 ValidatePointer(m_Weight, descriptorName, "weight");
1396
1397 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1398 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1399
1400 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1401 {
1402 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001403 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1404 "cannot be smaller than 1.",
1405 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001406 }
1407
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001408 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1409 {
1410 throw InvalidArgumentException(
1411 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1412 "cannot be either negative or 0.",
1413 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1414 }
1415
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001416 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1417
Jan Eilers53ef7952021-06-02 12:01:25 +01001418 // Expected weight shape: [ 1, H, W, I*M ] - This shape does NOT depend on the data layout
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001419 // inputChannels * channelMultiplier should be equal to outputChannels.
Jan Eilers53ef7952021-06-02 12:01:25 +01001420 const unsigned int numWeightOutputChannels = weightTensorInfo.GetShape()[3]; // I*M=Cout
1421 const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1422 if (numWeightOutputChannels != numOutputChannels)
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001423 {
James Ward47fce872020-09-10 11:57:28 +01001424 throw InvalidArgumentException(fmt::format(
Jan Eilers53ef7952021-06-02 12:01:25 +01001425 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1426 "But 4th dimension is not equal to Cout. Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
1427 descriptorName,
1428 numOutputChannels,
1429 weightTensorInfo.GetShape()[0],
1430 weightTensorInfo.GetShape()[1],
1431 weightTensorInfo.GetShape()[2],
1432 weightTensorInfo.GetShape()[3]));
1433 }
1434 if (weightTensorInfo.GetShape()[0] != 1)
1435 {
1436 throw InvalidArgumentException(fmt::format(
1437 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1438 "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
1439 descriptorName,
1440 weightTensorInfo.GetShape()[0],
1441 weightTensorInfo.GetShape()[1],
1442 weightTensorInfo.GetShape()[2],
1443 weightTensorInfo.GetShape()[3]));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001444 }
1445
Teresa Charlind8df0262019-11-11 12:28:15 +00001446 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001447
Teresa Charlind8df0262019-11-11 12:28:15 +00001448 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001449 if (m_Parameters.m_BiasEnabled)
1450 {
1451 ValidatePointer(m_Bias, descriptorName, "bias");
1452
Teresa Charlind8df0262019-11-11 12:28:15 +00001453 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1454 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001455
1456 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1457 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1458 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001459 ValidatePerAxisQuantization(inputTensorInfo,
1460 outputTensorInfo,
1461 weightTensorInfo,
1462 optionalBiasTensorInfo,
1463 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001464
1465 std::vector<DataType> supportedTypes =
1466 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001467 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001468 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001469 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001470 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001471 DataType::QAsymmU8,
1472 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001473 };
1474
1475 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1476 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001477}
1478
1479void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1480{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001481 const std::string descriptorName{"PermuteQueueDescriptor"};
1482
1483 ValidateNumInputs(workloadInfo, descriptorName, 1);
1484 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001485
1486 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1487
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001488 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1489 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001490
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001491 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1492 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001493
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001494 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001495 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001496 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001497 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001498 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1499 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1500 "must match dst dimension " + to_string(mapping[i]) +
1501 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001502 }
1503 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001504
1505 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001506}
1507
1508void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1509{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001510 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001511
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001512 ValidateNumInputs(workloadInfo, descriptorName, 1);
1513 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1514
1515 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1516 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1517
1518 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1519 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001520
1521 std::vector<DataType> supportedTypes =
1522 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001523 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001524 DataType::Float32,
1525 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001526 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001527 DataType::QAsymmU8,
1528 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001529 };
1530
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001531 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1532 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001533}
1534
Tamás Nyíri7b885b32021-10-26 14:47:57 +01001535void Pooling3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1536{
1537 const std::string descriptorName{"Pooling3dQueueDescriptor"};
1538
1539 ValidateNumInputs(workloadInfo, descriptorName, 1);
1540 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1541
1542 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1543 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1544
1545 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1546 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1547
1548 std::vector<DataType> supportedTypes =
1549 {
1550 DataType::BFloat16,
1551 DataType::Float32,
1552 DataType::Float16,
1553 DataType::QAsymmS8,
1554 DataType::QAsymmU8,
1555 DataType::QSymmS16
1556 };
1557
1558 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1559 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1560}
1561
1562
telsoa014fcda012018-03-09 14:13:49 +00001563void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1564{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001565 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001566
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001567 ValidateNumInputs(workloadInfo, descriptorName, 1);
1568 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1569
1570 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1571 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1572
1573 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1574 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001575
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001576 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001577 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001578 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001579 DataType::Float16,
1580 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001581 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001582 DataType::QAsymmU8,
1583 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001584 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001585
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001586 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1587 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001588
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001589 // ResizeBilinear only changes width and height: batch and channel count must match.
1590 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1591 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001592 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001593 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001594 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001595 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1596 descriptorName, inputBatchSize, outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001597 }
1598
Teresa Charlin970f43b2019-07-01 13:51:07 +01001599 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001600 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1601 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001602 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001603 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001604 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001605 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1606 descriptorName, inputChannelCount, outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001607 }
1608}
1609
1610void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1611{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001612 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001613
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001614 ValidateNumInputs(workloadInfo, descriptorName, 1);
1615 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1616
1617 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1618 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1619
1620 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1621 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001622
1623 std::vector<DataType> supportedTypes =
1624 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001625 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001626 DataType::Float16,
1627 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001628 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001629 DataType::QAsymmU8,
1630 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001631 };
1632
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001633 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1634 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001635
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001636 // Resize only changes width and height: batch and channel count must match.
1637 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1638 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001639 if (inputBatchSize != outputBatchSize)
1640 {
1641 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001642 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1643 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001644 }
1645
1646 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001647 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1648 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001649 if (inputChannelCount != outputChannelCount)
1650 {
1651 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001652 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1653 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001654 }
1655}
1656
1657void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1658{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001659 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001660
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001661 ValidateNumInputs(workloadInfo, descriptorName, 1);
1662 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1663
1664 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1665 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1666
1667 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1668 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1669
1670 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1671
telsoa014fcda012018-03-09 14:13:49 +00001672 if (m_Parameters.m_Min > m_Parameters.m_Max)
1673 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001674 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001675 }
telsoa014fcda012018-03-09 14:13:49 +00001676}
1677
Kevin Mayce5045a2019-10-02 14:07:47 +01001678void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1679{
1680 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1681
1682 ValidateNumInputs(workloadInfo, descriptorName, 1);
1683 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1684
1685 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1686 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1687
1688 if (inputTensorInfo.GetNumDimensions() > 4)
1689 {
1690 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1691 }
1692
1693 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1694
1695 // Check the supported data types
1696 std::vector<DataType> supportedTypes =
1697 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001698 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001699 DataType::Float32,
1700 DataType::Float16
1701 };
1702
1703 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001704 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001705}
1706
telsoa014fcda012018-03-09 14:13:49 +00001707void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1708{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001709 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001710
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001711 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001712 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1713
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001714 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1715 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1716
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001717 if (inputTensorInfo.GetNumDimensions() > 4)
1718 {
1719 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1720 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001721
1722 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001723
1724 // Check the supported data types
1725 std::vector<DataType> supportedTypes =
1726 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001727 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001728 DataType::Float32,
1729 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001730 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001731 DataType::QAsymmU8,
1732 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001733 };
1734
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001735 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001736 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1737}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001738
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001739void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1740{
1741 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1742
1743 ValidateNumInputs(workloadInfo, descriptorName, 1);
1744 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1745
1746 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1747 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1748
1749 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1750
1751 std::vector<DataType> supportedTypes =
1752 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001753 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001754 DataType::Float32,
1755 DataType::Float16,
1756 };
1757
1758 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001759 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001760}
1761
1762void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1763{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001764 const std::string descriptorName{"ConstantQueueDescriptor"};
1765
1766 ValidateNumInputs(workloadInfo, descriptorName, 0);
1767 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001768
1769 if (!m_LayerOutput)
1770 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001771 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001772 }
1773
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001774 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1775 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001776
1777 // Check the supported data types
1778 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001779 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001780 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001781 DataType::Float32,
1782 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001783 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001784 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001785 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001786 DataType::QSymmS16,
1787 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001788 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001789
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001790 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001791}
1792
1793void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1794{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001795 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001796
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001797 ValidateNumInputs(workloadInfo, descriptorName, 1);
1798 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1799
1800 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1801 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1802
1803 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001804
1805 // Check the supported data types
1806 std::vector<DataType> supportedTypes =
1807 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001808 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001809 DataType::Float32,
1810 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001811 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001812 DataType::QAsymmU8,
1813 DataType::QSymmS16,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001814 DataType::Signed32,
1815 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001816 };
1817
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001818 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1819 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001820}
1821
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001822void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1823{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001824 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001825
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001826 ValidateNumInputs(workloadInfo, descriptorName, 1);
1827 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1828
1829 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1830 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1831
1832 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1833 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001834
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001835 if (m_Parameters.m_BlockShape.size() != 2)
1836 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001837 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001838 }
1839
1840 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1841 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001842 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1843 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001844 }
1845
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001846 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001847
1848 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001849 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001850
Matthew Bentham8800c002018-11-19 13:19:28 +00001851 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001852
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001853 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1854 widthPad.first + widthPad.second;
1855 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1856 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001857
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001858 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1859 inputShape[dimensionIndices.GetChannelsIndex()];
1860 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001861
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001862 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001863 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001864 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001865 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001866 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001867 }
1868
1869 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001870 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001871 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1872 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001873 }
nikraj01120522a2019-05-31 11:33:07 +01001874
1875 std::vector<DataType> supportedTypes =
1876 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001877 DataType::BFloat16,
1878 DataType::Float16,
1879 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001880 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001881 DataType::QAsymmU8,
1882 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001883 };
1884
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001885 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1886 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001887}
1888
Keith Davisa57eccb2019-06-14 17:33:22 +01001889void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1890{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001891 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001892
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001893 ValidateNumInputs(workloadInfo, descriptorName, 1);
1894 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001895
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001896 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1897 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1898
1899 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1900 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001901
1902 std::vector<DataType> supportedTypes =
1903 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001904 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001905 DataType::Float32,
1906 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001907 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001908 DataType::QAsymmU8,
1909 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001910 };
1911
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001912 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1913 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001914
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001915 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1916
1917 if (m_Parameters.m_BlockSize == 0)
1918 {
1919 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1920 }
1921
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001922 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1923 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1924 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1925 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001926
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001927 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001928 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001929 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001930 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1931 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001932 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001933
1934 const TensorShape& outputShape = outputTensorInfo.GetShape();
1935 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1936 {
1937 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1938 "must be divisible by the square of block size." );
1939 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001940}
1941
telsoa014fcda012018-03-09 14:13:49 +00001942void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1943{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001944 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001945
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001946 ValidateNumInputs(workloadInfo, descriptorName, 1);
1947 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1948
1949 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1950 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001951
1952 std::vector<DataType> supportedTypes =
1953 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001954 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001955 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001956 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001957 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001958 };
1959
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001960 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matthew Sloyan81beae32021-07-13 19:46:11 +01001961 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1962 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1963 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001964}
1965
telsoa01c577f2c2018-08-31 09:22:23 +01001966void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1967{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001968 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1969
1970 const std::string descriptorName{"LstmQueueDescriptor"};
1971
1972 // check dimensions of all inputs and outputs
1973 if (workloadInfo.m_InputTensorInfos.size() != 3)
1974 {
1975 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1976 }
1977 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1978 {
1979 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1980 }
1981
1982 std::vector<DataType> supportedTypes =
1983 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001984 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001985 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001986 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001987 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001988 };
1989
Jan Eilers38e05bd2019-06-26 13:10:09 +01001990 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001991 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1992
Jan Eilers38e05bd2019-06-26 13:10:09 +01001993 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001994 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001995 {
1996 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1997 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001998 descriptorName,
1999 "input_0",
2000 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01002001 }
2002 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002003 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01002004 {
2005 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2006 workloadInfo.m_OutputTensorInfos[i],
2007 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002008 "input_0",
2009 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01002010 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01002011
janeil0117d8d852019-11-15 15:00:16 +00002012 // Making sure clipping parameters have valid values.
2013 // == 0 means no clipping
2014 // > 0 means clipping
2015 if (m_Parameters.m_ClippingThresCell < 0.0f)
2016 {
2017 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
2018 }
2019 if (m_Parameters.m_ClippingThresProj < 0.0f)
2020 {
2021 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
2022 }
2023
Jan Eilers38e05bd2019-06-26 13:10:09 +01002024 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01002025 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
2026 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
2027 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
2028 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
2029 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
2030 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
2031
Jan Eilers38e05bd2019-06-26 13:10:09 +01002032 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002033 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
2034 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002035 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002036 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
2037 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002038 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002039 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
2040 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002041 // scratchBufferTensor
2042 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002043 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
2044 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002045 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002046 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
2047 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002048 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002049 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
2050 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002051 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002052 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
2053 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002054
Jan Eilers38e05bd2019-06-26 13:10:09 +01002055 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
2056 if ( m_InputToInputWeights )
2057 {
2058 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
2059 (n_cell * n_input), "InputLayerNormWeights");
2060 }
2061
2062 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2063 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
2064 (n_cell * n_input), "InputToForgetWeights");
2065
2066 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2067 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2068 (n_cell * n_input), "InputToCellWeights");
2069
2070 if ( m_RecurrentToInputWeights )
2071 {
2072 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2073 (n_cell * n_output), "RecurrentToInputWeights");
2074 }
2075
2076 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2077 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2078 (n_cell * n_output), "RecurrentToForgetWeights");
2079
2080 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2081 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2082 (n_cell * n_output), "RecurrentToCellWeights");
2083
2084 // Make sure the input-gate's parameters are either both present (regular
2085 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2086 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2087 !m_Parameters.m_CifgEnabled) ||
2088 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2089 m_Parameters.m_CifgEnabled));
2090 if (!cifg_weights_all_or_none)
2091 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002092 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2093 "RecurrentToInputWeights must either both be present (regular LSTM) "
2094 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2095 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002096 }
2097
2098 if ( m_CellToInputWeights )
2099 {
2100 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2101 n_cell, "CellToInputWeights");
2102 }
2103 if ( m_CellToForgetWeights )
2104 {
2105 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2106 n_cell, "CellToForgetWeights");
2107 }
2108 if ( m_CellToOutputWeights )
2109 {
2110 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2111 n_cell, "CellToOutputWeights");
2112 }
2113
2114 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2115 bool peephole_weights_all_or_none =
2116 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2117 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2118 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2119 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2120 if (!peephole_weights_all_or_none)
2121 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002122 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002123 }
2124
2125 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2126 if (m_Parameters.m_CifgEnabled)
2127 {
2128 if (m_InputGateBias)
2129 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002130 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002131 }
2132 }
2133 else
2134 {
2135 if (!m_InputGateBias)
2136 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002137 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2138 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002139 }
2140 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2141 n_cell, "InputGateBias");
2142 }
2143
2144 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2145 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2146
2147 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2148 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2149
2150 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2151 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2152
2153 if (m_ProjectionWeights)
2154 {
2155 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2156 (n_cell * n_output), "ProjectionWeights");
2157 }
2158 if (m_ProjectionBias)
2159 {
2160 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2161 }
2162
2163 // Making sure the projection tensors are consistent:
2164 // 1) If projection weight is not present, then projection bias should not be
2165 // present.
2166 // 2) If projection weight is present, then projection bias is optional.
2167 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2168 !m_Parameters.m_ProjectionEnabled)
2169 || (m_ProjectionWeights && !m_ProjectionBias &&
2170 m_Parameters.m_ProjectionEnabled)
2171 || (m_ProjectionWeights && m_ProjectionBias &&
2172 m_Parameters.m_ProjectionEnabled));
2173 if (!projecton_tensors_consistent)
2174 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002175 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002176 }
2177
2178 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2179 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2180 // either all have values or none of them have values. Layer normalization is used when the values of all the
2181 // layer normalization weights are present
2182 if (m_InputLayerNormWeights)
2183 {
2184 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2185 }
2186 if (m_ForgetLayerNormWeights)
2187 {
2188 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2189 }
2190 if (m_CellLayerNormWeights)
2191 {
2192 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2193 }
2194 if (m_OutputLayerNormWeights)
2195 {
2196 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2197 }
2198
Jan Eilers38e05bd2019-06-26 13:10:09 +01002199 if (m_Parameters.m_LayerNormEnabled)
2200 {
2201 if (!m_Parameters.m_CifgEnabled)
2202 {
2203 if (!m_InputLayerNormWeights)
2204 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002205 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2206 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002207 }
2208 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2209 1, n_cell, "InputLayerNormWeights");
2210 }
2211 else if (m_InputLayerNormWeights)
2212 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002213 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2214 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002215 }
2216
2217 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2218 "ForgetLayerNormWeights");
2219 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2220
2221 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2222 "OutputLayerNormWeights");
2223 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2224
2225 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2226 "CellLayerNormWeights");
2227 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2228 }
2229 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2230 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002231 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2232 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002233 }
telsoa01c577f2c2018-08-31 09:22:23 +01002234}
2235
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002236void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2237{
2238 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2239
2240 ValidateNumInputs(workloadInfo, descriptorName, 1);
2241 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2242
2243 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2244 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2245
2246 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2247 {
2248 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2249 }
2250
2251 if (outputTensorInfo.GetDataType() != DataType::Float32)
2252 {
2253 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2254 }
2255
2256 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2257}
2258
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002259void ConvertFp32ToBf16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2260{
2261 const std::string descriptorName{"ConvertFp32ToBf16QueueDescriptor"};
2262
2263 ValidateNumInputs(workloadInfo, descriptorName, 1);
2264 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2265
2266 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2267 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2268
2269 if (inputTensorInfo.GetDataType() != DataType::Float32)
2270 {
2271 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2272 }
2273
2274 if (outputTensorInfo.GetDataType() != DataType::BFloat16)
2275 {
2276 throw InvalidArgumentException(descriptorName + ": Output tensor type must be BFloat16.");
2277 }
2278
2279 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2280}
2281
telsoa01c577f2c2018-08-31 09:22:23 +01002282void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2283{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002284 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002285
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002286 ValidateNumInputs(workloadInfo, descriptorName, 1);
2287 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2288
2289 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2290 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2291
2292 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002293 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002294 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002295 }
2296
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002297 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002298 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002299 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002300 }
2301
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002302 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002303}
2304
2305void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2306{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002307 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002308
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002309 ValidateNumInputs(workloadInfo, descriptorName, 1);
2310 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2311
2312 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2313 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2314
2315 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002316 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002317 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002318 }
2319
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002320 if (outputTensorInfo.GetDataType() != DataType::Float32)
2321 {
2322 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2323 }
2324
2325 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002326}
2327
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002328void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2329{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002330 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002331
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002332 ValidateNumInputs(workloadInfo, descriptorName, 2);
2333 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2334
2335 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2336 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2337 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2338
2339 std::vector<DataType> supportedTypes =
2340 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002341 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002342 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002343 DataType::Float32,
2344 DataType::QAsymmS8,
2345 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002346 DataType::QSymmS16,
2347 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002348 };
2349
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002350 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2351 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2352 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002353
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002354 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2355 inputTensorInfo1,
2356 outputTensorInfo,
2357 descriptorName,
2358 "input_0",
2359 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002360}
2361
David Beckc2044fe2018-09-05 15:00:38 +01002362void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2363{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002364 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002365
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002366 ValidateNumInputs(workloadInfo, descriptorName, 2);
2367 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2368
2369 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2370 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2371 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2372
2373 std::vector<DataType> supportedTypes =
2374 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002375 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002376 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002377 DataType::Float32,
2378 DataType::QAsymmS8,
2379 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002380 DataType::QSymmS16,
2381 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002382 };
2383
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002384 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2385 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2386 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002387
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002388 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2389 inputTensorInfo1,
2390 outputTensorInfo,
2391 descriptorName,
2392 "input_0",
2393 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002394}
2395
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002396void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2397{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002398 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002399
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002400 ValidateNumInputs(workloadInfo, descriptorName, 2);
2401 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2402
2403 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2404 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2405 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2406
2407 std::vector<DataType> supportedTypes =
2408 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002409 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002410 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002411 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002412 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002413 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002414 DataType::QSymmS16,
2415 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002416 };
2417
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002418 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2419 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2420 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002421
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002422 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2423 inputTensorInfo1,
2424 outputTensorInfo,
2425 descriptorName,
2426 "input_0",
2427 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002428}
2429
narpra01a6bf9122018-09-10 09:50:09 +01002430void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2431{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002432 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002433
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002434 ValidateNumInputs(workloadInfo, descriptorName, 1);
2435 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2436
2437 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2438 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002439
2440 std::vector<DataType> supportedTypes =
2441 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002442 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002443 DataType::Float32,
2444 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002445 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002446 DataType::QAsymmU8,
2447 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002448 };
narpra01eb061912018-09-10 17:35:27 +01002449
James Conroy4d1ff582019-06-10 17:06:39 +01002450 // First check if input tensor data type is supported, then
2451 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002452 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2453 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002454
narpra0132b90462018-09-13 11:07:48 +01002455 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002456 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002457 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002458 }
narpra0132b90462018-09-13 11:07:48 +01002459 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002460 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002461 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002462 }
2463 else
2464 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002465 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002466 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002467 ValidateTensorNumDimensions(outputTensorInfo,
2468 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002469 outputDim > 0 ? outputDim : 1,
2470 "output");
2471 }
narpra01a6bf9122018-09-10 09:50:09 +01002472}
2473
jimfly012c9322a2018-09-19 10:59:49 +01002474void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2475{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002476 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002477
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002478 ValidateNumInputs(workloadInfo, descriptorName, 1);
2479 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2480
2481 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2482 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002483
jimfly012c9322a2018-09-19 10:59:49 +01002484 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002485 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2486
jimfly012c9322a2018-09-19 10:59:49 +01002487 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002488 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2489 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2490 "as there are dimensions in the input tensor that is " +
2491 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2492 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002493 }
2494}
2495
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002496void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2497{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002498 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002499
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002500 ValidateNumInputs(workloadInfo, descriptorName, 1);
2501 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002502
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002503 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2504 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2505
Sadik Armagan2208b602019-07-31 16:36:27 +01002506 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002507 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002508 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002509 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002510 DataType::Float16,
2511 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002512 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002513 DataType::QAsymmU8,
2514 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002515 };
2516
2517 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002518
Keith Davis0c2eeac2020-02-11 16:51:50 +00002519 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002520 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002521 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002522 }
2523}
2524
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002525void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2526{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002527 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002528
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002529 ValidateNumInputs(workloadInfo, descriptorName, 1);
2530 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002531
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002532 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2533 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002534
2535 std::vector<DataType> supportedTypes =
2536 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002537 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002538 DataType::Float32,
2539 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002540 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002541 DataType::QAsymmU8,
2542 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002543 };
2544
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002545 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2546 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002547}
2548
Conor Kennedy430b5d82018-11-14 15:28:28 +00002549void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2550{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002551 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002552
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002553 ValidateNumInputs(workloadInfo, descriptorName, 1);
2554 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2555
2556 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2557 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002558
2559 std::vector<DataType> supportedTypes =
2560 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002561 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002562 DataType::Float16,
2563 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002564 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002565 DataType::QAsymmU8,
2566 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002567 };
2568
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002569 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2570 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002571
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002572 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002573
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002574 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002575 if (rank > 4)
2576 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002577 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002578 }
2579
Conor Kennedy430b5d82018-11-14 15:28:28 +00002580 // Begin, End & Stride length must be of rank(input0)
2581 if (m_Parameters.m_Begin.size() != rank)
2582 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002583 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002584 }
2585
2586 if (m_Parameters.m_End.size() != rank)
2587 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002588 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002589 }
2590
2591 if (m_Parameters.m_Stride.size() != rank)
2592 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002593 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002594 }
2595
2596 // Stride entries must be non-zero
2597 for (auto& stride : m_Parameters.m_Stride)
2598 {
2599 if (stride == 0)
2600 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002601 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002602 }
2603 }
2604}
2605
kevmay0190539692018-11-29 08:40:19 +00002606void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2607{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002608 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002609
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002610 ValidateNumInputs(workloadInfo, descriptorName, 2);
2611 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2612
2613 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2614 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2615 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2616
2617 std::vector<DataType> supportedTypes =
2618 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002619 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002620 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002621 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002622 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002623 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002624 DataType::QSymmS16,
2625 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002626 };
2627
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002628 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2629 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2630 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002631
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002632 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2633 inputTensorInfo1,
2634 outputTensorInfo,
2635 descriptorName,
2636 "input_0",
2637 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002638}
2639
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002640void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2641{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002642 const std::string descriptorName{"DebugQueueDescriptor"};
2643
2644 ValidateNumInputs(workloadInfo, descriptorName, 1);
2645 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002646}
2647
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002648void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2649{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002650 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002651
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002652 ValidateNumInputs(workloadInfo, descriptorName, 2);
2653 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002654
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002655 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2656 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2657 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2658
2659 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2660 inputTensorInfo1,
2661 outputTensorInfo,
2662 descriptorName,
2663 "input_0",
2664 "input_1");
2665
2666 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002667 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002668 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002669 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002670}
2671
FrancisMurtagh878f0232018-12-19 10:56:15 +00002672void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2673{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002674 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002675
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002676 ValidateNumInputs(workloadInfo, descriptorName, 2);
2677 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002678
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002679 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2680 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2681 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2682
2683 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2684 inputTensorInfo1,
2685 outputTensorInfo,
2686 descriptorName,
2687 "input_0",
2688 "input_1");
2689
2690 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002691 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002692 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002693 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002694}
2695
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002696void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2697{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002698 const std::string descriptorName{"RsqrtQueueDescriptor"};
2699
2700 ValidateNumInputs(workloadInfo, descriptorName, 1);
2701 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2702
2703 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2704 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2705
2706 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002707
2708 std::vector<DataType> supportedTypes =
2709 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002710 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002711 DataType::Float16,
2712 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002713 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002714 DataType::QAsymmU8,
2715 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002716 };
2717
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002718 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2719 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002720}
2721
narpra01b89b05f2019-01-16 09:53:09 +00002722void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2723{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002724 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002725
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002726 ValidateNumInputs(workloadInfo, descriptorName, 2);
2727 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002728
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002729 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2730 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002731 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002732 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002733 }
2734
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002735 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2736 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2737
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002738 std::vector<DataType> supportedTypes =
2739 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002740 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002741 DataType::Float16,
2742 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002743 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002744 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002745 DataType::QSymmS16,
2746 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002747 };
2748
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002749 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002750
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002751 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002752
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002753 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2754 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002755}
2756
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002757void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2758{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002759 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2760
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002761 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002762
2763 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2764 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002765 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002766 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2767 }
2768
2769 if (m_Anchors == nullptr)
2770 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002771 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002772 }
2773
2774 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002775 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2776 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2777
2778 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002779 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002780 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2781 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002782
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002783 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2784 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2785 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002786
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002787 const std::vector<DataType> supportedInputTypes =
2788 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002789 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002790 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002791 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002792 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002793 DataType::QAsymmU8,
2794 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002795 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002796
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002797 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2798 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2799 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2800
2801 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2802 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2803 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2804 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2805
2806 // NOTE: Output is always Float32 regardless of input type
2807 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2808 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2809 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2810 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002811
2812 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2813 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002814 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002815 "must be positive and less than or equal to 1.");
2816 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002817
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002818 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2819 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002820 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002821 "should be equal to number of classes + 1.");
2822 }
2823}
2824
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002825void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2826{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002827 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002828
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002829 ValidateNumInputs(workloadInfo, descriptorName, 1);
2830 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2831
2832 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2833 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2834
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002835 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002836 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002837 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002838 }
2839
Sadik Armagan2208b602019-07-31 16:36:27 +01002840 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002841 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002842 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002843 DataType::Float32,
2844 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002845 };
2846
2847 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002848}
2849
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002850void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2851{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002852 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002853
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002854 ValidateNumInputs(workloadInfo, descriptorName, 2);
2855 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002856
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002857 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2858 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2859 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002860
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002861 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2862 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2863
2864 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2865 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002866}
2867
Keith Davis3ae3f972021-05-21 16:33:48 +01002868void ShapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2869{
2870 const std::string& descriptorName{"ShapeQueueDescriptor"};
2871
2872 ValidateNumInputs(workloadInfo, descriptorName, 1);
2873 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2874
2875 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2876 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2877
2878 std::vector<DataType> supportedTypes =
2879 {
2880 DataType::BFloat16,
2881 DataType::Float16,
2882 DataType::Float32,
2883 DataType::QAsymmS8,
2884 DataType::QAsymmU8,
2885 DataType::QAsymmS8,
2886 DataType::QSymmS8,
2887 DataType::QSymmS16,
2888 DataType::Signed32
2889 };
2890
2891 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2892 ValidateDataTypes(outputTensorInfo, {DataType::Signed32}, descriptorName);
2893}
2894
Sadik Armaganeff363d2019-04-05 15:25:46 +01002895void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2896{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002897 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002898
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002899 ValidateNumInputs(workloadInfo, descriptorName, 2);
2900 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2901
2902 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2903 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2904
2905 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2906 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2907
2908 std::vector<DataType> supportedTypes =
2909 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002910 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002911 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002912 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002913 DataType::QAsymmU8,
2914 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002915 };
2916
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002917 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2918 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002919
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002920 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2921 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002922
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002923 ValidateTensorShapesMatch(inputTensorInfo0,
2924 outputTensorInfo0,
2925 descriptorName,
2926 "input_0",
2927 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002928
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002929 ValidateTensorShapesMatch(inputTensorInfo0,
2930 outputTensorInfo1,
2931 descriptorName,
2932 "input_0",
2933 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002934}
2935
Derek Lamberti901ea112019-12-10 22:07:09 +00002936void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002937{
2938 // This is internally generated so it should not need validation.
2939}
2940
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002941void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2942{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002943 const std::string& descriptorName{"PreluQueueDescriptor"};
2944
2945 ValidateNumInputs(workloadInfo, descriptorName, 2);
2946 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2947
2948 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2949 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2950 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002951
2952 std::vector<DataType> supportedTypes
2953 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002954 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002955 DataType::Float16,
2956 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002957 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002958 DataType::QAsymmU8,
2959 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002960 };
2961
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002962 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2963 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002964
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002965 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002966
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002967 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2968 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002969
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002970 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2971 alphaTensorInfo,
2972 outputTensorInfo,
2973 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002974 "input",
2975 "alpha");
2976}
2977
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002978void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2979{
2980 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2981
2982 ValidateNumInputs(workloadInfo, descriptorName, 1);
2983 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2984
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002985 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2986 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2987
2988 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2989 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002990
2991 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002992
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002993 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2994 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002995
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002996 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2997
2998 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002999 if (m_Parameters.m_BiasEnabled)
3000 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003001 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003002
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003003 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
3004 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003005
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003006 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003007 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003008 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003009
3010 ValidatePerAxisQuantization(inputTensorInfo,
3011 outputTensorInfo,
3012 weightTensorInfo,
3013 optionalBiasTensorInfo,
3014 descriptorName);
3015
3016 std::vector<DataType> supportedTypes =
3017 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003018 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003019 DataType::Float32,
3020 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003021 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003022 DataType::QAsymmU8,
3023 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00003024 };
3025
3026 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3027 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01003028}
3029
Mike Kellyc9ea45a2020-02-28 18:11:58 +00003030void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3031{
3032 const std::string descriptorName{"TransposeQueueDescriptor"};
3033
3034 ValidateNumInputs(workloadInfo, descriptorName, 1);
3035 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3036
3037 const PermutationVector& mapping = m_Parameters.m_DimMappings;
3038
3039 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3040 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3041
3042 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
3043 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
3044
3045 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
3046 {
3047 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
3048 {
3049 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
3050 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
3051 "must match dst dimension " + to_string(i) +
3052 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
3053 }
3054 }
3055
3056 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3057}
3058
Simon Obute51f67772021-09-03 15:50:13 +01003059void ChannelShuffleQueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
3060{
3061 const std::string descriptorName{"TransposeQueueDescriptor"};
3062
3063 ValidateNumInputs(workloadInfo, descriptorName, 1);
3064 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3065
3066 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3067 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3068
3069 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3070}
3071
James Conroy4f1f8992020-04-29 20:01:10 +01003072void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3073{
3074 const std::string descriptorName{"QLstmQueueDescriptor"};
3075
3076 // Validate number of inputs/outputs
3077 ValidateNumInputs(workloadInfo, descriptorName, 3);
3078 ValidateNumOutputs(workloadInfo, descriptorName, 3);
3079
3080 // Input/output tensor info
3081 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3082 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
3083 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
3084
3085 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3086 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3087 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
3088
3089 // Supported types for various tensors in QLSTM
3090 std::vector<DataType> inputOutputSupportedTypes =
3091 {
3092 DataType::QAsymmS8
3093 };
3094
3095 std::vector<DataType> cellStateSupportedTypes =
3096 {
3097 DataType::QSymmS16
3098 };
3099
3100 std::vector<DataType> weightsSupportedTypes =
3101 {
3102 DataType::QSymmS8
3103 };
3104
3105 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3106 {
3107 DataType::QSymmS16
3108 };
3109
3110 std::vector<DataType> biasSupportedTypes =
3111 {
3112 DataType::Signed32
3113 };
3114
3115 // Validate types of input/output tensors
3116 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3117 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3118 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3119
3120 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3121 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3122 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3123
3124 // Validate matching types of input/output tensors
3125 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3126 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3127 "outputStateIn", "outputStateOut");
3128 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3129
3130 // Infer number of batches, number of units, input size and output size from tensor dimensions
3131 const uint32_t numBatches = inputInfo.GetShape()[0];
3132 const uint32_t inputSize = inputInfo.GetShape()[1];
3133 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3134 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3135
3136 // Validate number of dimensions and number of elements for input/output tensors
3137 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3138 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3139 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3140
3141 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3142 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3143 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3144
3145 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3146 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3147 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3148 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3149
3150 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3151 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3152 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3153
3154 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3155 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3156 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3157
3158 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3159 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3160 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3161 " RecurrentToForgetWeights");
3162
3163 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3164 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3165 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3166
3167 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3168 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3169 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3170
3171 // Validate data types for MANDATORY weights tensors (all should match each other)
3172 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3173
3174 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3175 "inputToForgetWeights", "inputToCellWeights");
3176 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3177 "inputToForgetWeights", "inputToOutputWeights");
3178
3179 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3180 "inputToForgetWeights", "recurrentToForgeteights");
3181 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3182 "inputToForgetWeights", "recurrentToCellWeights");
3183 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3184 "inputToForgetWeights", "recurrentToOutputWeights");
3185
3186 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3187 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3188 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3189 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3190
3191 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3192 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3193 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3194
3195 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3196 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3197 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3198
3199 // Validate data types for MANDATORY bias tensors
3200 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3201
3202 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3203 "forgetGateBias", "cellBias");
3204 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3205 "forgetGateBias", "outputGateBias");
3206
3207 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3208 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3209 !m_Parameters.m_CifgEnabled) ||
3210 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3211 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3212
3213 if (!allCifgParamsPresentOrNot)
3214 {
3215 throw InvalidArgumentException(descriptorName +
3216 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3217 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3218 "set appropriately.");
3219 }
3220
3221 if (!m_Parameters.m_CifgEnabled)
3222 {
3223 // Validate number of dimensions and number of elements
3224 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3225 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3226
3227 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3228 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3229 " RecurrentToInputWeights");
3230
3231 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3232 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3233
3234 // Validate data types
3235 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3236 "inputToForgetWeights", "inputToInputWeights");
3237 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3238 "inputToForgetWeights", "recurrentToInputWeights");
3239 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3240 "forgetGateBias", "inputGateBias");
3241 }
3242
3243 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3244 bool allPeepholeWeightsPresentOrNot =
3245 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3246 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3247 || (!m_CellToInputWeights && !m_CellToForgetWeights
3248 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3249
3250 if (!allPeepholeWeightsPresentOrNot)
3251 {
3252 throw InvalidArgumentException(descriptorName +
3253 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3254 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3255 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3256 "appropriately.");
3257 }
3258
3259 if (m_Parameters.m_PeepholeEnabled)
3260 {
3261 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3262 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3263 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3264
3265 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3266 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3267 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3268 "cellToForgetWeight", "cellToOutputWeights");
3269
3270 if (!m_Parameters.m_CifgEnabled)
3271 {
3272 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3273 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3274 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3275 "cellToForgetWeights", "cellToInputWeights");
3276 }
3277 }
3278
3279 // Validate OPTIONAL params: Layer Norm Weights
3280 bool allLayerNormWeightsPresentOrNot =
3281 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3282 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3283 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3284 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3285
3286 if (!allLayerNormWeightsPresentOrNot)
3287 {
3288 throw InvalidArgumentException(descriptorName +
3289 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3290 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3291 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3292 "only be present when Layer Norm is enabled and CIFG is disabled. "
3293 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3294 }
3295
3296 if (m_Parameters.m_LayerNormEnabled)
3297 {
3298 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3299 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3300 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3301
3302 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3303 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3304 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3305 "forgetLayerNormWeights", "cellLayerNormWeights");
3306
3307 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3308 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3309 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3310 "forgetLayerNormWeights", "outputLayerNormWeights");
3311
3312 if (!m_Parameters.m_CifgEnabled)
3313 {
3314 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3315 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3316 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3317 "forgetLayerNormWeights", "inputLayerNormWeights");
3318 }
3319 }
3320
3321 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3322 bool correctProjectionTensorsPresent =
3323 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3324 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3325 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3326
3327 if (!correctProjectionTensorsPresent)
3328 {
3329 throw InvalidArgumentException(descriptorName +
3330 ": If projection is enabled, ProjectionWeights should be present and "
3331 "ProjectionBias is optional. If projection is disabled, neither "
3332 "ProjectionWeights nor ProjectionBias should be present.");
3333 }
3334
3335 if (m_Parameters.m_ProjectionEnabled)
3336 {
3337 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3338 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3339 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3340
3341 if (m_ProjectionBias)
3342 {
3343 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003344 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003345 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3346 }
3347
3348 }
3349 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3350 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3351 throw InvalidArgumentException(descriptorName +
3352 ": If projection is disabled, output quantization info (scale, offset) "
3353 "should match HiddenStateScale and HiddenStateZeroPoint.");
3354 }
3355
3356}
3357
James Conroy9c3cae82019-08-01 16:01:48 +01003358void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3359{
3360 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3361
3362 // Validate number of inputs/outputs
3363 ValidateNumInputs(workloadInfo, descriptorName, 3);
3364 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3365
3366 // Input/output tensor infos
3367 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3368 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3369 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3370
3371 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3372 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3373
3374 std::vector<DataType> inputOutputSupportedTypes =
3375 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003376 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003377 };
3378
3379 std::vector<DataType> cellStateSupportedTypes =
3380 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003381 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003382 };
3383
3384 std::vector<DataType> weightsSupportedTypes =
3385 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003386 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003387 };
3388
3389 std::vector<DataType> biasSupportedTypes =
3390 {
3391 DataType::Signed32
3392 };
3393
3394 // Validate types of input/output tensors
3395 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3396 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3397 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3398
3399 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3400 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3401
3402 // Validate matching types of input/output tensors
3403 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3404 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3405 "outputStateIn", "outputStateOut");
3406 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3407
3408 // Validate matching quantization info for input/output tensors
3409 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3410 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3411 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003412
James Conroy9c3cae82019-08-01 16:01:48 +01003413 // Infer number of batches, input size and output size from tensor dimensions
3414 const uint32_t numBatches = inputInfo.GetShape()[0];
3415 const uint32_t inputSize = inputInfo.GetShape()[1];
3416 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3417
3418 // Validate number of dimensions and number of elements for input/output tensors
3419 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3420 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3421 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3422 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3423 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3424
3425 // Validate number of dimensions and number of elements for weights tensors
3426 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3427 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3428 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3429
3430 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3431 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3432 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3433
3434 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3435 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3436 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3437
3438 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3439 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3440 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3441
3442 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3443 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3444 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3445
3446 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3447 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3448 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3449 " RecurrentToForgetWeights");
3450
3451 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3452 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3453 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3454
3455 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3456 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3457 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3458
3459 // Validate data types for weights tensors (all should match each other)
3460 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3461
3462 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3463 "inputToInputWeights", "inputToForgetWeights");
3464 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3465 "inputToInputWeights", "inputToCellWeights");
3466 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3467 "inputToInputWeights", "inputToOutputWeights");
3468
3469 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3470 "inputToInputWeights", "recurrentToInputWeights");
3471 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3472 "inputToInputWeights", "recurrentToForgeteights");
3473 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3474 "inputToInputWeights", "recurrentToCellWeights");
3475 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3476 "inputToInputWeights", "recurrentToOutputWeights");
3477
3478 // Validate matching quantization info for weight tensors (all should match each other)
3479 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3480 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3481 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3482 descriptorName, "inputToInputWeights", "inputToCellWeights");
3483 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3484 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3485
3486 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3487 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3488 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3489 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3490 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3491 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3492 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3493 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3494
3495 // Validate number of dimensions and number of elements in bias tensors
3496 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3497 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3498 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3499
3500 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3501 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3502 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3503
3504 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3505 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3506 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3507
3508 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3509 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3510 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3511
3512 // Validate data types for bias tensors (all should match each other)
3513 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3514
3515 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3516 "inputGateBias", "forgetGateBias");
3517 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3518 "inputGateBias", "cellBias");
3519 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3520 "inputGateBias", "outputGateBias");
3521
3522 // Validate bias tensor quantization info
3523 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3524 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3525 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3526 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3527}
3528
Kevin May868eb142019-09-04 17:29:31 +01003529void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3530{
3531 const std::string descriptorName{"AbsQueueDescriptor"};
3532
3533 ValidateNumInputs(workloadInfo, descriptorName, 1);
3534 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3535
3536 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3537 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3538
3539 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3540
3541 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003542 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003543 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003544 DataType::Float16,
3545 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003546 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003547 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003548 DataType::QSymmS16,
3549 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003550 };
Kevin May868eb142019-09-04 17:29:31 +01003551
3552 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3553 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3554}
3555
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003556void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3557{
3558 const std::string descriptorName{"SliceQueueDescriptor"};
3559
3560 ValidateNumInputs(workloadInfo, descriptorName, 1);
3561 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3562
3563 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3564 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3565
3566 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3567
3568 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3569 if (rank > 4)
3570 {
3571 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3572 }
3573
3574 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3575
3576 // Check if m_Begin and m_Size have the expected length
3577 if (m_Parameters.m_Begin.size() != rank)
3578 {
3579 throw InvalidArgumentException(descriptorName +
3580 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3581 }
3582 if (m_Parameters.m_Size.size() != rank)
3583 {
3584 throw InvalidArgumentException(descriptorName +
3585 ": Length of size descriptor must equal rank " + std::to_string(rank));
3586 }
3587
3588 // Check if the shape of the output tensor matches m_Size
3589 const TensorShape& outputShape = outputTensorInfo.GetShape();
3590 for (unsigned int i = 0u; i < rank; ++i)
3591 {
3592 if (m_Parameters.m_Size[i] != outputShape[i])
3593 {
3594 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3595 }
3596 }
3597
3598 // Check if the sum of begin offset and size in a given dimension
3599 // does not exceed the size of corresponding input
3600 const TensorShape& inputShape = inputTensorInfo.GetShape();
3601 for(unsigned int i = 0u; i < rank; ++i)
3602 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003603 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003604 {
3605 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3606 std::to_string(i) + " exceeds input size.");
3607 }
3608 }
3609}
3610
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003611void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3612{
3613 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3614
3615 ValidateNumInputs(workloadInfo, descriptorName, 1);
3616 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3617
3618 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3619 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3620
3621 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3622 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3623
3624 std::vector<DataType> supportedTypes =
3625 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003626 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003627 DataType::Float32,
3628 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003629 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003630 DataType::QAsymmU8,
3631 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003632 };
3633
3634 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3635 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3636
3637 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3638
3639 if (m_Parameters.m_BlockSize == 0)
3640 {
3641 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3642 }
3643
3644 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3645 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3646 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3647 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3648
3649 const TensorShape& outputShape = outputInfo.GetShape();
3650 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3651 {
3652 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3653 "must be divisible by block size.");
3654 }
3655
3656 const TensorShape& inputShape = inputInfo.GetShape();
3657 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3658 {
3659 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3660 "must be divisible by the square of block size." );
3661 }
3662}
3663
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003664void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3665{
3666 const std::string descriptorName{"ComparisonQueueDescriptor"};
3667
3668 ValidateNumInputs(workloadInfo, descriptorName, 2);
3669 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3670
3671 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3672 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3673 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3674
3675 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3676 inputTensorInfo1,
3677 outputTensorInfo,
3678 descriptorName,
3679 "input_0",
3680 "input_1");
3681
3682 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3683 {
3684 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3685 }
3686}
3687
josh minor4a3c6102020-01-06 16:40:46 -06003688void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3689{
3690 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3691
3692 ValidateNumInputs(workloadInfo, descriptorName, 1);
3693 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3694
3695 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3696 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3697
3698 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3699
3700 std::vector<DataType> supportedTypes =
3701 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003702 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003703 DataType::Float16,
3704 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003705 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003706 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003707 DataType::QSymmS16,
3708 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003709 };
3710
James Conroyaba90cd2020-11-06 16:28:18 +00003711 std::vector<DataType> logicalSupportedTypes =
3712 {
3713 DataType::Boolean
3714 };
3715
3716 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3717 {
3718 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3719 }
3720 else
3721 {
3722 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3723 }
3724
3725
josh minor4a3c6102020-01-06 16:40:46 -06003726 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3727}
3728
Finn Williams2605b232020-06-10 15:53:46 +01003729void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3730{
3731 const std::string descriptorName{"RankQueueDescriptor"};
3732
3733 ValidateNumInputs(workloadInfo, descriptorName, 1);
3734 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3735
3736 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3737 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3738
3739 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3740 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3741
3742 std::vector<DataType> supportedTypes =
3743 {
3744 DataType::BFloat16,
3745 DataType::Float16,
3746 DataType::Float32,
3747 DataType::QAsymmS8,
3748 DataType::QAsymmU8,
3749 DataType::QSymmS8,
3750 DataType::QSymmS16,
3751 DataType::Signed32
3752 };
3753
3754 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3755 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3756}
3757
James Conroyaba90cd2020-11-06 16:28:18 +00003758void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3759{
3760 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3761
3762 ValidateNumInputs(workloadInfo, descriptorName, 2);
3763 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3764
3765 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3766 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3767 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3768
3769 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3770 inputTensorInfo1,
3771 outputTensorInfo,
3772 descriptorName,
3773 "input_0",
3774 "input_1");
3775
3776 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3777 {
3778 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3779 }
3780
3781 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3782 {
3783 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3784 }
3785
3786 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3787 {
3788 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3789 }
3790}
3791
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003792void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3793{
3794 const std::string descriptorName{"ReduceQueueDescriptor"};
3795
3796 ValidateNumInputs(workloadInfo, descriptorName, 1);
3797 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3798
3799 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3800 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3801
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003802 std::vector<DataType> supportedTypes =
3803 {
3804 DataType::BFloat16,
3805 DataType::Float16,
3806 DataType::Float32,
3807 DataType::QAsymmS8,
3808 DataType::QAsymmU8,
3809 DataType::QSymmS16,
3810 DataType::Signed32
3811 };
3812
3813 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3814 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3815}
3816
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003817void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3818{
3819 // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm
3820
3821 const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"};
3822
3823 // check dimensions of all inputs and outputs
3824 if (workloadInfo.m_InputTensorInfos.size() != 3)
3825 {
3826 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
3827 }
3828 if (workloadInfo.m_OutputTensorInfos.size() != 1)
3829 {
3830 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
3831 }
3832
3833 std::vector<DataType> supportedTypes =
3834 {
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01003835 DataType::Float32
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003836 };
3837
3838 // check for supported type of one input and match them with all the other input and output
3839 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
3840
3841 // type matches all other inputs
3842 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
3843 {
3844 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
3845 workloadInfo.m_InputTensorInfos[i],
3846 descriptorName,
3847 "input_0",
3848 "input_" + std::to_string(i));
3849 }
3850 // type matches all other outputs
3851 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
3852 {
3853 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
3854 workloadInfo.m_OutputTensorInfos[i],
3855 "LstmQueueDescriptor",
3856 "input_0",
3857 "output_" + std::to_string(i));
3858 }
3859
3860 // Making sure clipping parameters have valid values.
3861 // == 0 means no clipping
3862 // > 0 means clipping
3863 if (m_Parameters.m_ClippingThresCell < 0.0f)
3864 {
3865 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
3866 }
3867 if (m_Parameters.m_ClippingThresProj < 0.0f)
3868 {
3869 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
3870 }
3871
3872 unsigned int batchIndx = 0;
3873 unsigned int inputIndx = 1;
3874 uint32_t timeStep = 1;
3875 unsigned int timeIndx = 1;
3876 inputIndx = 2;
3877 if (m_Parameters.m_TimeMajor)
3878 {
3879 batchIndx = 1;
3880 timeIndx = 0;
3881
3882 }
3883 timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx];
3884
3885 // Inferring batch size, number of outputs and number of cells from the inputs.
3886 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx];
3887 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx];
3888 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
3889 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
3890 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
3891 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
3892
3893 // input tensor
3894 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input),
3895 descriptorName + " input_0");
3896 // outputStateInTensor
3897 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
3898 descriptorName + " input_1");
3899 // outputStateInTensor
3900 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
3901 descriptorName + " input_2");
3902
3903 // outputTensor
3904 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 3, (timeStep * n_batch * n_output),
3905 descriptorName + " output_0");
3906
3907 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
3908 if ( m_InputToInputWeights )
3909 {
3910 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
3911 (n_cell * n_input), "InputLayerNormWeights");
3912 }
3913
3914 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
3915 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
3916 (n_cell * n_input), "InputToForgetWeights");
3917
3918 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
3919 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
3920 (n_cell * n_input), "InputToCellWeights");
3921
3922 if ( m_RecurrentToInputWeights )
3923 {
3924 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
3925 (n_cell * n_output), "RecurrentToInputWeights");
3926 }
3927
3928 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
3929 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
3930 (n_cell * n_output), "RecurrentToForgetWeights");
3931
3932 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
3933 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
3934 (n_cell * n_output), "RecurrentToCellWeights");
3935
3936 // Make sure the input-gate's parameters are either both present (regular
3937 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
3938 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
3939 !m_Parameters.m_CifgEnabled) ||
3940 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3941 m_Parameters.m_CifgEnabled));
3942 if (!cifg_weights_all_or_none)
3943 {
3944 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
3945 "RecurrentToInputWeights must either both be present (regular LSTM) "
3946 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
3947 "accordingly.");
3948 }
3949
3950 if ( m_CellToInputWeights )
3951 {
3952 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
3953 n_cell, "CellToInputWeights");
3954 }
3955 if ( m_CellToForgetWeights )
3956 {
3957 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
3958 n_cell, "CellToForgetWeights");
3959 }
3960 if ( m_CellToOutputWeights )
3961 {
3962 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
3963 n_cell, "CellToOutputWeights");
3964 }
3965
3966 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
3967 bool peephole_weights_all_or_none =
3968 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3969 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3970 || ( !m_CellToInputWeights && !m_CellToForgetWeights
3971 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3972 if (!peephole_weights_all_or_none)
3973 {
3974 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
3975 }
3976
3977 // Make sure the input gate bias is present only when not a CIFG-LSTM.
3978 if (m_Parameters.m_CifgEnabled)
3979 {
3980 if (m_InputGateBias)
3981 {
3982 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
3983 }
3984 }
3985 else
3986 {
3987 if (!m_InputGateBias)
3988 {
3989 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
3990 "must be present.");
3991 }
3992 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
3993 n_cell, "InputGateBias");
3994 }
3995
3996 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
3997 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
3998
3999 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
4000 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
4001
4002 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
4003 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
4004
4005 if (m_ProjectionWeights)
4006 {
4007 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
4008 (n_cell * n_output), "ProjectionWeights");
4009 }
4010 if (m_ProjectionBias)
4011 {
4012 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
4013 }
4014
4015 // Making sure the projection tensors are consistent:
4016 // 1) If projection weight is not present, then projection bias should not be
4017 // present.
4018 // 2) If projection weight is present, then projection bias is optional.
4019 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
4020 !m_Parameters.m_ProjectionEnabled)
4021 || (m_ProjectionWeights && !m_ProjectionBias &&
4022 m_Parameters.m_ProjectionEnabled)
4023 || (m_ProjectionWeights && m_ProjectionBias &&
4024 m_Parameters.m_ProjectionEnabled));
4025 if (!projecton_tensors_consistent)
4026 {
4027 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
4028 }
4029
4030 // The four layer normalization weights either all have values or none of them have values. Additionally, if
4031 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
4032 // either all have values or none of them have values. Layer normalization is used when the values of all the
4033 // layer normalization weights are present
4034 if (m_InputLayerNormWeights)
4035 {
4036 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
4037 }
4038 if (m_ForgetLayerNormWeights)
4039 {
4040 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4041 }
4042 if (m_CellLayerNormWeights)
4043 {
4044 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4045 }
4046 if (m_OutputLayerNormWeights)
4047 {
4048 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4049 }
4050
4051 if (m_Parameters.m_LayerNormEnabled)
4052 {
4053 if (!m_Parameters.m_CifgEnabled)
4054 {
4055 if (!m_InputLayerNormWeights)
4056 {
4057 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
4058 "disabled but InputLayerNormWeights are not present");
4059 }
4060 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
4061 1, n_cell, "InputLayerNormWeights");
4062 }
4063 else if (m_InputLayerNormWeights)
4064 {
4065 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
4066 "enabled");
4067 }
4068
4069 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
4070 "ForgetLayerNormWeights");
4071 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4072
4073 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
4074 "OutputLayerNormWeights");
4075 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4076
4077 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
4078 "CellLayerNormWeights");
4079 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4080 }
4081 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
4082 {
4083 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
4084 "normalisation weights are present.");
4085 }
4086}
4087
4088
mathad01df9a3222021-04-28 11:42:57 +01004089} // namespace armnn