blob: 2716c827aff2735d584d3bac9dc951ffa1a8f949 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
James Conroy1f58f032021-04-27 17:13:27 +01006#include <backendsCommon/TensorHandle.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00007#include <backendsCommon/WorkloadData.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/TensorUtils.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010010#include <armnn/utility/NumericCast.hpp>
mathad01df9a3222021-04-28 11:42:57 +010011#include <armnn/Logging.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000012
telsoa014fcda012018-03-09 14:13:49 +000013#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000015#include <string>
16#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000017
James Ward47fce872020-09-10 11:57:28 +010018#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000019
Matteo Martincigh21350152018-11-28 16:22:22 +000020using namespace armnnUtils;
21
telsoa014fcda012018-03-09 14:13:49 +000022namespace armnn
23{
24
25//---------------------------------------------------------------
26DataType GetBiasDataType(DataType inputDataType)
27{
28 switch (inputDataType)
29 {
telsoa01c577f2c2018-08-31 09:22:23 +010030 case DataType::Float16:
31 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000032 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000033 case DataType::Float32:
34 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000035 case DataType::QAsymmS8:
36 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000037 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000038 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000039 case DataType::QSymmS8:
40 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000041 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010042 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000043 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010044 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000045 return DataType::Float32;
46 }
47}
48
49namespace
50{
51
52//---------------------------------------------------------------
53//android ndk does not support std::to_string function.
54template <typename T>
55std::string to_string(T value)
56{
57 std::ostringstream os;
58 os << value;
59 return os.str();
60}
61
62//---------------------------------------------------------------
63void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
64{
65 if (!ptr)
66 {
67 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
68 paramName + " parameter must be set.");
69 }
70}
71
72//---------------------------------------------------------------
73void ValidateTensorShapesMatch(const TensorInfo& first,
74 const TensorInfo& second,
75 std::string const& descName,
76 std::string const& firstName,
77 std::string const& secondName)
78{
79 if (first.GetShape() != second.GetShape())
80 {
81 throw InvalidArgumentException(descName + ": "
82 + firstName + " & " + secondName + " must have identical shapes");
83 }
84}
85
86//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010087void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000088{
Sadik Armaganeff363d2019-04-05 15:25:46 +010089 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000090 {
91 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010092 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000093 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
94 }
95}
96
97//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010098void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000099{
Sadik Armaganeff363d2019-04-05 15:25:46 +0100100 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000101 {
102 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100103 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000104 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
105 }
106}
107
108//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100109void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000110 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100111 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000112 std::string const& tensorName)
113{
114 if (tensor.GetNumDimensions() != numDimensions)
115 {
116 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
117 to_string(tensor.GetNumDimensions()) + " dimensions for " +
118 tensorName + " tensor.");
119 }
120}
121
122//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100123void ValidateTensorNumElements(const TensorInfo& tensor,
124 std::string const& descName,
125 unsigned int numElements,
126 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100127{
128 if (tensor.GetNumElements() != numElements)
129 {
130 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100131 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100132 tensorName + " tensor.");
133 }
134}
135
136//---------------------------------------------------------------
137void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100138 unsigned int numDimension,
139 unsigned int numElements,
140 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100141{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100142 const std::string functionName{"ValidateTensorNumDimNumElem"};
143 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
144 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100145}
146
147//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000148void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
149 const std::string& descName, std::string const& tensorName)
150{
151 if (tensor.GetDataType() != dataType)
152 {
153 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
154 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
155 }
156}
157
Derek Lambertid466a542020-01-22 15:37:29 +0000158void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
159{
Jan Eilers1b2654f2021-09-24 15:45:46 +0100160 if (tensor.GetDataType() != DataType::QSymmS8)
Derek Lambertid466a542020-01-22 15:37:29 +0000161 {
162 throw InvalidArgumentException(descName +
163 ": Expected data type which supports per-axis quantization scheme but got " +
164 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
165 }
Derek Lambertid466a542020-01-22 15:37:29 +0000166}
167
telsoa014fcda012018-03-09 14:13:49 +0000168//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100169void ValidateTensorQuantizationSpace(const TensorInfo& first,
170 const TensorInfo& second,
171 const std::string& descName,
172 std::string const& firstName,
173 std::string const& secondName)
174{
175 if (!first.IsQuantized() ||
176 !second.IsQuantized())
177 {
178 // Not a quantized type, ignore the validation
179 return;
180 }
181
182 DataType firstDataType = first.GetDataType();
183 DataType secondDataType = second.GetDataType();
184
185 if (firstDataType != secondDataType)
186 {
187 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
188 " must be of the same quantized type, " +
189 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
190 secondName + " is " + GetDataTypeName(secondDataType));
191 }
192
193 if (!first.IsTypeSpaceMatch(second))
194 {
195 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
196 " must have the same quantization space, " +
197 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
198 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
199 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
200 " and scale " + to_string(second.GetQuantizationScale()));
201 }
202}
203
204//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100205void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
206 const TensorInfo& inputTensorInfo,
207 const TensorInfo& weightsTensorInfo,
208 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000209{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000210 // Helper lambda function to validate a single bias quantization scale value
211 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
212 {
mathad01df9a3222021-04-28 11:42:57 +0100213 constexpr float tolerance = 0.0001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000214 if (std::abs(biasScale - expectedScale) > tolerance)
215 {
216 // Print the float values with extra precision to see very small differences
mathad01df9a3222021-04-28 11:42:57 +0100217 ARMNN_LOG(warning) << std::setprecision(6) << descName << ": Expected " << expectedScale <<
218 " for bias quantization scale (product of input and weight scales), but got " <<
219 biasScale << ". Using scale provided.";
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000220 }
221 };
222
telsoa014fcda012018-03-09 14:13:49 +0000223 if (biasTensor.GetQuantizationOffset() != 0)
224 {
225 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
226 to_string(biasTensor.GetQuantizationOffset()));
227 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000228
James Conroy8502ade2020-11-12 19:26:29 +0000229 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000230 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000231 // Validate per-axis quantization scales
232 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
233 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
234
235 if (weightScales.size() != biasScales.size())
236 {
237 std::stringstream msg;
James Conroy8502ade2020-11-12 19:26:29 +0000238 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
239 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
240 << ", biases=" << biasScales.size();
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000241 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
242 }
243
244 for (size_t i = 0ul; i < biasScales.size(); ++i)
245 {
246 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
247 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
248 }
249 }
250 else
251 {
252 // Validate per-tensor quantization scale
253 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
254 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000255 }
256}
257
258//---------------------------------------------------------------
259void ValidateTensors(const std::vector<ITensorHandle*>& vec,
260 unsigned int numExpected,
261 const std::string& descName,
262 const std::string& varName)
263{
264 if (vec.empty() && numExpected > 0)
265 {
266 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
267 }
268
269 for (unsigned int i = 0; i < numExpected; ++i)
270 {
271 if (!vec[i])
272 {
273 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
274 }
275 }
276}
277
278//---------------------------------------------------------------
279void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
280 const TensorInfo& second,
281 const TensorInfo& output,
282 std::string const& descName,
283 std::string const& firstName,
284 std::string const& secondName)
285{
286 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
287 // broadcasted.
288 if (first.GetNumDimensions() != second.GetNumDimensions())
289 {
290 throw InvalidArgumentException(descName + ": Tensors "
291 + firstName + " & " + secondName
292 + " must have the same number of dimensions in order to be broadcasted");
293 }
294 uint32_t numDims = first.GetNumDimensions();
295 std::vector<uint32_t> outputDims(numDims, 0u);
296 for (uint32_t i = 0; i < numDims; i++)
297 {
298 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
299 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
300 if (dimsNotEqual && dimsNotOne)
301 {
302 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
303 }
304 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
305 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100306 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000307 if (broadcastShape != output.GetShape())
308 {
309 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
310 + firstName + " & " + secondName
311 + " does not match the output shape");
312 }
313}
314
315//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100316void ValidateDataTypes(const TensorInfo& info,
317 const std::vector<armnn::DataType>& supportedTypes,
318 std::string const& descName)
319{
320 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
321 if (iterator == supportedTypes.end())
322 {
323 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
324 }
325}
326
James Conroy4d1ff582019-06-10 17:06:39 +0100327//---------------------------------------------------------------
328void ValidateTensorDataTypesMatch(const TensorInfo& first,
329 const TensorInfo& second,
330 std::string const& descName,
331 std::string const& firstName,
332 std::string const& secondName)
333{
334 if (first.GetDataType() != second.GetDataType())
335 {
336 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
337 " must have identical data types.");
338 }
339}
340
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100341//---------------------------------------------------------------
342void ValidateTensorNumElementsMatch(const TensorInfo& first,
343 const TensorInfo& second,
344 std::string const& descName,
345 std::string const& firstName,
346 std::string const& secondName)
347{
348 if (first.GetNumElements() != second.GetNumElements())
349 {
350 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
351 " must have the same number of elements.");
352 }
353}
354
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000355void ValidateWeightDataType(const TensorInfo& inputInfo,
356 const TensorInfo& weightInfo,
357 const std::string& descName)
358{
359 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000360 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000361 {
362 const std::vector<DataType> validTypes =
363 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000364 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100365 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +0100366 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000367 };
368
369 ValidateDataTypes(weightInfo, validTypes, descName);
370 }
371 else
372 {
373 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
374 }
375}
376
377void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
378 const std::string& descName,
379 const std::string& tensorName)
380{
381 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
382 if (!quantizationDim.has_value())
383 {
James Ward47fce872020-09-10 11:57:28 +0100384 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
385 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000386 }
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000387}
388
389void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
390 const std::string& descName,
391 const std::string& tensorName)
392{
393 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
394 if (quantizationOffset != 0)
395 {
James Ward47fce872020-09-10 11:57:28 +0100396 throw InvalidArgumentException(fmt::format(
397 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
398 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000399 }
400}
401
402void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
403 const TensorInfo& outputInfo,
404 const TensorInfo& weightInfo,
405 const Optional<TensorInfo>& optionalBiasInfo,
406 const std::string& descName)
407{
408 if (weightInfo.HasPerAxisQuantization())
409 {
410 const DataType inputDataType = inputInfo.GetDataType();
411 const DataType outputDataType = outputInfo.GetDataType();
412
Keith Davis0c2eeac2020-02-11 16:51:50 +0000413 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000414
415 if (!canHavePerAxisQuantization)
416 {
James Ward47fce872020-09-10 11:57:28 +0100417 throw InvalidArgumentException(fmt::format(
418 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
419 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000420 }
421
Derek Lambertid466a542020-01-22 15:37:29 +0000422
423 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000424 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
425 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
426
427 if (optionalBiasInfo.has_value())
428 {
429 const TensorInfo& biasInfo = optionalBiasInfo.value();
430 if (!biasInfo.HasPerAxisQuantization())
431 {
James Ward47fce872020-09-10 11:57:28 +0100432 throw InvalidArgumentException(fmt::format(
433 "{}: Per-axis quantization parameters not set on bias tensor, "
434 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000435 }
436
437 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
438 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
439 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
440 }
441 }
442}
443
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100444} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000445
446void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
447 unsigned int numExpectedIn, unsigned int numExpectedOut) const
448{
449 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
450 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
451}
452
453//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100454void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
455{
456 const std::string descriptorName{"MapQueueDescriptor"};
457
458 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100459 ValidateNumOutputs(workloadInfo, descriptorName, 0);
460
461 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
462 {
463 if (!m_Inputs[i])
464 {
465 throw InvalidArgumentException(
466 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
467 }
468 }
469}
470
471//---------------------------------------------------------------
472void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
473{
474 const std::string descriptorName{"UnmapQueueDescriptor"};
475
476 ValidateNumInputs(workloadInfo, descriptorName, 1);
477 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100478
479 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
480 {
481 if (!m_Inputs[i])
482 {
483 throw InvalidArgumentException(
484 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
485 }
486 }
487}
488
489//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000490void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
491{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100492 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000493
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100494 ValidateNumInputs(workloadInfo, descriptorName, 1);
495 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000496
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100497 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
498 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
499
500 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
501 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000502
503 if (m_Inputs.size() != m_Outputs.size())
504 {
James Ward47fce872020-09-10 11:57:28 +0100505 throw InvalidArgumentException(fmt::format(
506 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
507 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000508 }
509
510 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
511 {
512 if (!m_Inputs[i])
513 {
James Ward47fce872020-09-10 11:57:28 +0100514 throw InvalidArgumentException(fmt::format(
515 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000516 }
517
518 if (!m_Outputs[i])
519 {
James Ward47fce872020-09-10 11:57:28 +0100520 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000521 }
522 }
523}
524
Derek Lambertif674aa02019-08-01 15:56:25 +0100525//---------------------------------------------------------------
526void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
527{
528 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
529 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
530
531 if (workloadInfo.m_InputTensorInfos.size() != 1)
532 {
James Ward47fce872020-09-10 11:57:28 +0100533 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
534 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100535
536 }
537
538 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
539 {
James Ward47fce872020-09-10 11:57:28 +0100540 throw InvalidArgumentException(fmt::format(
541 "Number of input infos ({0}) does not match the number of output infos ({1})",
542 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100543 }
544
545 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
546 {
547 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
548 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
549 {
James Ward47fce872020-09-10 11:57:28 +0100550 throw InvalidArgumentException(fmt::format(
551 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100552 }
553 }
554
555 if (m_Inputs.size() != 1)
556 {
James Ward47fce872020-09-10 11:57:28 +0100557 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100558 }
559
560 if (m_Inputs.size() != m_Outputs.size())
561 {
James Ward47fce872020-09-10 11:57:28 +0100562 throw InvalidArgumentException(fmt::format(
563 "Number of inputs ({0}) does not match the number of outputs ({1})",
564 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100565 }
566
567 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
568 {
569 if (!m_Inputs[i])
570 {
James Ward47fce872020-09-10 11:57:28 +0100571 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100572 }
573
574 if (!m_Outputs[i])
575 {
James Ward47fce872020-09-10 11:57:28 +0100576 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100577 }
578 }
579}
580
581//---------------------------------------------------------------
582void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
583{
584 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
585 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
586
Derek Lambertif674aa02019-08-01 15:56:25 +0100587 if (m_Inputs.size() != 1)
588 {
James Ward47fce872020-09-10 11:57:28 +0100589 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100590 }
591
592 if (m_Outputs.size() != 0)
593 {
James Ward47fce872020-09-10 11:57:28 +0100594 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100595 }
596
597 if (!m_Inputs[0])
598 {
James Ward47fce872020-09-10 11:57:28 +0100599 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100600 }
601}
602
603//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000604void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
605{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100606 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100607
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100608 ValidateNumInputs(workloadInfo, descriptorName, 1);
609 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100610
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100611 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
612 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100613
614 std::vector<DataType> supportedTypes =
615 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000616 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100617 DataType::Float16,
618 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000619 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000620 DataType::QAsymmU8,
621 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100622 };
623
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100624 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
625 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
626 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000627}
628
Nikhil Rajee391d52019-09-05 17:50:44 +0100629void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
630{
631 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
632
633 ValidateNumInputs(workloadInfo, descriptorName, 1);
634 ValidateNumOutputs(workloadInfo, descriptorName, 1);
635
636 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
637 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
638
Inki Daed4619e22020-09-10 15:33:54 +0900639 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
640 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100641 {
Inki Daed4619e22020-09-10 15:33:54 +0900642 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100643 }
644
James Conroyd47a0642019-09-17 14:22:06 +0100645 std::vector<DataType> supportedInputTypes =
646 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000647 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100648 DataType::Float16,
649 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100650 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000651 DataType::QAsymmU8,
652 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900653 DataType::Signed32,
654 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100655 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100656
James Conroyd47a0642019-09-17 14:22:06 +0100657 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100658
659 auto inputShape = inputTensorInfo.GetShape();
660 auto outputShape = outputTensorInfo.GetShape();
661
662 auto inputNumDimensions = inputShape.GetNumDimensions();
663 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
664
665 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
666
667 // 1D input shape results in scalar output shape
668 if (inputShape.GetNumDimensions() == 1)
669 {
670 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
671 {
672 throw InvalidArgumentException(descriptorName + outputShapeError);
673 }
674 }
675 else
676 {
677 for (unsigned int i = 0; i < unsignedAxis; ++i)
678 {
679 if (outputShape[i] != inputShape[i])
680 {
681 throw InvalidArgumentException(descriptorName + outputShapeError);
682 }
683 }
684
685 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
686 {
687 if (outputShape[i - 1] != inputShape[i])
688 {
689 throw InvalidArgumentException(descriptorName + outputShapeError);
690 }
691 }
692 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100693}
694
mathad01b392e982021-04-07 12:07:30 +0100695void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
696{
697 const std::string descriptorName{"CastQueueDescriptor"};
698
699 ValidateNumInputs(workloadInfo, descriptorName, 1);
700 ValidateNumOutputs(workloadInfo, descriptorName, 1);
701
702 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
703 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
704
705 std::vector<DataType> supportedTypes =
706 {
707 DataType::BFloat16,
708 DataType::Float16,
709 DataType::Float32,
710 DataType::QAsymmS8,
711 DataType::QAsymmU8,
712 DataType::QSymmS8,
713 DataType::QSymmS16,
714 DataType::Signed32,
715 DataType::Signed64
716 };
717
718 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
719 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
720}
721
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100722void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
723{
724 const std::string descriptorName{"SoftmaxQueueDescriptor"};
725
726 ValidateNumInputs(workloadInfo, descriptorName, 1);
727 ValidateNumOutputs(workloadInfo, descriptorName, 1);
728
729 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
730 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
731
732 std::vector<DataType> supportedTypes =
733 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000734 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100735 DataType::Float16,
736 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000737 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000738 DataType::QAsymmU8,
739 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100740 };
741
742 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
743 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
744 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
745}
746
telsoa014fcda012018-03-09 14:13:49 +0000747void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
748{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100749 const std::string descriptorName{"SplitterQueueDescriptor"};
750
751 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000752
Ruomei Yan25339c32019-05-28 16:48:20 +0100753 // Check the supported data types
754 std::vector<DataType> supportedTypes =
755 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000756 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100757 DataType::Float32,
758 DataType::Float16,
759 DataType::Boolean,
760 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100761 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000762 DataType::QAsymmU8,
763 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100764 };
765
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100766 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
767 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100768 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100769 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
770 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
771
772 const std::string outputName = "output_" + std::to_string(i);
773 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100774 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100775
telsoa014fcda012018-03-09 14:13:49 +0000776 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
777 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100778 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000779 }
780
781 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
782 {
783 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100784 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000785 "has to match number of workloadInfo.m_OutputTensorInfos. "
786 "Number of windows: " +
787 to_string(m_ViewOrigins.size()) +
788 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
789 }
790
telsoa01c577f2c2018-08-31 09:22:23 +0100791 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000792 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
793 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
794 {
telsoa01c577f2c2018-08-31 09:22:23 +0100795 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000796 ViewOrigin const& e = m_ViewOrigins[w];
797 if (e.m_Origin.size() != inputDims)
798 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100799 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000800 "have the same dimensionality as the input tensor. "
801 "Window origin (index: " +
802 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
803 " dimensions, the input "
804 "tensor has " +
805 to_string(inputDims) + " dimensions.");
806 }
807 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
808 {
809 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
810 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
811 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100812 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000813 "be smaller or equal than the size of the input in that coord.");
814 }
815 }
816 }
817}
818
Jim Flynne242f2d2019-05-22 14:24:13 +0100819void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000820{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100821 const std::string descriptorName{"ConcatQueueDescriptor"};
822
823 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000824
825 if (m_Inputs.size() <= 0)
826 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100827 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000828 }
829 if (m_Outputs.size() <= 0)
830 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100831 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000832 }
833
834 if (workloadInfo.m_InputTensorInfos.size() <= 0)
835 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100836 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000837 }
838 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
839 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100840 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000841 }
842
Nikhil Raj8599a412018-11-19 14:51:07 +0000843 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
844 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100845 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000846 }
847
848 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
849 {
850 return;
851 }
852
telsoa014fcda012018-03-09 14:13:49 +0000853 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
854 {
855 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100856 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000857 "has to match number of workloadInfo.m_InputTensorInfos. "
858 "Number of windows: " +
859 to_string(m_ViewOrigins.size()) +
860 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
861 }
862
telsoa01c577f2c2018-08-31 09:22:23 +0100863 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000864 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
865 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
866 {
telsoa01c577f2c2018-08-31 09:22:23 +0100867 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000868 ViewOrigin const& e = m_ViewOrigins[w];
869 if (e.m_Origin.size() != outputDims)
870 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100871 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000872 "have the same dimensionality as the output tensor. "
873 "Window origin (index: " +
874 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
875 " dimensions, the output "
876 "tensor has " +
877 to_string(outputDims) + " dimensions.");
878 }
telsoa01c577f2c2018-08-31 09:22:23 +0100879 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000880 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
881 {
882 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
883 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
884 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100885 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000886 "be smaller or equal than the size of the output in that coord.");
887 }
888 }
889 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100890
891 // Check the supported data types
892 std::vector<DataType> supportedTypes =
893 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000894 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100895 DataType::Float32,
896 DataType::Float16,
897 DataType::Boolean,
898 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100899 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000900 DataType::QAsymmU8,
901 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100902 };
903
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100904 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
905 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100906 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100907 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
908 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
909
910 const std::string inputName = "input_" + std::to_string(i);
911 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100912 }
telsoa014fcda012018-03-09 14:13:49 +0000913}
914
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100915void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
916{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100917 const std::string descriptorName{"StackQueueDescriptor"};
918
919 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100920
921 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
922 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100923 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100924 }
925
926 // All inputs must have the same shape, which is defined in parameters
927 const TensorShape& inputShape = m_Parameters.m_InputShape;
928 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
929 {
930 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
931 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100932 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100933 }
934 }
935
Matthew Jacksondba634f2019-08-15 15:14:18 +0100936 if (inputShape.GetNumDimensions() > 4)
937 {
938 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
939 }
940
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100941 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
942 // since the output tensor has an additional dimension.
943 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
944 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100945 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100946 "than the number of input dimensions.");
947 }
948
949 // Output shape must be as inferred from the input shape
950 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
951 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
952 {
953 if (outputShape[i] != inputShape[i])
954 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100955 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100956 "match shape inferred from input tensor.");
957 }
958 }
959
960 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
961 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100962 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100963 "match shape inferred from input tensor.");
964 }
965
966 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
967 {
968 if (outputShape[i] != inputShape[i-1])
969 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100970 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100971 "match shape inferred from input tensor.");
972 }
973 }
974
Matthew Jacksondba634f2019-08-15 15:14:18 +0100975 if (outputShape.GetNumDimensions() > 5)
976 {
977 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
978 }
979
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100980 // Check the supported data types
981 std::vector<DataType> supportedTypes =
982 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000983 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100984 DataType::Float32,
985 DataType::Float16,
986 DataType::Boolean,
987 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100988 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000989 DataType::QAsymmU8,
990 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100991 };
992
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100993 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100994
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100995 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100996 {
997 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
998 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100999 descriptorName,
1000 "input_0",
1001 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001002 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001003
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001004 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1005 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001006 descriptorName,
1007 "input_0",
1008 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001009}
1010
Ryan OSheaec6c6802020-06-05 17:17:06 +01001011void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1012{
1013 const std::string descriptorName{"FillQueueDescriptor"};
1014
1015 ValidateNumInputs(workloadInfo, descriptorName, 1);
1016 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1017
1018 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1019 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1020
1021 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1022
1023 std::vector<DataType> supportedTypes =
1024 {
1025 DataType::BFloat16,
1026 DataType::Float32,
1027 DataType::Float16,
1028 DataType::Signed32
1029 };
1030
1031 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1032}
1033
telsoa014fcda012018-03-09 14:13:49 +00001034void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1035{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001036 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001037
Matthew Sloyan81beae32021-07-13 19:46:11 +01001038 uint32_t numInputs = 2;
1039 if (m_Parameters.m_BiasEnabled)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001040 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001041 numInputs = 3;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001042 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001043
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001044 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001045 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1046
1047 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1048 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1049
1050 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1051
1052 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001053 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001054 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001055 }
1056
Matthew Sloyan81beae32021-07-13 19:46:11 +01001057 TensorInfo weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001058 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001059
1060 if (m_Parameters.m_BiasEnabled)
1061 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001062 TensorInfo biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
telsoa01c577f2c2018-08-31 09:22:23 +01001063 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001064 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001065 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1066 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001067 }
1068
Francis Murtagh46c09d02019-05-28 08:15:28 +01001069 // Check the supported data types
1070 std::vector<DataType> supportedTypes =
1071 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001072 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001073 DataType::Float32,
1074 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001075 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001076 DataType::QAsymmU8,
1077 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001078 };
1079
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001080 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001081
1082 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1083 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1084 {
1085 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1086 {
1087 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1088 "for BFloat16 input.");
1089 }
1090 }
1091 else
1092 {
1093 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1094 }
telsoa014fcda012018-03-09 14:13:49 +00001095}
1096
telsoa014fcda012018-03-09 14:13:49 +00001097void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1098{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001099 const std::string descriptorName{"NormalizationQueueDescriptor"};
1100
1101 ValidateNumInputs(workloadInfo, descriptorName, 1);
1102 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1103
1104 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1105 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001106
1107 // Check the supported data types
1108 std::vector<DataType> supportedTypes =
1109 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001110 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001111 DataType::Float16,
1112 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001113 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001114 DataType::QAsymmU8,
1115 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001116 };
1117
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001118 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001119
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001120 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001121
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001122 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001123}
1124
1125void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1126{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001127 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001128
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001129 ValidateNumInputs(workloadInfo, descriptorName, 2);
1130 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1131
1132 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1133 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1134 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1135
1136 std::vector<DataType> supportedTypes =
1137 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001138 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001139 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001140 DataType::Float16,
1141 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001142 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001143 DataType::QSymmS16,
1144 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001145 };
1146
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001147 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1148 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1149 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001150
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001151 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1152 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001153
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001154 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1155 inputTensorInfo1,
1156 outputTensorInfo,
1157 descriptorName,
1158 "input_0",
1159 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001160}
1161
telsoa014fcda012018-03-09 14:13:49 +00001162void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1163{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001164 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001165
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001166 ValidateNumInputs(workloadInfo, descriptorName, 2);
1167 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1168
1169 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1170 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1171 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1172
1173 std::vector<DataType> supportedTypes =
1174 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001175 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001176 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001177 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001178 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001179 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001180 DataType::QSymmS16,
1181 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001182 };
1183
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001184 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1185 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1186 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001187
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001188 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1189 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001190
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001191 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1192 inputTensorInfo1,
1193 outputTensorInfo,
1194 descriptorName,
1195 "input_0",
1196 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001197}
1198
1199void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1200{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001201 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001202
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001203 ValidateNumInputs(workloadInfo, descriptorName, 1);
1204 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1205
1206 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1207 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001208
1209 std::vector<DataType> supportedTypes =
1210 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001211 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001212 DataType::Float16,
1213 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001214 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001215 DataType::QAsymmU8,
1216 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001217 };
1218
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001219 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1220 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001221
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001222 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001223 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001224
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001225 ValidatePointer(m_Mean, descriptorName, "mean");
1226 ValidatePointer(m_Variance, descriptorName, "variance");
1227 ValidatePointer(m_Beta, descriptorName, "beta");
1228 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001229
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001230 const TensorInfo& mean = m_Mean->GetTensorInfo();
1231 const TensorInfo& variance = m_Variance->GetTensorInfo();
1232 const TensorInfo& beta = m_Beta->GetTensorInfo();
1233 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001234
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001235 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1236 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1237 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1238 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001239
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001240 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1241 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1242 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001243}
1244
1245void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1246{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001247 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001248
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001249 ValidateNumInputs(workloadInfo, descriptorName, 1);
1250 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001251
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001252 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1253 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001254
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001255 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1256 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001257
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001258 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001259
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001260 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1261 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001262
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001263 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001264
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001265 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001266 if (m_Parameters.m_BiasEnabled)
1267 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001268 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001269
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001270 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1271 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001272
1273 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1274 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001275 }
1276
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001277 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1278 {
1279 throw InvalidArgumentException(
1280 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1281 "cannot be either negative or 0.",
1282 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1283 }
1284
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001285 ValidatePerAxisQuantization(inputTensorInfo,
1286 outputTensorInfo,
1287 weightTensorInfo,
1288 optionalBiasTensorInfo,
1289 descriptorName);
1290
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001291 std::vector<DataType> supportedTypes =
1292 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001293 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001294 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001295 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001296 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001297 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001298 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001299 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001300 };
1301
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001302 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001303
1304 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1305 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1306 {
1307 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1308 {
1309 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1310 "for BFloat16 input.");
1311 }
1312 }
1313 else
1314 {
1315 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1316 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001317}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001318
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001319void Convolution3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1320{
1321 const std::string descriptorName{"Convolution3dQueueDescriptor"};
1322
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001323 uint32_t numInputs = 2;
1324 if (m_Parameters.m_BiasEnabled)
1325 {
1326 numInputs = 3;
1327 }
1328 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001329 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1330
1331 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1332 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1333
1334 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1335 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1336
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001337 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001338 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 5, "weight");
1339
1340 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1341
1342 Optional<TensorInfo> optionalBiasTensorInfo;
1343 if (m_Parameters.m_BiasEnabled)
1344 {
Matthew Sloyan5d7b0a32021-10-18 13:07:49 +01001345 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001346 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1347
1348 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1349 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1350 }
1351
1352 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 || m_Parameters.m_StrideZ <= 0 )
1353 {
1354 throw InvalidArgumentException(
1355 fmt::format("{}: strideX (provided {}), strideY (provided {}) or strideZ (provided {})"
1356 "cannot be either negative or 0.",
1357 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY, m_Parameters.m_StrideZ));
1358 }
1359
1360 ValidatePerAxisQuantization(inputTensorInfo,
1361 outputTensorInfo,
1362 weightTensorInfo,
1363 optionalBiasTensorInfo,
1364 descriptorName);
1365
1366 std::vector<DataType> supportedTypes =
1367 {
1368 DataType::BFloat16,
1369 DataType::Float16,
1370 DataType::Float32,
1371 DataType::QAsymmS8,
1372 DataType::QAsymmU8,
1373 DataType::QSymmS16,
1374 DataType::QSymmS8
1375 };
1376
1377 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1378 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1379}
1380
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001381void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1382{
1383 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1384
1385 ValidateNumInputs(workloadInfo, descriptorName, 1);
1386 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1387
1388 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1389 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1390
1391 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1392 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1393
1394 ValidatePointer(m_Weight, descriptorName, "weight");
1395
1396 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1397 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1398
1399 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1400 {
1401 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001402 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1403 "cannot be smaller than 1.",
1404 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001405 }
1406
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001407 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1408 {
1409 throw InvalidArgumentException(
1410 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1411 "cannot be either negative or 0.",
1412 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1413 }
1414
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001415 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1416
Jan Eilers53ef7952021-06-02 12:01:25 +01001417 // Expected weight shape: [ 1, H, W, I*M ] - This shape does NOT depend on the data layout
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001418 // inputChannels * channelMultiplier should be equal to outputChannels.
Jan Eilers53ef7952021-06-02 12:01:25 +01001419 const unsigned int numWeightOutputChannels = weightTensorInfo.GetShape()[3]; // I*M=Cout
1420 const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1421 if (numWeightOutputChannels != numOutputChannels)
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001422 {
James Ward47fce872020-09-10 11:57:28 +01001423 throw InvalidArgumentException(fmt::format(
Jan Eilers53ef7952021-06-02 12:01:25 +01001424 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1425 "But 4th dimension is not equal to Cout. Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
1426 descriptorName,
1427 numOutputChannels,
1428 weightTensorInfo.GetShape()[0],
1429 weightTensorInfo.GetShape()[1],
1430 weightTensorInfo.GetShape()[2],
1431 weightTensorInfo.GetShape()[3]));
1432 }
1433 if (weightTensorInfo.GetShape()[0] != 1)
1434 {
1435 throw InvalidArgumentException(fmt::format(
1436 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1437 "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
1438 descriptorName,
1439 weightTensorInfo.GetShape()[0],
1440 weightTensorInfo.GetShape()[1],
1441 weightTensorInfo.GetShape()[2],
1442 weightTensorInfo.GetShape()[3]));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001443 }
1444
Teresa Charlind8df0262019-11-11 12:28:15 +00001445 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001446
Teresa Charlind8df0262019-11-11 12:28:15 +00001447 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001448 if (m_Parameters.m_BiasEnabled)
1449 {
1450 ValidatePointer(m_Bias, descriptorName, "bias");
1451
Teresa Charlind8df0262019-11-11 12:28:15 +00001452 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1453 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001454
1455 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1456 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1457 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001458 ValidatePerAxisQuantization(inputTensorInfo,
1459 outputTensorInfo,
1460 weightTensorInfo,
1461 optionalBiasTensorInfo,
1462 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001463
1464 std::vector<DataType> supportedTypes =
1465 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001466 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001467 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001468 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001469 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001470 DataType::QAsymmU8,
1471 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001472 };
1473
1474 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1475 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001476}
1477
1478void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1479{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001480 const std::string descriptorName{"PermuteQueueDescriptor"};
1481
1482 ValidateNumInputs(workloadInfo, descriptorName, 1);
1483 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001484
1485 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1486
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001487 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1488 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001489
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001490 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1491 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001492
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001493 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001494 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001495 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001496 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001497 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1498 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1499 "must match dst dimension " + to_string(mapping[i]) +
1500 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001501 }
1502 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001503
1504 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001505}
1506
1507void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1508{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001509 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001510
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001511 ValidateNumInputs(workloadInfo, descriptorName, 1);
1512 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1513
1514 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1515 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1516
1517 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1518 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001519
1520 std::vector<DataType> supportedTypes =
1521 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001522 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001523 DataType::Float32,
1524 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001525 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001526 DataType::QAsymmU8,
1527 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001528 };
1529
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001530 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1531 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001532}
1533
1534void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1535{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001536 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001537
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001538 ValidateNumInputs(workloadInfo, descriptorName, 1);
1539 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1540
1541 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1542 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1543
1544 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1545 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001546
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001547 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001548 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001549 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001550 DataType::Float16,
1551 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001552 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001553 DataType::QAsymmU8,
1554 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001555 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001556
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001557 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1558 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001559
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001560 // ResizeBilinear only changes width and height: batch and channel count must match.
1561 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1562 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001563 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001564 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001565 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001566 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1567 descriptorName, inputBatchSize, outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001568 }
1569
Teresa Charlin970f43b2019-07-01 13:51:07 +01001570 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001571 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1572 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001573 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001574 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001575 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001576 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1577 descriptorName, inputChannelCount, outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001578 }
1579}
1580
1581void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1582{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001583 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001584
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001585 ValidateNumInputs(workloadInfo, descriptorName, 1);
1586 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1587
1588 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1589 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1590
1591 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1592 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001593
1594 std::vector<DataType> supportedTypes =
1595 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001596 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001597 DataType::Float16,
1598 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001599 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001600 DataType::QAsymmU8,
1601 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001602 };
1603
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001604 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1605 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001606
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001607 // Resize only changes width and height: batch and channel count must match.
1608 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1609 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001610 if (inputBatchSize != outputBatchSize)
1611 {
1612 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001613 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1614 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001615 }
1616
1617 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001618 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1619 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001620 if (inputChannelCount != outputChannelCount)
1621 {
1622 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001623 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1624 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001625 }
1626}
1627
1628void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1629{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001630 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001631
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001632 ValidateNumInputs(workloadInfo, descriptorName, 1);
1633 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1634
1635 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1636 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1637
1638 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1639 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1640
1641 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1642
telsoa014fcda012018-03-09 14:13:49 +00001643 if (m_Parameters.m_Min > m_Parameters.m_Max)
1644 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001645 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001646 }
telsoa014fcda012018-03-09 14:13:49 +00001647}
1648
Kevin Mayce5045a2019-10-02 14:07:47 +01001649void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1650{
1651 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1652
1653 ValidateNumInputs(workloadInfo, descriptorName, 1);
1654 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1655
1656 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1657 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1658
1659 if (inputTensorInfo.GetNumDimensions() > 4)
1660 {
1661 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1662 }
1663
1664 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1665
1666 // Check the supported data types
1667 std::vector<DataType> supportedTypes =
1668 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001669 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001670 DataType::Float32,
1671 DataType::Float16
1672 };
1673
1674 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001675 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001676}
1677
telsoa014fcda012018-03-09 14:13:49 +00001678void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1679{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001680 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001681
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001682 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001683 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1684
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001685 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1686 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1687
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001688 if (inputTensorInfo.GetNumDimensions() > 4)
1689 {
1690 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1691 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001692
1693 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001694
1695 // Check the supported data types
1696 std::vector<DataType> supportedTypes =
1697 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001698 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001699 DataType::Float32,
1700 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001701 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001702 DataType::QAsymmU8,
1703 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001704 };
1705
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001706 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001707 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1708}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001709
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001710void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1711{
1712 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1713
1714 ValidateNumInputs(workloadInfo, descriptorName, 1);
1715 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1716
1717 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1718 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1719
1720 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1721
1722 std::vector<DataType> supportedTypes =
1723 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001724 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001725 DataType::Float32,
1726 DataType::Float16,
1727 };
1728
1729 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001730 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001731}
1732
1733void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1734{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001735 const std::string descriptorName{"ConstantQueueDescriptor"};
1736
1737 ValidateNumInputs(workloadInfo, descriptorName, 0);
1738 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001739
1740 if (!m_LayerOutput)
1741 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001742 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001743 }
1744
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001745 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1746 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001747
1748 // Check the supported data types
1749 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001750 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001751 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001752 DataType::Float32,
1753 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001754 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001755 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001756 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001757 DataType::QSymmS16,
1758 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001759 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001760
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001761 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001762}
1763
1764void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1765{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001766 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001767
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001768 ValidateNumInputs(workloadInfo, descriptorName, 1);
1769 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1770
1771 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1772 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1773
1774 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001775
1776 // Check the supported data types
1777 std::vector<DataType> supportedTypes =
1778 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001779 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001780 DataType::Float32,
1781 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001782 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001783 DataType::QAsymmU8,
1784 DataType::QSymmS16,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001785 DataType::Signed32,
1786 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001787 };
1788
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001789 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1790 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001791}
1792
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001793void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1794{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001795 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001796
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001797 ValidateNumInputs(workloadInfo, descriptorName, 1);
1798 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1799
1800 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1801 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1802
1803 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1804 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001805
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001806 if (m_Parameters.m_BlockShape.size() != 2)
1807 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001808 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001809 }
1810
1811 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1812 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001813 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1814 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001815 }
1816
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001817 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001818
1819 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001820 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001821
Matthew Bentham8800c002018-11-19 13:19:28 +00001822 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001823
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001824 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1825 widthPad.first + widthPad.second;
1826 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1827 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001828
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001829 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1830 inputShape[dimensionIndices.GetChannelsIndex()];
1831 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001832
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001833 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001834 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001835 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001836 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001837 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001838 }
1839
1840 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001841 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001842 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1843 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001844 }
nikraj01120522a2019-05-31 11:33:07 +01001845
1846 std::vector<DataType> supportedTypes =
1847 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001848 DataType::BFloat16,
1849 DataType::Float16,
1850 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001851 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001852 DataType::QAsymmU8,
1853 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001854 };
1855
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001856 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1857 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001858}
1859
Keith Davisa57eccb2019-06-14 17:33:22 +01001860void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1861{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001862 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001863
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001864 ValidateNumInputs(workloadInfo, descriptorName, 1);
1865 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001866
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001867 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1868 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1869
1870 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1871 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001872
1873 std::vector<DataType> supportedTypes =
1874 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001875 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001876 DataType::Float32,
1877 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001878 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001879 DataType::QAsymmU8,
1880 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001881 };
1882
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001883 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1884 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001885
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001886 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1887
1888 if (m_Parameters.m_BlockSize == 0)
1889 {
1890 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1891 }
1892
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001893 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1894 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1895 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1896 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001897
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001898 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001899 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001900 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001901 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1902 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001903 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001904
1905 const TensorShape& outputShape = outputTensorInfo.GetShape();
1906 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1907 {
1908 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1909 "must be divisible by the square of block size." );
1910 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001911}
1912
telsoa014fcda012018-03-09 14:13:49 +00001913void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1914{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001915 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001916
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001917 ValidateNumInputs(workloadInfo, descriptorName, 1);
1918 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1919
1920 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1921 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001922
1923 std::vector<DataType> supportedTypes =
1924 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001925 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001926 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001927 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001928 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001929 };
1930
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001931 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matthew Sloyan81beae32021-07-13 19:46:11 +01001932 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1933 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1934 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001935}
1936
telsoa01c577f2c2018-08-31 09:22:23 +01001937void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1938{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001939 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1940
1941 const std::string descriptorName{"LstmQueueDescriptor"};
1942
1943 // check dimensions of all inputs and outputs
1944 if (workloadInfo.m_InputTensorInfos.size() != 3)
1945 {
1946 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1947 }
1948 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1949 {
1950 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1951 }
1952
1953 std::vector<DataType> supportedTypes =
1954 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001955 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001956 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001957 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001958 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001959 };
1960
Jan Eilers38e05bd2019-06-26 13:10:09 +01001961 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001962 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1963
Jan Eilers38e05bd2019-06-26 13:10:09 +01001964 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001965 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001966 {
1967 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1968 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001969 descriptorName,
1970 "input_0",
1971 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001972 }
1973 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001974 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001975 {
1976 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1977 workloadInfo.m_OutputTensorInfos[i],
1978 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001979 "input_0",
1980 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001981 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001982
janeil0117d8d852019-11-15 15:00:16 +00001983 // Making sure clipping parameters have valid values.
1984 // == 0 means no clipping
1985 // > 0 means clipping
1986 if (m_Parameters.m_ClippingThresCell < 0.0f)
1987 {
1988 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
1989 }
1990 if (m_Parameters.m_ClippingThresProj < 0.0f)
1991 {
1992 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
1993 }
1994
Jan Eilers38e05bd2019-06-26 13:10:09 +01001995 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01001996 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1997 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1998 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1999 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
2000 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
2001 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
2002
Jan Eilers38e05bd2019-06-26 13:10:09 +01002003 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002004 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
2005 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002006 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002007 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
2008 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002009 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002010 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
2011 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002012 // scratchBufferTensor
2013 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002014 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
2015 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002016 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002017 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
2018 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002019 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002020 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
2021 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002022 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002023 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
2024 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002025
Jan Eilers38e05bd2019-06-26 13:10:09 +01002026 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
2027 if ( m_InputToInputWeights )
2028 {
2029 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
2030 (n_cell * n_input), "InputLayerNormWeights");
2031 }
2032
2033 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2034 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
2035 (n_cell * n_input), "InputToForgetWeights");
2036
2037 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2038 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2039 (n_cell * n_input), "InputToCellWeights");
2040
2041 if ( m_RecurrentToInputWeights )
2042 {
2043 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2044 (n_cell * n_output), "RecurrentToInputWeights");
2045 }
2046
2047 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2048 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2049 (n_cell * n_output), "RecurrentToForgetWeights");
2050
2051 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2052 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2053 (n_cell * n_output), "RecurrentToCellWeights");
2054
2055 // Make sure the input-gate's parameters are either both present (regular
2056 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2057 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2058 !m_Parameters.m_CifgEnabled) ||
2059 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2060 m_Parameters.m_CifgEnabled));
2061 if (!cifg_weights_all_or_none)
2062 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002063 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2064 "RecurrentToInputWeights must either both be present (regular LSTM) "
2065 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2066 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002067 }
2068
2069 if ( m_CellToInputWeights )
2070 {
2071 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2072 n_cell, "CellToInputWeights");
2073 }
2074 if ( m_CellToForgetWeights )
2075 {
2076 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2077 n_cell, "CellToForgetWeights");
2078 }
2079 if ( m_CellToOutputWeights )
2080 {
2081 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2082 n_cell, "CellToOutputWeights");
2083 }
2084
2085 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2086 bool peephole_weights_all_or_none =
2087 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2088 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2089 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2090 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2091 if (!peephole_weights_all_or_none)
2092 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002093 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002094 }
2095
2096 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2097 if (m_Parameters.m_CifgEnabled)
2098 {
2099 if (m_InputGateBias)
2100 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002101 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002102 }
2103 }
2104 else
2105 {
2106 if (!m_InputGateBias)
2107 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002108 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2109 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002110 }
2111 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2112 n_cell, "InputGateBias");
2113 }
2114
2115 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2116 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2117
2118 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2119 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2120
2121 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2122 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2123
2124 if (m_ProjectionWeights)
2125 {
2126 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2127 (n_cell * n_output), "ProjectionWeights");
2128 }
2129 if (m_ProjectionBias)
2130 {
2131 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2132 }
2133
2134 // Making sure the projection tensors are consistent:
2135 // 1) If projection weight is not present, then projection bias should not be
2136 // present.
2137 // 2) If projection weight is present, then projection bias is optional.
2138 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2139 !m_Parameters.m_ProjectionEnabled)
2140 || (m_ProjectionWeights && !m_ProjectionBias &&
2141 m_Parameters.m_ProjectionEnabled)
2142 || (m_ProjectionWeights && m_ProjectionBias &&
2143 m_Parameters.m_ProjectionEnabled));
2144 if (!projecton_tensors_consistent)
2145 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002146 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002147 }
2148
2149 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2150 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2151 // either all have values or none of them have values. Layer normalization is used when the values of all the
2152 // layer normalization weights are present
2153 if (m_InputLayerNormWeights)
2154 {
2155 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2156 }
2157 if (m_ForgetLayerNormWeights)
2158 {
2159 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2160 }
2161 if (m_CellLayerNormWeights)
2162 {
2163 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2164 }
2165 if (m_OutputLayerNormWeights)
2166 {
2167 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2168 }
2169
Jan Eilers38e05bd2019-06-26 13:10:09 +01002170 if (m_Parameters.m_LayerNormEnabled)
2171 {
2172 if (!m_Parameters.m_CifgEnabled)
2173 {
2174 if (!m_InputLayerNormWeights)
2175 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002176 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2177 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002178 }
2179 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2180 1, n_cell, "InputLayerNormWeights");
2181 }
2182 else if (m_InputLayerNormWeights)
2183 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002184 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2185 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002186 }
2187
2188 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2189 "ForgetLayerNormWeights");
2190 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2191
2192 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2193 "OutputLayerNormWeights");
2194 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2195
2196 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2197 "CellLayerNormWeights");
2198 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2199 }
2200 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2201 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002202 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2203 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002204 }
telsoa01c577f2c2018-08-31 09:22:23 +01002205}
2206
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002207void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2208{
2209 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2210
2211 ValidateNumInputs(workloadInfo, descriptorName, 1);
2212 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2213
2214 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2215 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2216
2217 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2218 {
2219 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2220 }
2221
2222 if (outputTensorInfo.GetDataType() != DataType::Float32)
2223 {
2224 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2225 }
2226
2227 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2228}
2229
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002230void ConvertFp32ToBf16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2231{
2232 const std::string descriptorName{"ConvertFp32ToBf16QueueDescriptor"};
2233
2234 ValidateNumInputs(workloadInfo, descriptorName, 1);
2235 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2236
2237 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2238 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2239
2240 if (inputTensorInfo.GetDataType() != DataType::Float32)
2241 {
2242 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2243 }
2244
2245 if (outputTensorInfo.GetDataType() != DataType::BFloat16)
2246 {
2247 throw InvalidArgumentException(descriptorName + ": Output tensor type must be BFloat16.");
2248 }
2249
2250 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2251}
2252
telsoa01c577f2c2018-08-31 09:22:23 +01002253void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2254{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002255 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002256
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002257 ValidateNumInputs(workloadInfo, descriptorName, 1);
2258 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2259
2260 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2261 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2262
2263 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002264 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002265 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002266 }
2267
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002268 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002269 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002270 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002271 }
2272
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002273 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002274}
2275
2276void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2277{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002278 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002279
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002280 ValidateNumInputs(workloadInfo, descriptorName, 1);
2281 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2282
2283 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2284 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2285
2286 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002287 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002288 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002289 }
2290
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002291 if (outputTensorInfo.GetDataType() != DataType::Float32)
2292 {
2293 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2294 }
2295
2296 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002297}
2298
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002299void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2300{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002301 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002302
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002303 ValidateNumInputs(workloadInfo, descriptorName, 2);
2304 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2305
2306 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2307 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2308 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2309
2310 std::vector<DataType> supportedTypes =
2311 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002312 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002313 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002314 DataType::Float32,
2315 DataType::QAsymmS8,
2316 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002317 DataType::QSymmS16,
2318 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002319 };
2320
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002321 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2322 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2323 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002324
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002325 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2326 inputTensorInfo1,
2327 outputTensorInfo,
2328 descriptorName,
2329 "input_0",
2330 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002331}
2332
David Beckc2044fe2018-09-05 15:00:38 +01002333void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2334{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002335 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002336
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002337 ValidateNumInputs(workloadInfo, descriptorName, 2);
2338 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2339
2340 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2341 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2342 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2343
2344 std::vector<DataType> supportedTypes =
2345 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002346 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002347 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002348 DataType::Float32,
2349 DataType::QAsymmS8,
2350 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002351 DataType::QSymmS16,
2352 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002353 };
2354
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002355 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2356 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2357 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002358
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002359 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2360 inputTensorInfo1,
2361 outputTensorInfo,
2362 descriptorName,
2363 "input_0",
2364 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002365}
2366
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002367void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2368{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002369 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002370
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002371 ValidateNumInputs(workloadInfo, descriptorName, 2);
2372 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2373
2374 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2375 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2376 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2377
2378 std::vector<DataType> supportedTypes =
2379 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002380 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002381 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002382 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002383 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002384 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002385 DataType::QSymmS16,
2386 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002387 };
2388
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002389 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2390 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2391 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002392
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002393 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2394 inputTensorInfo1,
2395 outputTensorInfo,
2396 descriptorName,
2397 "input_0",
2398 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002399}
2400
narpra01a6bf9122018-09-10 09:50:09 +01002401void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2402{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002403 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002404
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002405 ValidateNumInputs(workloadInfo, descriptorName, 1);
2406 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2407
2408 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2409 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002410
2411 std::vector<DataType> supportedTypes =
2412 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002413 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002414 DataType::Float32,
2415 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002416 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002417 DataType::QAsymmU8,
2418 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002419 };
narpra01eb061912018-09-10 17:35:27 +01002420
James Conroy4d1ff582019-06-10 17:06:39 +01002421 // First check if input tensor data type is supported, then
2422 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002423 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2424 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002425
narpra0132b90462018-09-13 11:07:48 +01002426 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002427 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002428 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002429 }
narpra0132b90462018-09-13 11:07:48 +01002430 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002431 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002432 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002433 }
2434 else
2435 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002436 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002437 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002438 ValidateTensorNumDimensions(outputTensorInfo,
2439 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002440 outputDim > 0 ? outputDim : 1,
2441 "output");
2442 }
narpra01a6bf9122018-09-10 09:50:09 +01002443}
2444
jimfly012c9322a2018-09-19 10:59:49 +01002445void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2446{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002447 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002448
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002449 ValidateNumInputs(workloadInfo, descriptorName, 1);
2450 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2451
2452 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2453 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002454
jimfly012c9322a2018-09-19 10:59:49 +01002455 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002456 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2457
jimfly012c9322a2018-09-19 10:59:49 +01002458 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002459 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2460 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2461 "as there are dimensions in the input tensor that is " +
2462 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2463 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002464 }
2465}
2466
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002467void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2468{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002469 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002470
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002471 ValidateNumInputs(workloadInfo, descriptorName, 1);
2472 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002473
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002474 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2475 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2476
Sadik Armagan2208b602019-07-31 16:36:27 +01002477 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002478 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002479 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002480 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002481 DataType::Float16,
2482 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002483 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002484 DataType::QAsymmU8,
2485 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002486 };
2487
2488 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002489
Keith Davis0c2eeac2020-02-11 16:51:50 +00002490 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002491 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002492 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002493 }
2494}
2495
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002496void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2497{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002498 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002499
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002500 ValidateNumInputs(workloadInfo, descriptorName, 1);
2501 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002502
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002503 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2504 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002505
2506 std::vector<DataType> supportedTypes =
2507 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002508 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002509 DataType::Float32,
2510 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002511 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002512 DataType::QAsymmU8,
2513 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002514 };
2515
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002516 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2517 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002518}
2519
Conor Kennedy430b5d82018-11-14 15:28:28 +00002520void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2521{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002522 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002523
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002524 ValidateNumInputs(workloadInfo, descriptorName, 1);
2525 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2526
2527 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2528 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002529
2530 std::vector<DataType> supportedTypes =
2531 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002532 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002533 DataType::Float16,
2534 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002535 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002536 DataType::QAsymmU8,
2537 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002538 };
2539
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002540 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2541 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002542
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002543 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002544
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002545 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002546 if (rank > 4)
2547 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002548 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002549 }
2550
Conor Kennedy430b5d82018-11-14 15:28:28 +00002551 // Begin, End & Stride length must be of rank(input0)
2552 if (m_Parameters.m_Begin.size() != rank)
2553 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002554 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002555 }
2556
2557 if (m_Parameters.m_End.size() != rank)
2558 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002559 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002560 }
2561
2562 if (m_Parameters.m_Stride.size() != rank)
2563 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002564 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002565 }
2566
2567 // Stride entries must be non-zero
2568 for (auto& stride : m_Parameters.m_Stride)
2569 {
2570 if (stride == 0)
2571 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002572 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002573 }
2574 }
2575}
2576
kevmay0190539692018-11-29 08:40:19 +00002577void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2578{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002579 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002580
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002581 ValidateNumInputs(workloadInfo, descriptorName, 2);
2582 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2583
2584 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2585 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2586 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2587
2588 std::vector<DataType> supportedTypes =
2589 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002590 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002591 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002592 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002593 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002594 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002595 DataType::QSymmS16,
2596 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002597 };
2598
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002599 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2600 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2601 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002602
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002603 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2604 inputTensorInfo1,
2605 outputTensorInfo,
2606 descriptorName,
2607 "input_0",
2608 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002609}
2610
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002611void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2612{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002613 const std::string descriptorName{"DebugQueueDescriptor"};
2614
2615 ValidateNumInputs(workloadInfo, descriptorName, 1);
2616 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002617}
2618
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002619void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2620{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002621 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002622
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002623 ValidateNumInputs(workloadInfo, descriptorName, 2);
2624 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002625
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002626 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2627 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2628 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2629
2630 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2631 inputTensorInfo1,
2632 outputTensorInfo,
2633 descriptorName,
2634 "input_0",
2635 "input_1");
2636
2637 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002638 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002639 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002640 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002641}
2642
FrancisMurtagh878f0232018-12-19 10:56:15 +00002643void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2644{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002645 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002646
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002647 ValidateNumInputs(workloadInfo, descriptorName, 2);
2648 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002649
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002650 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2651 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2652 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2653
2654 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2655 inputTensorInfo1,
2656 outputTensorInfo,
2657 descriptorName,
2658 "input_0",
2659 "input_1");
2660
2661 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002662 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002663 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002664 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002665}
2666
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002667void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2668{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002669 const std::string descriptorName{"RsqrtQueueDescriptor"};
2670
2671 ValidateNumInputs(workloadInfo, descriptorName, 1);
2672 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2673
2674 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2675 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2676
2677 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002678
2679 std::vector<DataType> supportedTypes =
2680 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002681 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002682 DataType::Float16,
2683 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002684 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002685 DataType::QAsymmU8,
2686 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002687 };
2688
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002689 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2690 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002691}
2692
narpra01b89b05f2019-01-16 09:53:09 +00002693void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2694{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002695 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002696
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002697 ValidateNumInputs(workloadInfo, descriptorName, 2);
2698 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002699
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002700 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2701 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002702 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002703 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002704 }
2705
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002706 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2707 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2708
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002709 std::vector<DataType> supportedTypes =
2710 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002711 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002712 DataType::Float16,
2713 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002714 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002715 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002716 DataType::QSymmS16,
2717 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002718 };
2719
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002720 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002721
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002722 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002723
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002724 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2725 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002726}
2727
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002728void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2729{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002730 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2731
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002732 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002733
2734 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2735 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002736 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002737 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2738 }
2739
2740 if (m_Anchors == nullptr)
2741 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002742 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002743 }
2744
2745 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002746 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2747 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2748
2749 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002750 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002751 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2752 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002753
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002754 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2755 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2756 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002757
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002758 const std::vector<DataType> supportedInputTypes =
2759 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002760 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002761 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002762 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002763 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002764 DataType::QAsymmU8,
2765 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002766 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002767
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002768 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2769 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2770 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2771
2772 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2773 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2774 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2775 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2776
2777 // NOTE: Output is always Float32 regardless of input type
2778 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2779 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2780 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2781 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002782
2783 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2784 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002785 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002786 "must be positive and less than or equal to 1.");
2787 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002788
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002789 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2790 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002791 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002792 "should be equal to number of classes + 1.");
2793 }
2794}
2795
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002796void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2797{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002798 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002799
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002800 ValidateNumInputs(workloadInfo, descriptorName, 1);
2801 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2802
2803 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2804 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2805
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002806 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002807 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002808 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002809 }
2810
Sadik Armagan2208b602019-07-31 16:36:27 +01002811 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002812 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002813 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002814 DataType::Float32,
2815 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002816 };
2817
2818 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002819}
2820
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002821void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2822{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002823 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002824
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002825 ValidateNumInputs(workloadInfo, descriptorName, 2);
2826 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002827
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002828 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2829 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2830 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002831
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002832 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2833 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2834
2835 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2836 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002837}
2838
Keith Davis3ae3f972021-05-21 16:33:48 +01002839void ShapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2840{
2841 const std::string& descriptorName{"ShapeQueueDescriptor"};
2842
2843 ValidateNumInputs(workloadInfo, descriptorName, 1);
2844 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2845
2846 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2847 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2848
2849 std::vector<DataType> supportedTypes =
2850 {
2851 DataType::BFloat16,
2852 DataType::Float16,
2853 DataType::Float32,
2854 DataType::QAsymmS8,
2855 DataType::QAsymmU8,
2856 DataType::QAsymmS8,
2857 DataType::QSymmS8,
2858 DataType::QSymmS16,
2859 DataType::Signed32
2860 };
2861
2862 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2863 ValidateDataTypes(outputTensorInfo, {DataType::Signed32}, descriptorName);
2864}
2865
Sadik Armaganeff363d2019-04-05 15:25:46 +01002866void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2867{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002868 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002869
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002870 ValidateNumInputs(workloadInfo, descriptorName, 2);
2871 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2872
2873 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2874 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2875
2876 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2877 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2878
2879 std::vector<DataType> supportedTypes =
2880 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002881 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002882 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002883 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002884 DataType::QAsymmU8,
2885 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002886 };
2887
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002888 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2889 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002890
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002891 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2892 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002893
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002894 ValidateTensorShapesMatch(inputTensorInfo0,
2895 outputTensorInfo0,
2896 descriptorName,
2897 "input_0",
2898 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002899
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002900 ValidateTensorShapesMatch(inputTensorInfo0,
2901 outputTensorInfo1,
2902 descriptorName,
2903 "input_0",
2904 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002905}
2906
Derek Lamberti901ea112019-12-10 22:07:09 +00002907void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002908{
2909 // This is internally generated so it should not need validation.
2910}
2911
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002912void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2913{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002914 const std::string& descriptorName{"PreluQueueDescriptor"};
2915
2916 ValidateNumInputs(workloadInfo, descriptorName, 2);
2917 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2918
2919 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2920 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2921 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002922
2923 std::vector<DataType> supportedTypes
2924 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002925 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002926 DataType::Float16,
2927 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002928 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002929 DataType::QAsymmU8,
2930 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002931 };
2932
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002933 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2934 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002935
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002936 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002937
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002938 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2939 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002940
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002941 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2942 alphaTensorInfo,
2943 outputTensorInfo,
2944 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002945 "input",
2946 "alpha");
2947}
2948
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002949void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2950{
2951 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2952
2953 ValidateNumInputs(workloadInfo, descriptorName, 1);
2954 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2955
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002956 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2957 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2958
2959 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2960 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002961
2962 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002963
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002964 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2965 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002966
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002967 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2968
2969 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002970 if (m_Parameters.m_BiasEnabled)
2971 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002972 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002973
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002974 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2975 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002976
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002977 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002978 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002979 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002980
2981 ValidatePerAxisQuantization(inputTensorInfo,
2982 outputTensorInfo,
2983 weightTensorInfo,
2984 optionalBiasTensorInfo,
2985 descriptorName);
2986
2987 std::vector<DataType> supportedTypes =
2988 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002989 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002990 DataType::Float32,
2991 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002992 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002993 DataType::QAsymmU8,
2994 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002995 };
2996
2997 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2998 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002999}
3000
Mike Kellyc9ea45a2020-02-28 18:11:58 +00003001void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3002{
3003 const std::string descriptorName{"TransposeQueueDescriptor"};
3004
3005 ValidateNumInputs(workloadInfo, descriptorName, 1);
3006 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3007
3008 const PermutationVector& mapping = m_Parameters.m_DimMappings;
3009
3010 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3011 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3012
3013 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
3014 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
3015
3016 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
3017 {
3018 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
3019 {
3020 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
3021 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
3022 "must match dst dimension " + to_string(i) +
3023 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
3024 }
3025 }
3026
3027 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3028}
3029
Simon Obute51f67772021-09-03 15:50:13 +01003030void ChannelShuffleQueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
3031{
3032 const std::string descriptorName{"TransposeQueueDescriptor"};
3033
3034 ValidateNumInputs(workloadInfo, descriptorName, 1);
3035 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3036
3037 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3038 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3039
3040 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3041}
3042
James Conroy4f1f8992020-04-29 20:01:10 +01003043void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3044{
3045 const std::string descriptorName{"QLstmQueueDescriptor"};
3046
3047 // Validate number of inputs/outputs
3048 ValidateNumInputs(workloadInfo, descriptorName, 3);
3049 ValidateNumOutputs(workloadInfo, descriptorName, 3);
3050
3051 // Input/output tensor info
3052 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3053 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
3054 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
3055
3056 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3057 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3058 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
3059
3060 // Supported types for various tensors in QLSTM
3061 std::vector<DataType> inputOutputSupportedTypes =
3062 {
3063 DataType::QAsymmS8
3064 };
3065
3066 std::vector<DataType> cellStateSupportedTypes =
3067 {
3068 DataType::QSymmS16
3069 };
3070
3071 std::vector<DataType> weightsSupportedTypes =
3072 {
3073 DataType::QSymmS8
3074 };
3075
3076 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3077 {
3078 DataType::QSymmS16
3079 };
3080
3081 std::vector<DataType> biasSupportedTypes =
3082 {
3083 DataType::Signed32
3084 };
3085
3086 // Validate types of input/output tensors
3087 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3088 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3089 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3090
3091 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3092 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3093 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3094
3095 // Validate matching types of input/output tensors
3096 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3097 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3098 "outputStateIn", "outputStateOut");
3099 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3100
3101 // Infer number of batches, number of units, input size and output size from tensor dimensions
3102 const uint32_t numBatches = inputInfo.GetShape()[0];
3103 const uint32_t inputSize = inputInfo.GetShape()[1];
3104 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3105 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3106
3107 // Validate number of dimensions and number of elements for input/output tensors
3108 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3109 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3110 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3111
3112 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3113 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3114 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3115
3116 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3117 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3118 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3119 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3120
3121 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3122 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3123 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3124
3125 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3126 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3127 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3128
3129 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3130 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3131 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3132 " RecurrentToForgetWeights");
3133
3134 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3135 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3136 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3137
3138 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3139 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3140 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3141
3142 // Validate data types for MANDATORY weights tensors (all should match each other)
3143 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3144
3145 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3146 "inputToForgetWeights", "inputToCellWeights");
3147 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3148 "inputToForgetWeights", "inputToOutputWeights");
3149
3150 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3151 "inputToForgetWeights", "recurrentToForgeteights");
3152 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3153 "inputToForgetWeights", "recurrentToCellWeights");
3154 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3155 "inputToForgetWeights", "recurrentToOutputWeights");
3156
3157 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3158 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3159 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3160 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3161
3162 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3163 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3164 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3165
3166 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3167 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3168 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3169
3170 // Validate data types for MANDATORY bias tensors
3171 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3172
3173 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3174 "forgetGateBias", "cellBias");
3175 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3176 "forgetGateBias", "outputGateBias");
3177
3178 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3179 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3180 !m_Parameters.m_CifgEnabled) ||
3181 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3182 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3183
3184 if (!allCifgParamsPresentOrNot)
3185 {
3186 throw InvalidArgumentException(descriptorName +
3187 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3188 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3189 "set appropriately.");
3190 }
3191
3192 if (!m_Parameters.m_CifgEnabled)
3193 {
3194 // Validate number of dimensions and number of elements
3195 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3196 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3197
3198 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3199 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3200 " RecurrentToInputWeights");
3201
3202 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3203 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3204
3205 // Validate data types
3206 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3207 "inputToForgetWeights", "inputToInputWeights");
3208 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3209 "inputToForgetWeights", "recurrentToInputWeights");
3210 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3211 "forgetGateBias", "inputGateBias");
3212 }
3213
3214 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3215 bool allPeepholeWeightsPresentOrNot =
3216 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3217 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3218 || (!m_CellToInputWeights && !m_CellToForgetWeights
3219 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3220
3221 if (!allPeepholeWeightsPresentOrNot)
3222 {
3223 throw InvalidArgumentException(descriptorName +
3224 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3225 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3226 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3227 "appropriately.");
3228 }
3229
3230 if (m_Parameters.m_PeepholeEnabled)
3231 {
3232 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3233 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3234 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3235
3236 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3237 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3238 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3239 "cellToForgetWeight", "cellToOutputWeights");
3240
3241 if (!m_Parameters.m_CifgEnabled)
3242 {
3243 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3244 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3245 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3246 "cellToForgetWeights", "cellToInputWeights");
3247 }
3248 }
3249
3250 // Validate OPTIONAL params: Layer Norm Weights
3251 bool allLayerNormWeightsPresentOrNot =
3252 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3253 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3254 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3255 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3256
3257 if (!allLayerNormWeightsPresentOrNot)
3258 {
3259 throw InvalidArgumentException(descriptorName +
3260 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3261 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3262 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3263 "only be present when Layer Norm is enabled and CIFG is disabled. "
3264 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3265 }
3266
3267 if (m_Parameters.m_LayerNormEnabled)
3268 {
3269 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3270 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3271 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3272
3273 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3274 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3275 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3276 "forgetLayerNormWeights", "cellLayerNormWeights");
3277
3278 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3279 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3280 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3281 "forgetLayerNormWeights", "outputLayerNormWeights");
3282
3283 if (!m_Parameters.m_CifgEnabled)
3284 {
3285 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3286 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3287 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3288 "forgetLayerNormWeights", "inputLayerNormWeights");
3289 }
3290 }
3291
3292 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3293 bool correctProjectionTensorsPresent =
3294 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3295 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3296 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3297
3298 if (!correctProjectionTensorsPresent)
3299 {
3300 throw InvalidArgumentException(descriptorName +
3301 ": If projection is enabled, ProjectionWeights should be present and "
3302 "ProjectionBias is optional. If projection is disabled, neither "
3303 "ProjectionWeights nor ProjectionBias should be present.");
3304 }
3305
3306 if (m_Parameters.m_ProjectionEnabled)
3307 {
3308 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3309 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3310 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3311
3312 if (m_ProjectionBias)
3313 {
3314 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003315 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003316 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3317 }
3318
3319 }
3320 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3321 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3322 throw InvalidArgumentException(descriptorName +
3323 ": If projection is disabled, output quantization info (scale, offset) "
3324 "should match HiddenStateScale and HiddenStateZeroPoint.");
3325 }
3326
3327}
3328
James Conroy9c3cae82019-08-01 16:01:48 +01003329void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3330{
3331 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3332
3333 // Validate number of inputs/outputs
3334 ValidateNumInputs(workloadInfo, descriptorName, 3);
3335 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3336
3337 // Input/output tensor infos
3338 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3339 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3340 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3341
3342 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3343 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3344
3345 std::vector<DataType> inputOutputSupportedTypes =
3346 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003347 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003348 };
3349
3350 std::vector<DataType> cellStateSupportedTypes =
3351 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003352 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003353 };
3354
3355 std::vector<DataType> weightsSupportedTypes =
3356 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003357 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003358 };
3359
3360 std::vector<DataType> biasSupportedTypes =
3361 {
3362 DataType::Signed32
3363 };
3364
3365 // Validate types of input/output tensors
3366 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3367 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3368 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3369
3370 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3371 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3372
3373 // Validate matching types of input/output tensors
3374 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3375 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3376 "outputStateIn", "outputStateOut");
3377 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3378
3379 // Validate matching quantization info for input/output tensors
3380 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3381 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3382 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003383
James Conroy9c3cae82019-08-01 16:01:48 +01003384 // Infer number of batches, input size and output size from tensor dimensions
3385 const uint32_t numBatches = inputInfo.GetShape()[0];
3386 const uint32_t inputSize = inputInfo.GetShape()[1];
3387 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3388
3389 // Validate number of dimensions and number of elements for input/output tensors
3390 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3391 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3392 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3393 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3394 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3395
3396 // Validate number of dimensions and number of elements for weights tensors
3397 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3398 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3399 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3400
3401 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3402 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3403 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3404
3405 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3406 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3407 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3408
3409 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3410 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3411 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3412
3413 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3414 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3415 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3416
3417 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3418 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3419 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3420 " RecurrentToForgetWeights");
3421
3422 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3423 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3424 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3425
3426 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3427 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3428 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3429
3430 // Validate data types for weights tensors (all should match each other)
3431 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3432
3433 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3434 "inputToInputWeights", "inputToForgetWeights");
3435 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3436 "inputToInputWeights", "inputToCellWeights");
3437 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3438 "inputToInputWeights", "inputToOutputWeights");
3439
3440 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3441 "inputToInputWeights", "recurrentToInputWeights");
3442 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3443 "inputToInputWeights", "recurrentToForgeteights");
3444 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3445 "inputToInputWeights", "recurrentToCellWeights");
3446 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3447 "inputToInputWeights", "recurrentToOutputWeights");
3448
3449 // Validate matching quantization info for weight tensors (all should match each other)
3450 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3451 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3452 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3453 descriptorName, "inputToInputWeights", "inputToCellWeights");
3454 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3455 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3456
3457 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3458 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3459 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3460 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3461 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3462 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3463 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3464 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3465
3466 // Validate number of dimensions and number of elements in bias tensors
3467 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3468 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3469 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3470
3471 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3472 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3473 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3474
3475 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3476 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3477 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3478
3479 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3480 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3481 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3482
3483 // Validate data types for bias tensors (all should match each other)
3484 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3485
3486 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3487 "inputGateBias", "forgetGateBias");
3488 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3489 "inputGateBias", "cellBias");
3490 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3491 "inputGateBias", "outputGateBias");
3492
3493 // Validate bias tensor quantization info
3494 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3495 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3496 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3497 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3498}
3499
Kevin May868eb142019-09-04 17:29:31 +01003500void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3501{
3502 const std::string descriptorName{"AbsQueueDescriptor"};
3503
3504 ValidateNumInputs(workloadInfo, descriptorName, 1);
3505 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3506
3507 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3508 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3509
3510 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3511
3512 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003513 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003514 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003515 DataType::Float16,
3516 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003517 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003518 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003519 DataType::QSymmS16,
3520 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003521 };
Kevin May868eb142019-09-04 17:29:31 +01003522
3523 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3524 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3525}
3526
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003527void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3528{
3529 const std::string descriptorName{"SliceQueueDescriptor"};
3530
3531 ValidateNumInputs(workloadInfo, descriptorName, 1);
3532 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3533
3534 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3535 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3536
3537 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3538
3539 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3540 if (rank > 4)
3541 {
3542 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3543 }
3544
3545 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3546
3547 // Check if m_Begin and m_Size have the expected length
3548 if (m_Parameters.m_Begin.size() != rank)
3549 {
3550 throw InvalidArgumentException(descriptorName +
3551 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3552 }
3553 if (m_Parameters.m_Size.size() != rank)
3554 {
3555 throw InvalidArgumentException(descriptorName +
3556 ": Length of size descriptor must equal rank " + std::to_string(rank));
3557 }
3558
3559 // Check if the shape of the output tensor matches m_Size
3560 const TensorShape& outputShape = outputTensorInfo.GetShape();
3561 for (unsigned int i = 0u; i < rank; ++i)
3562 {
3563 if (m_Parameters.m_Size[i] != outputShape[i])
3564 {
3565 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3566 }
3567 }
3568
3569 // Check if the sum of begin offset and size in a given dimension
3570 // does not exceed the size of corresponding input
3571 const TensorShape& inputShape = inputTensorInfo.GetShape();
3572 for(unsigned int i = 0u; i < rank; ++i)
3573 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003574 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003575 {
3576 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3577 std::to_string(i) + " exceeds input size.");
3578 }
3579 }
3580}
3581
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003582void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3583{
3584 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3585
3586 ValidateNumInputs(workloadInfo, descriptorName, 1);
3587 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3588
3589 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3590 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3591
3592 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3593 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3594
3595 std::vector<DataType> supportedTypes =
3596 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003597 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003598 DataType::Float32,
3599 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003600 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003601 DataType::QAsymmU8,
3602 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003603 };
3604
3605 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3606 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3607
3608 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3609
3610 if (m_Parameters.m_BlockSize == 0)
3611 {
3612 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3613 }
3614
3615 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3616 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3617 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3618 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3619
3620 const TensorShape& outputShape = outputInfo.GetShape();
3621 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3622 {
3623 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3624 "must be divisible by block size.");
3625 }
3626
3627 const TensorShape& inputShape = inputInfo.GetShape();
3628 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3629 {
3630 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3631 "must be divisible by the square of block size." );
3632 }
3633}
3634
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003635void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3636{
3637 const std::string descriptorName{"ComparisonQueueDescriptor"};
3638
3639 ValidateNumInputs(workloadInfo, descriptorName, 2);
3640 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3641
3642 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3643 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3644 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3645
3646 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3647 inputTensorInfo1,
3648 outputTensorInfo,
3649 descriptorName,
3650 "input_0",
3651 "input_1");
3652
3653 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3654 {
3655 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3656 }
3657}
3658
josh minor4a3c6102020-01-06 16:40:46 -06003659void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3660{
3661 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3662
3663 ValidateNumInputs(workloadInfo, descriptorName, 1);
3664 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3665
3666 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3667 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3668
3669 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3670
3671 std::vector<DataType> supportedTypes =
3672 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003673 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003674 DataType::Float16,
3675 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003676 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003677 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003678 DataType::QSymmS16,
3679 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003680 };
3681
James Conroyaba90cd2020-11-06 16:28:18 +00003682 std::vector<DataType> logicalSupportedTypes =
3683 {
3684 DataType::Boolean
3685 };
3686
3687 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3688 {
3689 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3690 }
3691 else
3692 {
3693 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3694 }
3695
3696
josh minor4a3c6102020-01-06 16:40:46 -06003697 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3698}
3699
Finn Williams2605b232020-06-10 15:53:46 +01003700void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3701{
3702 const std::string descriptorName{"RankQueueDescriptor"};
3703
3704 ValidateNumInputs(workloadInfo, descriptorName, 1);
3705 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3706
3707 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3708 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3709
3710 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3711 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3712
3713 std::vector<DataType> supportedTypes =
3714 {
3715 DataType::BFloat16,
3716 DataType::Float16,
3717 DataType::Float32,
3718 DataType::QAsymmS8,
3719 DataType::QAsymmU8,
3720 DataType::QSymmS8,
3721 DataType::QSymmS16,
3722 DataType::Signed32
3723 };
3724
3725 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3726 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3727}
3728
James Conroyaba90cd2020-11-06 16:28:18 +00003729void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3730{
3731 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3732
3733 ValidateNumInputs(workloadInfo, descriptorName, 2);
3734 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3735
3736 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3737 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3738 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3739
3740 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3741 inputTensorInfo1,
3742 outputTensorInfo,
3743 descriptorName,
3744 "input_0",
3745 "input_1");
3746
3747 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3748 {
3749 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3750 }
3751
3752 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3753 {
3754 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3755 }
3756
3757 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3758 {
3759 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3760 }
3761}
3762
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003763void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3764{
3765 const std::string descriptorName{"ReduceQueueDescriptor"};
3766
3767 ValidateNumInputs(workloadInfo, descriptorName, 1);
3768 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3769
3770 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3771 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3772
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003773 std::vector<DataType> supportedTypes =
3774 {
3775 DataType::BFloat16,
3776 DataType::Float16,
3777 DataType::Float32,
3778 DataType::QAsymmS8,
3779 DataType::QAsymmU8,
3780 DataType::QSymmS16,
3781 DataType::Signed32
3782 };
3783
3784 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3785 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3786}
3787
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003788void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3789{
3790 // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm
3791
3792 const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"};
3793
3794 // check dimensions of all inputs and outputs
3795 if (workloadInfo.m_InputTensorInfos.size() != 3)
3796 {
3797 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
3798 }
3799 if (workloadInfo.m_OutputTensorInfos.size() != 1)
3800 {
3801 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
3802 }
3803
3804 std::vector<DataType> supportedTypes =
3805 {
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01003806 DataType::Float32
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003807 };
3808
3809 // check for supported type of one input and match them with all the other input and output
3810 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
3811
3812 // type matches all other inputs
3813 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
3814 {
3815 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
3816 workloadInfo.m_InputTensorInfos[i],
3817 descriptorName,
3818 "input_0",
3819 "input_" + std::to_string(i));
3820 }
3821 // type matches all other outputs
3822 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
3823 {
3824 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
3825 workloadInfo.m_OutputTensorInfos[i],
3826 "LstmQueueDescriptor",
3827 "input_0",
3828 "output_" + std::to_string(i));
3829 }
3830
3831 // Making sure clipping parameters have valid values.
3832 // == 0 means no clipping
3833 // > 0 means clipping
3834 if (m_Parameters.m_ClippingThresCell < 0.0f)
3835 {
3836 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
3837 }
3838 if (m_Parameters.m_ClippingThresProj < 0.0f)
3839 {
3840 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
3841 }
3842
3843 unsigned int batchIndx = 0;
3844 unsigned int inputIndx = 1;
3845 uint32_t timeStep = 1;
3846 unsigned int timeIndx = 1;
3847 inputIndx = 2;
3848 if (m_Parameters.m_TimeMajor)
3849 {
3850 batchIndx = 1;
3851 timeIndx = 0;
3852
3853 }
3854 timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx];
3855
3856 // Inferring batch size, number of outputs and number of cells from the inputs.
3857 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx];
3858 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx];
3859 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
3860 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
3861 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
3862 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
3863
3864 // input tensor
3865 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input),
3866 descriptorName + " input_0");
3867 // outputStateInTensor
3868 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
3869 descriptorName + " input_1");
3870 // outputStateInTensor
3871 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
3872 descriptorName + " input_2");
3873
3874 // outputTensor
3875 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 3, (timeStep * n_batch * n_output),
3876 descriptorName + " output_0");
3877
3878 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
3879 if ( m_InputToInputWeights )
3880 {
3881 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
3882 (n_cell * n_input), "InputLayerNormWeights");
3883 }
3884
3885 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
3886 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
3887 (n_cell * n_input), "InputToForgetWeights");
3888
3889 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
3890 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
3891 (n_cell * n_input), "InputToCellWeights");
3892
3893 if ( m_RecurrentToInputWeights )
3894 {
3895 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
3896 (n_cell * n_output), "RecurrentToInputWeights");
3897 }
3898
3899 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
3900 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
3901 (n_cell * n_output), "RecurrentToForgetWeights");
3902
3903 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
3904 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
3905 (n_cell * n_output), "RecurrentToCellWeights");
3906
3907 // Make sure the input-gate's parameters are either both present (regular
3908 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
3909 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
3910 !m_Parameters.m_CifgEnabled) ||
3911 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3912 m_Parameters.m_CifgEnabled));
3913 if (!cifg_weights_all_or_none)
3914 {
3915 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
3916 "RecurrentToInputWeights must either both be present (regular LSTM) "
3917 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
3918 "accordingly.");
3919 }
3920
3921 if ( m_CellToInputWeights )
3922 {
3923 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
3924 n_cell, "CellToInputWeights");
3925 }
3926 if ( m_CellToForgetWeights )
3927 {
3928 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
3929 n_cell, "CellToForgetWeights");
3930 }
3931 if ( m_CellToOutputWeights )
3932 {
3933 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
3934 n_cell, "CellToOutputWeights");
3935 }
3936
3937 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
3938 bool peephole_weights_all_or_none =
3939 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3940 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3941 || ( !m_CellToInputWeights && !m_CellToForgetWeights
3942 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3943 if (!peephole_weights_all_or_none)
3944 {
3945 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
3946 }
3947
3948 // Make sure the input gate bias is present only when not a CIFG-LSTM.
3949 if (m_Parameters.m_CifgEnabled)
3950 {
3951 if (m_InputGateBias)
3952 {
3953 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
3954 }
3955 }
3956 else
3957 {
3958 if (!m_InputGateBias)
3959 {
3960 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
3961 "must be present.");
3962 }
3963 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
3964 n_cell, "InputGateBias");
3965 }
3966
3967 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
3968 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
3969
3970 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
3971 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
3972
3973 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
3974 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
3975
3976 if (m_ProjectionWeights)
3977 {
3978 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
3979 (n_cell * n_output), "ProjectionWeights");
3980 }
3981 if (m_ProjectionBias)
3982 {
3983 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
3984 }
3985
3986 // Making sure the projection tensors are consistent:
3987 // 1) If projection weight is not present, then projection bias should not be
3988 // present.
3989 // 2) If projection weight is present, then projection bias is optional.
3990 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
3991 !m_Parameters.m_ProjectionEnabled)
3992 || (m_ProjectionWeights && !m_ProjectionBias &&
3993 m_Parameters.m_ProjectionEnabled)
3994 || (m_ProjectionWeights && m_ProjectionBias &&
3995 m_Parameters.m_ProjectionEnabled));
3996 if (!projecton_tensors_consistent)
3997 {
3998 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
3999 }
4000
4001 // The four layer normalization weights either all have values or none of them have values. Additionally, if
4002 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
4003 // either all have values or none of them have values. Layer normalization is used when the values of all the
4004 // layer normalization weights are present
4005 if (m_InputLayerNormWeights)
4006 {
4007 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
4008 }
4009 if (m_ForgetLayerNormWeights)
4010 {
4011 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4012 }
4013 if (m_CellLayerNormWeights)
4014 {
4015 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4016 }
4017 if (m_OutputLayerNormWeights)
4018 {
4019 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4020 }
4021
4022 if (m_Parameters.m_LayerNormEnabled)
4023 {
4024 if (!m_Parameters.m_CifgEnabled)
4025 {
4026 if (!m_InputLayerNormWeights)
4027 {
4028 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
4029 "disabled but InputLayerNormWeights are not present");
4030 }
4031 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
4032 1, n_cell, "InputLayerNormWeights");
4033 }
4034 else if (m_InputLayerNormWeights)
4035 {
4036 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
4037 "enabled");
4038 }
4039
4040 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
4041 "ForgetLayerNormWeights");
4042 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4043
4044 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
4045 "OutputLayerNormWeights");
4046 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4047
4048 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
4049 "CellLayerNormWeights");
4050 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4051 }
4052 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
4053 {
4054 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
4055 "normalisation weights are present.");
4056 }
4057}
4058
4059
mathad01df9a3222021-04-28 11:42:57 +01004060} // namespace armnn