blob: 319cdb106b39857909a85c386e8d23389284791b [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
James Conroy1f58f032021-04-27 17:13:27 +01006#include <backendsCommon/TensorHandle.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00007#include <backendsCommon/WorkloadData.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/TensorUtils.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010010#include <armnn/utility/NumericCast.hpp>
mathad01df9a3222021-04-28 11:42:57 +010011#include <armnn/Logging.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000012
telsoa014fcda012018-03-09 14:13:49 +000013#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000015#include <string>
16#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000017
James Ward47fce872020-09-10 11:57:28 +010018#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000019
Matteo Martincigh21350152018-11-28 16:22:22 +000020using namespace armnnUtils;
21
telsoa014fcda012018-03-09 14:13:49 +000022namespace armnn
23{
24
25//---------------------------------------------------------------
26DataType GetBiasDataType(DataType inputDataType)
27{
28 switch (inputDataType)
29 {
telsoa01c577f2c2018-08-31 09:22:23 +010030 case DataType::Float16:
31 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000032 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000033 case DataType::Float32:
34 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000035 case DataType::QAsymmS8:
36 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000037 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000038 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000039 case DataType::QSymmS8:
40 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000041 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010042 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000043 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010044 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000045 return DataType::Float32;
46 }
47}
48
49namespace
50{
51
52//---------------------------------------------------------------
53//android ndk does not support std::to_string function.
54template <typename T>
55std::string to_string(T value)
56{
57 std::ostringstream os;
58 os << value;
59 return os.str();
60}
61
62//---------------------------------------------------------------
63void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
64{
65 if (!ptr)
66 {
67 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
68 paramName + " parameter must be set.");
69 }
70}
71
72//---------------------------------------------------------------
73void ValidateTensorShapesMatch(const TensorInfo& first,
74 const TensorInfo& second,
75 std::string const& descName,
76 std::string const& firstName,
77 std::string const& secondName)
78{
79 if (first.GetShape() != second.GetShape())
80 {
81 throw InvalidArgumentException(descName + ": "
82 + firstName + " & " + secondName + " must have identical shapes");
83 }
84}
85
86//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010087void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000088{
Sadik Armaganeff363d2019-04-05 15:25:46 +010089 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000090 {
91 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010092 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000093 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
94 }
95}
96
97//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010098void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000099{
Sadik Armaganeff363d2019-04-05 15:25:46 +0100100 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000101 {
102 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100103 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000104 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
105 }
106}
107
108//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100109void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000110 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100111 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000112 std::string const& tensorName)
113{
114 if (tensor.GetNumDimensions() != numDimensions)
115 {
116 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
117 to_string(tensor.GetNumDimensions()) + " dimensions for " +
118 tensorName + " tensor.");
119 }
120}
121
122//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100123void ValidateTensorNumElements(const TensorInfo& tensor,
124 std::string const& descName,
125 unsigned int numElements,
126 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100127{
128 if (tensor.GetNumElements() != numElements)
129 {
130 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100131 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100132 tensorName + " tensor.");
133 }
134}
135
136//---------------------------------------------------------------
137void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100138 unsigned int numDimension,
139 unsigned int numElements,
140 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100141{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100142 const std::string functionName{"ValidateTensorNumDimNumElem"};
143 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
144 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100145}
146
147//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000148void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
149 const std::string& descName, std::string const& tensorName)
150{
151 if (tensor.GetDataType() != dataType)
152 {
153 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
154 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
155 }
156}
157
Derek Lambertid466a542020-01-22 15:37:29 +0000158void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
159{
160 ARMNN_NO_DEPRECATE_WARN_BEGIN
161 if (tensor.GetDataType() != DataType::QSymmS8 &&
162 tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
163 {
164 throw InvalidArgumentException(descName +
165 ": Expected data type which supports per-axis quantization scheme but got " +
166 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
167 }
168 ARMNN_NO_DEPRECATE_WARN_END
169}
170
telsoa014fcda012018-03-09 14:13:49 +0000171//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100172void ValidateTensorQuantizationSpace(const TensorInfo& first,
173 const TensorInfo& second,
174 const std::string& descName,
175 std::string const& firstName,
176 std::string const& secondName)
177{
178 if (!first.IsQuantized() ||
179 !second.IsQuantized())
180 {
181 // Not a quantized type, ignore the validation
182 return;
183 }
184
185 DataType firstDataType = first.GetDataType();
186 DataType secondDataType = second.GetDataType();
187
188 if (firstDataType != secondDataType)
189 {
190 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
191 " must be of the same quantized type, " +
192 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
193 secondName + " is " + GetDataTypeName(secondDataType));
194 }
195
196 if (!first.IsTypeSpaceMatch(second))
197 {
198 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
199 " must have the same quantization space, " +
200 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
201 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
202 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
203 " and scale " + to_string(second.GetQuantizationScale()));
204 }
205}
206
207//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100208void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
209 const TensorInfo& inputTensorInfo,
210 const TensorInfo& weightsTensorInfo,
211 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000212{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000213 // Helper lambda function to validate a single bias quantization scale value
214 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
215 {
mathad01df9a3222021-04-28 11:42:57 +0100216 constexpr float tolerance = 0.0001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000217 if (std::abs(biasScale - expectedScale) > tolerance)
218 {
219 // Print the float values with extra precision to see very small differences
mathad01df9a3222021-04-28 11:42:57 +0100220 ARMNN_LOG(warning) << std::setprecision(6) << descName << ": Expected " << expectedScale <<
221 " for bias quantization scale (product of input and weight scales), but got " <<
222 biasScale << ". Using scale provided.";
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000223 }
224 };
225
telsoa014fcda012018-03-09 14:13:49 +0000226 if (biasTensor.GetQuantizationOffset() != 0)
227 {
228 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
229 to_string(biasTensor.GetQuantizationOffset()));
230 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000231
James Conroy8502ade2020-11-12 19:26:29 +0000232 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000233 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000234 // Validate per-axis quantization scales
235 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
236 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
237
238 if (weightScales.size() != biasScales.size())
239 {
240 std::stringstream msg;
James Conroy8502ade2020-11-12 19:26:29 +0000241 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
242 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
243 << ", biases=" << biasScales.size();
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000244 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
245 }
246
247 for (size_t i = 0ul; i < biasScales.size(); ++i)
248 {
249 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
250 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
251 }
252 }
253 else
254 {
255 // Validate per-tensor quantization scale
256 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
257 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000258 }
259}
260
261//---------------------------------------------------------------
262void ValidateTensors(const std::vector<ITensorHandle*>& vec,
263 unsigned int numExpected,
264 const std::string& descName,
265 const std::string& varName)
266{
267 if (vec.empty() && numExpected > 0)
268 {
269 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
270 }
271
272 for (unsigned int i = 0; i < numExpected; ++i)
273 {
274 if (!vec[i])
275 {
276 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
277 }
278 }
279}
280
281//---------------------------------------------------------------
282void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
283 const TensorInfo& second,
284 const TensorInfo& output,
285 std::string const& descName,
286 std::string const& firstName,
287 std::string const& secondName)
288{
289 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
290 // broadcasted.
291 if (first.GetNumDimensions() != second.GetNumDimensions())
292 {
293 throw InvalidArgumentException(descName + ": Tensors "
294 + firstName + " & " + secondName
295 + " must have the same number of dimensions in order to be broadcasted");
296 }
297 uint32_t numDims = first.GetNumDimensions();
298 std::vector<uint32_t> outputDims(numDims, 0u);
299 for (uint32_t i = 0; i < numDims; i++)
300 {
301 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
302 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
303 if (dimsNotEqual && dimsNotOne)
304 {
305 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
306 }
307 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
308 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100309 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000310 if (broadcastShape != output.GetShape())
311 {
312 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
313 + firstName + " & " + secondName
314 + " does not match the output shape");
315 }
316}
317
318//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100319void ValidateDataTypes(const TensorInfo& info,
320 const std::vector<armnn::DataType>& supportedTypes,
321 std::string const& descName)
322{
323 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
324 if (iterator == supportedTypes.end())
325 {
326 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
327 }
328}
329
James Conroy4d1ff582019-06-10 17:06:39 +0100330//---------------------------------------------------------------
331void ValidateTensorDataTypesMatch(const TensorInfo& first,
332 const TensorInfo& second,
333 std::string const& descName,
334 std::string const& firstName,
335 std::string const& secondName)
336{
337 if (first.GetDataType() != second.GetDataType())
338 {
339 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
340 " must have identical data types.");
341 }
342}
343
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100344//---------------------------------------------------------------
345void ValidateTensorNumElementsMatch(const TensorInfo& first,
346 const TensorInfo& second,
347 std::string const& descName,
348 std::string const& firstName,
349 std::string const& secondName)
350{
351 if (first.GetNumElements() != second.GetNumElements())
352 {
353 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
354 " must have the same number of elements.");
355 }
356}
357
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000358void ValidateWeightDataType(const TensorInfo& inputInfo,
359 const TensorInfo& weightInfo,
360 const std::string& descName)
361{
362 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000363 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000364 {
Derek Lambertid466a542020-01-22 15:37:29 +0000365 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000366 const std::vector<DataType> validTypes =
367 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000368 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100369 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000370 DataType::QSymmS8,
371 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000372 };
Derek Lambertid466a542020-01-22 15:37:29 +0000373 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000374
375 ValidateDataTypes(weightInfo, validTypes, descName);
376 }
377 else
378 {
379 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
380 }
381}
382
383void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
384 const std::string& descName,
385 const std::string& tensorName)
386{
387 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
388 if (!quantizationDim.has_value())
389 {
James Ward47fce872020-09-10 11:57:28 +0100390 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
391 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000392 }
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000393}
394
395void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
396 const std::string& descName,
397 const std::string& tensorName)
398{
399 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
400 if (quantizationOffset != 0)
401 {
James Ward47fce872020-09-10 11:57:28 +0100402 throw InvalidArgumentException(fmt::format(
403 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
404 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000405 }
406}
407
408void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
409 const TensorInfo& outputInfo,
410 const TensorInfo& weightInfo,
411 const Optional<TensorInfo>& optionalBiasInfo,
412 const std::string& descName)
413{
414 if (weightInfo.HasPerAxisQuantization())
415 {
416 const DataType inputDataType = inputInfo.GetDataType();
417 const DataType outputDataType = outputInfo.GetDataType();
418
Keith Davis0c2eeac2020-02-11 16:51:50 +0000419 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000420
421 if (!canHavePerAxisQuantization)
422 {
James Ward47fce872020-09-10 11:57:28 +0100423 throw InvalidArgumentException(fmt::format(
424 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
425 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000426 }
427
Derek Lambertid466a542020-01-22 15:37:29 +0000428
429 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000430 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
431 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
432
433 if (optionalBiasInfo.has_value())
434 {
435 const TensorInfo& biasInfo = optionalBiasInfo.value();
436 if (!biasInfo.HasPerAxisQuantization())
437 {
James Ward47fce872020-09-10 11:57:28 +0100438 throw InvalidArgumentException(fmt::format(
439 "{}: Per-axis quantization parameters not set on bias tensor, "
440 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000441 }
442
443 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
444 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
445 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
446 }
447 }
448}
449
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100450} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000451
452void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
453 unsigned int numExpectedIn, unsigned int numExpectedOut) const
454{
455 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
456 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
457}
458
459//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100460void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
461{
462 const std::string descriptorName{"MapQueueDescriptor"};
463
464 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100465 ValidateNumOutputs(workloadInfo, descriptorName, 0);
466
467 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
468 {
469 if (!m_Inputs[i])
470 {
471 throw InvalidArgumentException(
472 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
473 }
474 }
475}
476
477//---------------------------------------------------------------
478void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
479{
480 const std::string descriptorName{"UnmapQueueDescriptor"};
481
482 ValidateNumInputs(workloadInfo, descriptorName, 1);
483 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100484
485 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
486 {
487 if (!m_Inputs[i])
488 {
489 throw InvalidArgumentException(
490 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
491 }
492 }
493}
494
495//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000496void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
497{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100498 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000499
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100500 ValidateNumInputs(workloadInfo, descriptorName, 1);
501 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000502
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100503 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
504 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
505
506 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
507 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000508
509 if (m_Inputs.size() != m_Outputs.size())
510 {
James Ward47fce872020-09-10 11:57:28 +0100511 throw InvalidArgumentException(fmt::format(
512 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
513 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000514 }
515
516 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
517 {
518 if (!m_Inputs[i])
519 {
James Ward47fce872020-09-10 11:57:28 +0100520 throw InvalidArgumentException(fmt::format(
521 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000522 }
523
524 if (!m_Outputs[i])
525 {
James Ward47fce872020-09-10 11:57:28 +0100526 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000527 }
528 }
529}
530
Derek Lambertif674aa02019-08-01 15:56:25 +0100531//---------------------------------------------------------------
532void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
533{
534 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
535 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
536
537 if (workloadInfo.m_InputTensorInfos.size() != 1)
538 {
James Ward47fce872020-09-10 11:57:28 +0100539 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
540 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100541
542 }
543
544 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
545 {
James Ward47fce872020-09-10 11:57:28 +0100546 throw InvalidArgumentException(fmt::format(
547 "Number of input infos ({0}) does not match the number of output infos ({1})",
548 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100549 }
550
551 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
552 {
553 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
554 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
555 {
James Ward47fce872020-09-10 11:57:28 +0100556 throw InvalidArgumentException(fmt::format(
557 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100558 }
559 }
560
561 if (m_Inputs.size() != 1)
562 {
James Ward47fce872020-09-10 11:57:28 +0100563 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100564 }
565
566 if (m_Inputs.size() != m_Outputs.size())
567 {
James Ward47fce872020-09-10 11:57:28 +0100568 throw InvalidArgumentException(fmt::format(
569 "Number of inputs ({0}) does not match the number of outputs ({1})",
570 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100571 }
572
573 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
574 {
575 if (!m_Inputs[i])
576 {
James Ward47fce872020-09-10 11:57:28 +0100577 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100578 }
579
580 if (!m_Outputs[i])
581 {
James Ward47fce872020-09-10 11:57:28 +0100582 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100583 }
584 }
585}
586
587//---------------------------------------------------------------
588void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
589{
590 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
591 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
592
Derek Lambertif674aa02019-08-01 15:56:25 +0100593 if (m_Inputs.size() != 1)
594 {
James Ward47fce872020-09-10 11:57:28 +0100595 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100596 }
597
598 if (m_Outputs.size() != 0)
599 {
James Ward47fce872020-09-10 11:57:28 +0100600 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100601 }
602
603 if (!m_Inputs[0])
604 {
James Ward47fce872020-09-10 11:57:28 +0100605 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100606 }
607}
608
609//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000610void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
611{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100612 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100613
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100614 ValidateNumInputs(workloadInfo, descriptorName, 1);
615 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100616
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100617 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
618 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100619
620 std::vector<DataType> supportedTypes =
621 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000622 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100623 DataType::Float16,
624 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000625 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000626 DataType::QAsymmU8,
627 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100628 };
629
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100630 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
631 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
632 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000633}
634
Nikhil Rajee391d52019-09-05 17:50:44 +0100635void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
636{
637 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
638
639 ValidateNumInputs(workloadInfo, descriptorName, 1);
640 ValidateNumOutputs(workloadInfo, descriptorName, 1);
641
642 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
643 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
644
Inki Daed4619e22020-09-10 15:33:54 +0900645 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
646 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100647 {
Inki Daed4619e22020-09-10 15:33:54 +0900648 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100649 }
650
James Conroyd47a0642019-09-17 14:22:06 +0100651 std::vector<DataType> supportedInputTypes =
652 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000653 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100654 DataType::Float16,
655 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100656 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000657 DataType::QAsymmU8,
658 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900659 DataType::Signed32,
660 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100661 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100662
James Conroyd47a0642019-09-17 14:22:06 +0100663 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100664
665 auto inputShape = inputTensorInfo.GetShape();
666 auto outputShape = outputTensorInfo.GetShape();
667
668 auto inputNumDimensions = inputShape.GetNumDimensions();
669 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
670
671 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
672
673 // 1D input shape results in scalar output shape
674 if (inputShape.GetNumDimensions() == 1)
675 {
676 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
677 {
678 throw InvalidArgumentException(descriptorName + outputShapeError);
679 }
680 }
681 else
682 {
683 for (unsigned int i = 0; i < unsignedAxis; ++i)
684 {
685 if (outputShape[i] != inputShape[i])
686 {
687 throw InvalidArgumentException(descriptorName + outputShapeError);
688 }
689 }
690
691 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
692 {
693 if (outputShape[i - 1] != inputShape[i])
694 {
695 throw InvalidArgumentException(descriptorName + outputShapeError);
696 }
697 }
698 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100699}
700
mathad01b392e982021-04-07 12:07:30 +0100701void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
702{
703 const std::string descriptorName{"CastQueueDescriptor"};
704
705 ValidateNumInputs(workloadInfo, descriptorName, 1);
706 ValidateNumOutputs(workloadInfo, descriptorName, 1);
707
708 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
709 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
710
711 std::vector<DataType> supportedTypes =
712 {
713 DataType::BFloat16,
714 DataType::Float16,
715 DataType::Float32,
716 DataType::QAsymmS8,
717 DataType::QAsymmU8,
718 DataType::QSymmS8,
719 DataType::QSymmS16,
720 DataType::Signed32,
721 DataType::Signed64
722 };
723
724 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
725 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
726}
727
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100728void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
729{
730 const std::string descriptorName{"SoftmaxQueueDescriptor"};
731
732 ValidateNumInputs(workloadInfo, descriptorName, 1);
733 ValidateNumOutputs(workloadInfo, descriptorName, 1);
734
735 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
736 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
737
738 std::vector<DataType> supportedTypes =
739 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000740 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100741 DataType::Float16,
742 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000743 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000744 DataType::QAsymmU8,
745 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100746 };
747
748 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
749 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
750 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
751}
752
telsoa014fcda012018-03-09 14:13:49 +0000753void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
754{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100755 const std::string descriptorName{"SplitterQueueDescriptor"};
756
757 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000758
Ruomei Yan25339c32019-05-28 16:48:20 +0100759 // Check the supported data types
760 std::vector<DataType> supportedTypes =
761 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000762 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100763 DataType::Float32,
764 DataType::Float16,
765 DataType::Boolean,
766 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100767 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000768 DataType::QAsymmU8,
769 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100770 };
771
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100772 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
773 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100774 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100775 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
776 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
777
778 const std::string outputName = "output_" + std::to_string(i);
779 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100780 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100781
telsoa014fcda012018-03-09 14:13:49 +0000782 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
783 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100784 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000785 }
786
787 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
788 {
789 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100790 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000791 "has to match number of workloadInfo.m_OutputTensorInfos. "
792 "Number of windows: " +
793 to_string(m_ViewOrigins.size()) +
794 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
795 }
796
telsoa01c577f2c2018-08-31 09:22:23 +0100797 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000798 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
799 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
800 {
telsoa01c577f2c2018-08-31 09:22:23 +0100801 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000802 ViewOrigin const& e = m_ViewOrigins[w];
803 if (e.m_Origin.size() != inputDims)
804 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100805 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000806 "have the same dimensionality as the input tensor. "
807 "Window origin (index: " +
808 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
809 " dimensions, the input "
810 "tensor has " +
811 to_string(inputDims) + " dimensions.");
812 }
813 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
814 {
815 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
816 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
817 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100818 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000819 "be smaller or equal than the size of the input in that coord.");
820 }
821 }
822 }
823}
824
Jim Flynne242f2d2019-05-22 14:24:13 +0100825void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000826{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100827 const std::string descriptorName{"ConcatQueueDescriptor"};
828
829 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000830
831 if (m_Inputs.size() <= 0)
832 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100833 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000834 }
835 if (m_Outputs.size() <= 0)
836 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100837 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000838 }
839
840 if (workloadInfo.m_InputTensorInfos.size() <= 0)
841 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100842 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000843 }
844 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
845 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100846 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000847 }
848
Nikhil Raj8599a412018-11-19 14:51:07 +0000849 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
850 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100851 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000852 }
853
854 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
855 {
856 return;
857 }
858
telsoa014fcda012018-03-09 14:13:49 +0000859 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
860 {
861 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100862 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000863 "has to match number of workloadInfo.m_InputTensorInfos. "
864 "Number of windows: " +
865 to_string(m_ViewOrigins.size()) +
866 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
867 }
868
telsoa01c577f2c2018-08-31 09:22:23 +0100869 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000870 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
871 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
872 {
telsoa01c577f2c2018-08-31 09:22:23 +0100873 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000874 ViewOrigin const& e = m_ViewOrigins[w];
875 if (e.m_Origin.size() != outputDims)
876 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100877 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000878 "have the same dimensionality as the output tensor. "
879 "Window origin (index: " +
880 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
881 " dimensions, the output "
882 "tensor has " +
883 to_string(outputDims) + " dimensions.");
884 }
telsoa01c577f2c2018-08-31 09:22:23 +0100885 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000886 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
887 {
888 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
889 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
890 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100891 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000892 "be smaller or equal than the size of the output in that coord.");
893 }
894 }
895 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100896
897 // Check the supported data types
898 std::vector<DataType> supportedTypes =
899 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000900 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100901 DataType::Float32,
902 DataType::Float16,
903 DataType::Boolean,
904 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100905 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000906 DataType::QAsymmU8,
907 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100908 };
909
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100910 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
911 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100912 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100913 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
914 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
915
916 const std::string inputName = "input_" + std::to_string(i);
917 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100918 }
telsoa014fcda012018-03-09 14:13:49 +0000919}
920
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100921void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
922{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100923 const std::string descriptorName{"StackQueueDescriptor"};
924
925 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100926
927 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
928 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100929 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100930 }
931
932 // All inputs must have the same shape, which is defined in parameters
933 const TensorShape& inputShape = m_Parameters.m_InputShape;
934 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
935 {
936 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
937 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100938 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100939 }
940 }
941
Matthew Jacksondba634f2019-08-15 15:14:18 +0100942 if (inputShape.GetNumDimensions() > 4)
943 {
944 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
945 }
946
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100947 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
948 // since the output tensor has an additional dimension.
949 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
950 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100951 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100952 "than the number of input dimensions.");
953 }
954
955 // Output shape must be as inferred from the input shape
956 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
957 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
958 {
959 if (outputShape[i] != inputShape[i])
960 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100961 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100962 "match shape inferred from input tensor.");
963 }
964 }
965
966 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
967 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100968 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100969 "match shape inferred from input tensor.");
970 }
971
972 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
973 {
974 if (outputShape[i] != inputShape[i-1])
975 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100976 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100977 "match shape inferred from input tensor.");
978 }
979 }
980
Matthew Jacksondba634f2019-08-15 15:14:18 +0100981 if (outputShape.GetNumDimensions() > 5)
982 {
983 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
984 }
985
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100986 // Check the supported data types
987 std::vector<DataType> supportedTypes =
988 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000989 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100990 DataType::Float32,
991 DataType::Float16,
992 DataType::Boolean,
993 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100994 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000995 DataType::QAsymmU8,
996 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100997 };
998
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100999 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001000
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001001 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001002 {
1003 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1004 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001005 descriptorName,
1006 "input_0",
1007 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001008 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001009
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001010 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1011 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001012 descriptorName,
1013 "input_0",
1014 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001015}
1016
Ryan OSheaec6c6802020-06-05 17:17:06 +01001017void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1018{
1019 const std::string descriptorName{"FillQueueDescriptor"};
1020
1021 ValidateNumInputs(workloadInfo, descriptorName, 1);
1022 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1023
1024 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1025 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1026
1027 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1028
1029 std::vector<DataType> supportedTypes =
1030 {
1031 DataType::BFloat16,
1032 DataType::Float32,
1033 DataType::Float16,
1034 DataType::Signed32
1035 };
1036
1037 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1038}
1039
telsoa014fcda012018-03-09 14:13:49 +00001040void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1041{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001042 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001043
Matthew Sloyan81beae32021-07-13 19:46:11 +01001044 uint32_t numInputs = 2;
1045 if (m_Parameters.m_BiasEnabled)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001046 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001047 numInputs = 3;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001048 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001049
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001050 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001051 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1052
1053 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1054 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1055
1056 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1057
1058 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001059 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001060 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001061 }
1062
Matthew Sloyan81beae32021-07-13 19:46:11 +01001063 TensorInfo weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001064 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001065
1066 if (m_Parameters.m_BiasEnabled)
1067 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001068 TensorInfo biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
telsoa01c577f2c2018-08-31 09:22:23 +01001069 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001070 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001071 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1072 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001073 }
1074
Francis Murtagh46c09d02019-05-28 08:15:28 +01001075 // Check the supported data types
1076 std::vector<DataType> supportedTypes =
1077 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001078 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001079 DataType::Float32,
1080 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001081 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001082 DataType::QAsymmU8,
1083 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001084 };
1085
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001086 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001087
1088 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1089 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1090 {
1091 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1092 {
1093 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1094 "for BFloat16 input.");
1095 }
1096 }
1097 else
1098 {
1099 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1100 }
telsoa014fcda012018-03-09 14:13:49 +00001101}
1102
telsoa014fcda012018-03-09 14:13:49 +00001103void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1104{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001105 const std::string descriptorName{"NormalizationQueueDescriptor"};
1106
1107 ValidateNumInputs(workloadInfo, descriptorName, 1);
1108 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1109
1110 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1111 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001112
1113 // Check the supported data types
1114 std::vector<DataType> supportedTypes =
1115 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001116 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001117 DataType::Float16,
1118 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001119 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001120 DataType::QAsymmU8,
1121 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001122 };
1123
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001124 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001125
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001126 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001127
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001128 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001129}
1130
1131void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1132{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001133 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001134
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001135 ValidateNumInputs(workloadInfo, descriptorName, 2);
1136 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1137
1138 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1139 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1140 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1141
1142 std::vector<DataType> supportedTypes =
1143 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001144 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001145 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001146 DataType::Float16,
1147 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001148 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001149 DataType::QSymmS16,
1150 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001151 };
1152
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001153 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1154 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1155 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001156
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001157 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1158 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001159
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001160 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1161 inputTensorInfo1,
1162 outputTensorInfo,
1163 descriptorName,
1164 "input_0",
1165 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001166}
1167
telsoa014fcda012018-03-09 14:13:49 +00001168void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1169{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001170 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001171
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001172 ValidateNumInputs(workloadInfo, descriptorName, 2);
1173 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1174
1175 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1176 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1177 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1178
1179 std::vector<DataType> supportedTypes =
1180 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001181 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001182 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001183 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001184 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001185 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001186 DataType::QSymmS16,
1187 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001188 };
1189
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001190 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1191 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1192 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001193
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001194 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1195 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001196
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001197 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1198 inputTensorInfo1,
1199 outputTensorInfo,
1200 descriptorName,
1201 "input_0",
1202 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001203}
1204
1205void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1206{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001207 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001208
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001209 ValidateNumInputs(workloadInfo, descriptorName, 1);
1210 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1211
1212 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1213 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001214
1215 std::vector<DataType> supportedTypes =
1216 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001217 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001218 DataType::Float16,
1219 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001220 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001221 DataType::QAsymmU8,
1222 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001223 };
1224
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001225 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1226 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001227
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001228 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001229 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001230
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001231 ValidatePointer(m_Mean, descriptorName, "mean");
1232 ValidatePointer(m_Variance, descriptorName, "variance");
1233 ValidatePointer(m_Beta, descriptorName, "beta");
1234 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001235
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001236 const TensorInfo& mean = m_Mean->GetTensorInfo();
1237 const TensorInfo& variance = m_Variance->GetTensorInfo();
1238 const TensorInfo& beta = m_Beta->GetTensorInfo();
1239 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001240
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001241 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1242 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1243 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1244 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001245
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001246 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1247 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1248 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001249}
1250
1251void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1252{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001253 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001254
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001255 ValidateNumInputs(workloadInfo, descriptorName, 1);
1256 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001257
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001258 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1259 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001260
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001261 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1262 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001263
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001264 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001265
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001266 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1267 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001268
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001269 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001270
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001271 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001272 if (m_Parameters.m_BiasEnabled)
1273 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001274 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001275
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001276 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1277 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001278
1279 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1280 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001281 }
1282
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001283 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1284 {
1285 throw InvalidArgumentException(
1286 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1287 "cannot be either negative or 0.",
1288 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1289 }
1290
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001291 ValidatePerAxisQuantization(inputTensorInfo,
1292 outputTensorInfo,
1293 weightTensorInfo,
1294 optionalBiasTensorInfo,
1295 descriptorName);
1296
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001297 std::vector<DataType> supportedTypes =
1298 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001299 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001300 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001301 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001302 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001303 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001304 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001305 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001306 };
1307
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001308 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001309
1310 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1311 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1312 {
1313 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1314 {
1315 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1316 "for BFloat16 input.");
1317 }
1318 }
1319 else
1320 {
1321 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1322 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001323}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001324
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001325void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1326{
1327 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1328
1329 ValidateNumInputs(workloadInfo, descriptorName, 1);
1330 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1331
1332 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1333 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1334
1335 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1336 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1337
1338 ValidatePointer(m_Weight, descriptorName, "weight");
1339
1340 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1341 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1342
1343 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1344 {
1345 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001346 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1347 "cannot be smaller than 1.",
1348 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001349 }
1350
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001351 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1352 {
1353 throw InvalidArgumentException(
1354 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1355 "cannot be either negative or 0.",
1356 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1357 }
1358
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001359 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1360
Jan Eilers53ef7952021-06-02 12:01:25 +01001361 // Expected weight shape: [ 1, H, W, I*M ] - This shape does NOT depend on the data layout
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001362 // inputChannels * channelMultiplier should be equal to outputChannels.
Jan Eilers53ef7952021-06-02 12:01:25 +01001363 const unsigned int numWeightOutputChannels = weightTensorInfo.GetShape()[3]; // I*M=Cout
1364 const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1365 if (numWeightOutputChannels != numOutputChannels)
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001366 {
James Ward47fce872020-09-10 11:57:28 +01001367 throw InvalidArgumentException(fmt::format(
Jan Eilers53ef7952021-06-02 12:01:25 +01001368 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1369 "But 4th dimension is not equal to Cout. Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
1370 descriptorName,
1371 numOutputChannels,
1372 weightTensorInfo.GetShape()[0],
1373 weightTensorInfo.GetShape()[1],
1374 weightTensorInfo.GetShape()[2],
1375 weightTensorInfo.GetShape()[3]));
1376 }
1377 if (weightTensorInfo.GetShape()[0] != 1)
1378 {
1379 throw InvalidArgumentException(fmt::format(
1380 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1381 "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
1382 descriptorName,
1383 weightTensorInfo.GetShape()[0],
1384 weightTensorInfo.GetShape()[1],
1385 weightTensorInfo.GetShape()[2],
1386 weightTensorInfo.GetShape()[3]));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001387 }
1388
Teresa Charlind8df0262019-11-11 12:28:15 +00001389 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001390
Teresa Charlind8df0262019-11-11 12:28:15 +00001391 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001392 if (m_Parameters.m_BiasEnabled)
1393 {
1394 ValidatePointer(m_Bias, descriptorName, "bias");
1395
Teresa Charlind8df0262019-11-11 12:28:15 +00001396 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1397 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001398
1399 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1400 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1401 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001402 ValidatePerAxisQuantization(inputTensorInfo,
1403 outputTensorInfo,
1404 weightTensorInfo,
1405 optionalBiasTensorInfo,
1406 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001407
1408 std::vector<DataType> supportedTypes =
1409 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001410 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001411 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001412 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001413 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001414 DataType::QAsymmU8,
1415 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001416 };
1417
1418 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1419 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001420}
1421
1422void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1423{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001424 const std::string descriptorName{"PermuteQueueDescriptor"};
1425
1426 ValidateNumInputs(workloadInfo, descriptorName, 1);
1427 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001428
1429 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1430
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001431 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1432 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001433
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001434 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1435 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001436
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001437 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001438 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001439 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001440 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001441 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1442 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1443 "must match dst dimension " + to_string(mapping[i]) +
1444 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001445 }
1446 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001447
1448 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001449}
1450
1451void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1452{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001453 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001454
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001455 ValidateNumInputs(workloadInfo, descriptorName, 1);
1456 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1457
1458 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1459 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1460
1461 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1462 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001463
1464 std::vector<DataType> supportedTypes =
1465 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001466 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001467 DataType::Float32,
1468 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001469 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001470 DataType::QAsymmU8,
1471 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001472 };
1473
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001474 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1475 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001476}
1477
1478void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1479{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001480 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001481
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001482 ValidateNumInputs(workloadInfo, descriptorName, 1);
1483 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1484
1485 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1486 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1487
1488 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1489 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001490
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001491 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001492 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001493 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001494 DataType::Float16,
1495 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001496 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001497 DataType::QAsymmU8,
1498 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001499 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001500
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001501 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1502 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001503
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001504 // ResizeBilinear only changes width and height: batch and channel count must match.
1505 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1506 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001507 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001508 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001509 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001510 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1511 descriptorName, inputBatchSize, outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001512 }
1513
Teresa Charlin970f43b2019-07-01 13:51:07 +01001514 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001515 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1516 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001517 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001518 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001519 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001520 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1521 descriptorName, inputChannelCount, outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001522 }
1523}
1524
1525void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1526{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001527 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001528
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001529 ValidateNumInputs(workloadInfo, descriptorName, 1);
1530 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1531
1532 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1533 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1534
1535 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1536 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001537
1538 std::vector<DataType> supportedTypes =
1539 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001540 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001541 DataType::Float16,
1542 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001543 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001544 DataType::QAsymmU8,
1545 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001546 };
1547
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001548 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1549 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001550
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001551 // Resize only changes width and height: batch and channel count must match.
1552 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1553 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001554 if (inputBatchSize != outputBatchSize)
1555 {
1556 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001557 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1558 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001559 }
1560
1561 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001562 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1563 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001564 if (inputChannelCount != outputChannelCount)
1565 {
1566 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001567 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1568 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001569 }
1570}
1571
1572void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1573{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001574 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001575
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001576 ValidateNumInputs(workloadInfo, descriptorName, 1);
1577 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1578
1579 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1580 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1581
1582 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1583 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1584
1585 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1586
telsoa014fcda012018-03-09 14:13:49 +00001587 if (m_Parameters.m_Min > m_Parameters.m_Max)
1588 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001589 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001590 }
telsoa014fcda012018-03-09 14:13:49 +00001591}
1592
Kevin Mayce5045a2019-10-02 14:07:47 +01001593void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1594{
1595 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1596
1597 ValidateNumInputs(workloadInfo, descriptorName, 1);
1598 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1599
1600 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1601 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1602
1603 if (inputTensorInfo.GetNumDimensions() > 4)
1604 {
1605 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1606 }
1607
1608 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1609
1610 // Check the supported data types
1611 std::vector<DataType> supportedTypes =
1612 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001613 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001614 DataType::Float32,
1615 DataType::Float16
1616 };
1617
1618 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001619 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001620}
1621
telsoa014fcda012018-03-09 14:13:49 +00001622void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1623{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001624 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001625
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001626 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001627 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1628
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001629 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1630 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1631
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001632 if (inputTensorInfo.GetNumDimensions() > 4)
1633 {
1634 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1635 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001636
1637 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001638
1639 // Check the supported data types
1640 std::vector<DataType> supportedTypes =
1641 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001642 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001643 DataType::Float32,
1644 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001645 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001646 DataType::QAsymmU8,
1647 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001648 };
1649
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001650 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001651 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1652}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001653
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001654void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1655{
1656 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1657
1658 ValidateNumInputs(workloadInfo, descriptorName, 1);
1659 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1660
1661 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1662 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1663
1664 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1665
1666 std::vector<DataType> supportedTypes =
1667 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001668 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001669 DataType::Float32,
1670 DataType::Float16,
1671 };
1672
1673 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001674 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001675}
1676
1677void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1678{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001679 const std::string descriptorName{"ConstantQueueDescriptor"};
1680
1681 ValidateNumInputs(workloadInfo, descriptorName, 0);
1682 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001683
1684 if (!m_LayerOutput)
1685 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001686 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001687 }
1688
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001689 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1690 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001691
1692 // Check the supported data types
1693 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001694 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001695 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001696 DataType::Float32,
1697 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001698 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001699 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001700 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001701 DataType::QSymmS16,
1702 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001703 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001704
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001705 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001706}
1707
1708void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1709{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001710 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001711
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001712 ValidateNumInputs(workloadInfo, descriptorName, 1);
1713 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1714
1715 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1716 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1717
1718 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001719
1720 // Check the supported data types
1721 std::vector<DataType> supportedTypes =
1722 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001723 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001724 DataType::Float32,
1725 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001726 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001727 DataType::QAsymmU8,
1728 DataType::QSymmS16,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001729 DataType::Signed32,
1730 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001731 };
1732
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001733 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1734 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001735}
1736
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001737void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1738{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001739 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001740
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001741 ValidateNumInputs(workloadInfo, descriptorName, 1);
1742 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1743
1744 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1745 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1746
1747 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1748 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001749
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001750 if (m_Parameters.m_BlockShape.size() != 2)
1751 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001752 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001753 }
1754
1755 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1756 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001757 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1758 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001759 }
1760
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001761 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001762
1763 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001764 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001765
Matthew Bentham8800c002018-11-19 13:19:28 +00001766 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001767
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001768 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1769 widthPad.first + widthPad.second;
1770 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1771 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001772
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001773 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1774 inputShape[dimensionIndices.GetChannelsIndex()];
1775 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001776
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001777 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001778 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001779 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001780 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001781 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001782 }
1783
1784 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001785 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001786 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1787 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001788 }
nikraj01120522a2019-05-31 11:33:07 +01001789
1790 std::vector<DataType> supportedTypes =
1791 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001792 DataType::BFloat16,
1793 DataType::Float16,
1794 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001795 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001796 DataType::QAsymmU8,
1797 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001798 };
1799
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001800 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1801 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001802}
1803
Keith Davisa57eccb2019-06-14 17:33:22 +01001804void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1805{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001806 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001807
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001808 ValidateNumInputs(workloadInfo, descriptorName, 1);
1809 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001810
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001811 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1812 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1813
1814 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1815 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001816
1817 std::vector<DataType> supportedTypes =
1818 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001819 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001820 DataType::Float32,
1821 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001822 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001823 DataType::QAsymmU8,
1824 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001825 };
1826
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001827 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1828 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001829
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001830 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1831
1832 if (m_Parameters.m_BlockSize == 0)
1833 {
1834 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1835 }
1836
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001837 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1838 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1839 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1840 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001841
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001842 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001843 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001844 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001845 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1846 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001847 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001848
1849 const TensorShape& outputShape = outputTensorInfo.GetShape();
1850 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1851 {
1852 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1853 "must be divisible by the square of block size." );
1854 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001855}
1856
telsoa014fcda012018-03-09 14:13:49 +00001857void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1858{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001859 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001860
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001861 ValidateNumInputs(workloadInfo, descriptorName, 1);
1862 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1863
1864 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1865 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001866
1867 std::vector<DataType> supportedTypes =
1868 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001869 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001870 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001871 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001872 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001873 };
1874
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001875 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matthew Sloyan81beae32021-07-13 19:46:11 +01001876 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1877 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1878 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001879}
1880
telsoa01c577f2c2018-08-31 09:22:23 +01001881void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1882{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001883 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1884
1885 const std::string descriptorName{"LstmQueueDescriptor"};
1886
1887 // check dimensions of all inputs and outputs
1888 if (workloadInfo.m_InputTensorInfos.size() != 3)
1889 {
1890 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1891 }
1892 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1893 {
1894 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1895 }
1896
1897 std::vector<DataType> supportedTypes =
1898 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001899 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001900 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001901 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001902 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001903 };
1904
Jan Eilers38e05bd2019-06-26 13:10:09 +01001905 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001906 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1907
Jan Eilers38e05bd2019-06-26 13:10:09 +01001908 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001909 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001910 {
1911 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1912 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001913 descriptorName,
1914 "input_0",
1915 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001916 }
1917 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001918 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001919 {
1920 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1921 workloadInfo.m_OutputTensorInfos[i],
1922 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001923 "input_0",
1924 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001925 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001926
janeil0117d8d852019-11-15 15:00:16 +00001927 // Making sure clipping parameters have valid values.
1928 // == 0 means no clipping
1929 // > 0 means clipping
1930 if (m_Parameters.m_ClippingThresCell < 0.0f)
1931 {
1932 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
1933 }
1934 if (m_Parameters.m_ClippingThresProj < 0.0f)
1935 {
1936 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
1937 }
1938
Jan Eilers38e05bd2019-06-26 13:10:09 +01001939 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01001940 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1941 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1942 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1943 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1944 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1945 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1946
Jan Eilers38e05bd2019-06-26 13:10:09 +01001947 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001948 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1949 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001950 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001951 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1952 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001953 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001954 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1955 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001956 // scratchBufferTensor
1957 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001958 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1959 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001960 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001961 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1962 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001963 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001964 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1965 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001966 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001967 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1968 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001969
Jan Eilers38e05bd2019-06-26 13:10:09 +01001970 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1971 if ( m_InputToInputWeights )
1972 {
1973 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1974 (n_cell * n_input), "InputLayerNormWeights");
1975 }
1976
1977 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1978 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1979 (n_cell * n_input), "InputToForgetWeights");
1980
1981 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1982 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1983 (n_cell * n_input), "InputToCellWeights");
1984
1985 if ( m_RecurrentToInputWeights )
1986 {
1987 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1988 (n_cell * n_output), "RecurrentToInputWeights");
1989 }
1990
1991 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1992 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1993 (n_cell * n_output), "RecurrentToForgetWeights");
1994
1995 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1996 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1997 (n_cell * n_output), "RecurrentToCellWeights");
1998
1999 // Make sure the input-gate's parameters are either both present (regular
2000 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2001 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2002 !m_Parameters.m_CifgEnabled) ||
2003 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2004 m_Parameters.m_CifgEnabled));
2005 if (!cifg_weights_all_or_none)
2006 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002007 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2008 "RecurrentToInputWeights must either both be present (regular LSTM) "
2009 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2010 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002011 }
2012
2013 if ( m_CellToInputWeights )
2014 {
2015 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2016 n_cell, "CellToInputWeights");
2017 }
2018 if ( m_CellToForgetWeights )
2019 {
2020 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2021 n_cell, "CellToForgetWeights");
2022 }
2023 if ( m_CellToOutputWeights )
2024 {
2025 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2026 n_cell, "CellToOutputWeights");
2027 }
2028
2029 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2030 bool peephole_weights_all_or_none =
2031 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2032 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2033 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2034 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2035 if (!peephole_weights_all_or_none)
2036 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002037 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002038 }
2039
2040 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2041 if (m_Parameters.m_CifgEnabled)
2042 {
2043 if (m_InputGateBias)
2044 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002045 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002046 }
2047 }
2048 else
2049 {
2050 if (!m_InputGateBias)
2051 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002052 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2053 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002054 }
2055 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2056 n_cell, "InputGateBias");
2057 }
2058
2059 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2060 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2061
2062 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2063 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2064
2065 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2066 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2067
2068 if (m_ProjectionWeights)
2069 {
2070 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2071 (n_cell * n_output), "ProjectionWeights");
2072 }
2073 if (m_ProjectionBias)
2074 {
2075 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2076 }
2077
2078 // Making sure the projection tensors are consistent:
2079 // 1) If projection weight is not present, then projection bias should not be
2080 // present.
2081 // 2) If projection weight is present, then projection bias is optional.
2082 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2083 !m_Parameters.m_ProjectionEnabled)
2084 || (m_ProjectionWeights && !m_ProjectionBias &&
2085 m_Parameters.m_ProjectionEnabled)
2086 || (m_ProjectionWeights && m_ProjectionBias &&
2087 m_Parameters.m_ProjectionEnabled));
2088 if (!projecton_tensors_consistent)
2089 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002090 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002091 }
2092
2093 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2094 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2095 // either all have values or none of them have values. Layer normalization is used when the values of all the
2096 // layer normalization weights are present
2097 if (m_InputLayerNormWeights)
2098 {
2099 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2100 }
2101 if (m_ForgetLayerNormWeights)
2102 {
2103 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2104 }
2105 if (m_CellLayerNormWeights)
2106 {
2107 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2108 }
2109 if (m_OutputLayerNormWeights)
2110 {
2111 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2112 }
2113
Jan Eilers38e05bd2019-06-26 13:10:09 +01002114 if (m_Parameters.m_LayerNormEnabled)
2115 {
2116 if (!m_Parameters.m_CifgEnabled)
2117 {
2118 if (!m_InputLayerNormWeights)
2119 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002120 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2121 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002122 }
2123 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2124 1, n_cell, "InputLayerNormWeights");
2125 }
2126 else if (m_InputLayerNormWeights)
2127 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002128 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2129 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002130 }
2131
2132 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2133 "ForgetLayerNormWeights");
2134 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2135
2136 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2137 "OutputLayerNormWeights");
2138 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2139
2140 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2141 "CellLayerNormWeights");
2142 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2143 }
2144 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2145 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002146 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2147 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002148 }
telsoa01c577f2c2018-08-31 09:22:23 +01002149}
2150
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002151void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2152{
2153 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2154
2155 ValidateNumInputs(workloadInfo, descriptorName, 1);
2156 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2157
2158 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2159 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2160
2161 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2162 {
2163 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2164 }
2165
2166 if (outputTensorInfo.GetDataType() != DataType::Float32)
2167 {
2168 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2169 }
2170
2171 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2172}
2173
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002174void ConvertFp32ToBf16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2175{
2176 const std::string descriptorName{"ConvertFp32ToBf16QueueDescriptor"};
2177
2178 ValidateNumInputs(workloadInfo, descriptorName, 1);
2179 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2180
2181 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2182 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2183
2184 if (inputTensorInfo.GetDataType() != DataType::Float32)
2185 {
2186 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2187 }
2188
2189 if (outputTensorInfo.GetDataType() != DataType::BFloat16)
2190 {
2191 throw InvalidArgumentException(descriptorName + ": Output tensor type must be BFloat16.");
2192 }
2193
2194 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2195}
2196
telsoa01c577f2c2018-08-31 09:22:23 +01002197void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2198{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002199 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002200
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002201 ValidateNumInputs(workloadInfo, descriptorName, 1);
2202 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2203
2204 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2205 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2206
2207 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002208 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002209 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002210 }
2211
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002212 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002213 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002214 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002215 }
2216
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002217 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002218}
2219
2220void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2221{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002222 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002223
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002224 ValidateNumInputs(workloadInfo, descriptorName, 1);
2225 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2226
2227 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2228 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2229
2230 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002231 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002232 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002233 }
2234
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002235 if (outputTensorInfo.GetDataType() != DataType::Float32)
2236 {
2237 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2238 }
2239
2240 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002241}
2242
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002243void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2244{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002245 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002246
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002247 ValidateNumInputs(workloadInfo, descriptorName, 2);
2248 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2249
2250 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2251 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2252 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2253
2254 std::vector<DataType> supportedTypes =
2255 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002256 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002257 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002258 DataType::Float32,
2259 DataType::QAsymmS8,
2260 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002261 DataType::QSymmS16,
2262 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002263 };
2264
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002265 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2266 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2267 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002268
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002269 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2270 inputTensorInfo1,
2271 outputTensorInfo,
2272 descriptorName,
2273 "input_0",
2274 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002275}
2276
David Beckc2044fe2018-09-05 15:00:38 +01002277void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2278{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002279 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002280
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002281 ValidateNumInputs(workloadInfo, descriptorName, 2);
2282 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2283
2284 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2285 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2286 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2287
2288 std::vector<DataType> supportedTypes =
2289 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002290 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002291 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002292 DataType::Float32,
2293 DataType::QAsymmS8,
2294 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002295 DataType::QSymmS16,
2296 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002297 };
2298
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002299 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2300 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2301 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002302
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002303 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2304 inputTensorInfo1,
2305 outputTensorInfo,
2306 descriptorName,
2307 "input_0",
2308 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002309}
2310
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002311void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2312{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002313 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002314
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002315 ValidateNumInputs(workloadInfo, descriptorName, 2);
2316 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2317
2318 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2319 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2320 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2321
2322 std::vector<DataType> supportedTypes =
2323 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002324 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002325 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002326 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002327 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002328 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002329 DataType::QSymmS16,
2330 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002331 };
2332
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002333 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2334 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2335 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002336
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002337 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2338 inputTensorInfo1,
2339 outputTensorInfo,
2340 descriptorName,
2341 "input_0",
2342 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002343}
2344
narpra01a6bf9122018-09-10 09:50:09 +01002345void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2346{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002347 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002348
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002349 ValidateNumInputs(workloadInfo, descriptorName, 1);
2350 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2351
2352 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2353 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002354
2355 std::vector<DataType> supportedTypes =
2356 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002357 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002358 DataType::Float32,
2359 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002360 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002361 DataType::QAsymmU8,
2362 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002363 };
narpra01eb061912018-09-10 17:35:27 +01002364
James Conroy4d1ff582019-06-10 17:06:39 +01002365 // First check if input tensor data type is supported, then
2366 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002367 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2368 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002369
narpra0132b90462018-09-13 11:07:48 +01002370 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002371 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002372 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002373 }
narpra0132b90462018-09-13 11:07:48 +01002374 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002375 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002376 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002377 }
2378 else
2379 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002380 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002381 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002382 ValidateTensorNumDimensions(outputTensorInfo,
2383 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002384 outputDim > 0 ? outputDim : 1,
2385 "output");
2386 }
narpra01a6bf9122018-09-10 09:50:09 +01002387}
2388
jimfly012c9322a2018-09-19 10:59:49 +01002389void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2390{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002391 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002392
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002393 ValidateNumInputs(workloadInfo, descriptorName, 1);
2394 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2395
2396 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2397 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002398
jimfly012c9322a2018-09-19 10:59:49 +01002399 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002400 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2401
jimfly012c9322a2018-09-19 10:59:49 +01002402 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002403 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2404 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2405 "as there are dimensions in the input tensor that is " +
2406 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2407 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002408 }
2409}
2410
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002411void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2412{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002413 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002414
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002415 ValidateNumInputs(workloadInfo, descriptorName, 1);
2416 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002417
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002418 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2419 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2420
Sadik Armagan2208b602019-07-31 16:36:27 +01002421 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002422 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002423 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002424 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002425 DataType::Float16,
2426 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002427 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002428 DataType::QAsymmU8,
2429 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002430 };
2431
2432 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002433
Keith Davis0c2eeac2020-02-11 16:51:50 +00002434 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002435 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002436 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002437 }
2438}
2439
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002440void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2441{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002442 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002443
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002444 ValidateNumInputs(workloadInfo, descriptorName, 1);
2445 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002446
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002447 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2448 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002449
2450 std::vector<DataType> supportedTypes =
2451 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002452 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002453 DataType::Float32,
2454 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002455 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002456 DataType::QAsymmU8,
2457 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002458 };
2459
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002460 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2461 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002462}
2463
Conor Kennedy430b5d82018-11-14 15:28:28 +00002464void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2465{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002466 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002467
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002468 ValidateNumInputs(workloadInfo, descriptorName, 1);
2469 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2470
2471 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2472 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002473
2474 std::vector<DataType> supportedTypes =
2475 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002476 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002477 DataType::Float16,
2478 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002479 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002480 DataType::QAsymmU8,
2481 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002482 };
2483
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002484 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2485 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002486
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002487 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002488
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002489 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002490 if (rank > 4)
2491 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002492 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002493 }
2494
Conor Kennedy430b5d82018-11-14 15:28:28 +00002495 // Begin, End & Stride length must be of rank(input0)
2496 if (m_Parameters.m_Begin.size() != rank)
2497 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002498 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002499 }
2500
2501 if (m_Parameters.m_End.size() != rank)
2502 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002503 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002504 }
2505
2506 if (m_Parameters.m_Stride.size() != rank)
2507 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002508 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002509 }
2510
2511 // Stride entries must be non-zero
2512 for (auto& stride : m_Parameters.m_Stride)
2513 {
2514 if (stride == 0)
2515 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002516 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002517 }
2518 }
2519}
2520
kevmay0190539692018-11-29 08:40:19 +00002521void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2522{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002523 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002524
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002525 ValidateNumInputs(workloadInfo, descriptorName, 2);
2526 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2527
2528 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2529 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2530 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2531
2532 std::vector<DataType> supportedTypes =
2533 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002534 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002535 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002536 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002537 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002538 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002539 DataType::QSymmS16,
2540 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002541 };
2542
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002543 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2544 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2545 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002546
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002547 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2548 inputTensorInfo1,
2549 outputTensorInfo,
2550 descriptorName,
2551 "input_0",
2552 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002553}
2554
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002555void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2556{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002557 const std::string descriptorName{"DebugQueueDescriptor"};
2558
2559 ValidateNumInputs(workloadInfo, descriptorName, 1);
2560 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002561}
2562
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002563void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2564{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002565 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002566
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002567 ValidateNumInputs(workloadInfo, descriptorName, 2);
2568 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002569
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002570 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2571 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2572 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2573
2574 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2575 inputTensorInfo1,
2576 outputTensorInfo,
2577 descriptorName,
2578 "input_0",
2579 "input_1");
2580
2581 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002582 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002583 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002584 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002585}
2586
FrancisMurtagh878f0232018-12-19 10:56:15 +00002587void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2588{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002589 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002590
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002591 ValidateNumInputs(workloadInfo, descriptorName, 2);
2592 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002593
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002594 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2595 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2596 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2597
2598 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2599 inputTensorInfo1,
2600 outputTensorInfo,
2601 descriptorName,
2602 "input_0",
2603 "input_1");
2604
2605 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002606 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002607 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002608 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002609}
2610
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002611void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2612{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002613 const std::string descriptorName{"RsqrtQueueDescriptor"};
2614
2615 ValidateNumInputs(workloadInfo, descriptorName, 1);
2616 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2617
2618 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2619 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2620
2621 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002622
2623 std::vector<DataType> supportedTypes =
2624 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002625 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002626 DataType::Float16,
2627 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002628 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002629 DataType::QAsymmU8,
2630 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002631 };
2632
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002633 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2634 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002635}
2636
narpra01b89b05f2019-01-16 09:53:09 +00002637void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2638{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002639 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002640
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002641 ValidateNumInputs(workloadInfo, descriptorName, 2);
2642 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002643
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002644 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2645 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002646 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002647 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002648 }
2649
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002650 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2651 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2652
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002653 std::vector<DataType> supportedTypes =
2654 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002655 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002656 DataType::Float16,
2657 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002658 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002659 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002660 DataType::QSymmS16,
2661 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002662 };
2663
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002664 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002665
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002666 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002667
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002668 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2669 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002670}
2671
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002672void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2673{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002674 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2675
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002676 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002677
2678 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2679 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002680 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002681 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2682 }
2683
2684 if (m_Anchors == nullptr)
2685 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002686 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002687 }
2688
2689 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002690 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2691 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2692
2693 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002694 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002695 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2696 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002697
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002698 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2699 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2700 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002701
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002702 const std::vector<DataType> supportedInputTypes =
2703 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002704 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002705 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002706 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002707 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002708 DataType::QAsymmU8,
2709 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002710 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002711
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002712 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2713 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2714 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2715
2716 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2717 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2718 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2719 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2720
2721 // NOTE: Output is always Float32 regardless of input type
2722 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2723 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2724 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2725 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002726
2727 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2728 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002729 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002730 "must be positive and less than or equal to 1.");
2731 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002732
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002733 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2734 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002735 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002736 "should be equal to number of classes + 1.");
2737 }
2738}
2739
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002740void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2741{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002742 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002743
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002744 ValidateNumInputs(workloadInfo, descriptorName, 1);
2745 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2746
2747 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2748 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2749
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002750 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002751 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002752 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002753 }
2754
Sadik Armagan2208b602019-07-31 16:36:27 +01002755 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002756 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002757 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002758 DataType::Float32,
2759 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002760 };
2761
2762 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002763}
2764
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002765void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2766{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002767 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002768
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002769 ValidateNumInputs(workloadInfo, descriptorName, 2);
2770 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002771
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002772 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2773 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2774 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002775
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002776 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2777 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2778
2779 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2780 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002781}
2782
Keith Davis3ae3f972021-05-21 16:33:48 +01002783void ShapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2784{
2785 const std::string& descriptorName{"ShapeQueueDescriptor"};
2786
2787 ValidateNumInputs(workloadInfo, descriptorName, 1);
2788 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2789
2790 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2791 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2792
2793 std::vector<DataType> supportedTypes =
2794 {
2795 DataType::BFloat16,
2796 DataType::Float16,
2797 DataType::Float32,
2798 DataType::QAsymmS8,
2799 DataType::QAsymmU8,
2800 DataType::QAsymmS8,
2801 DataType::QSymmS8,
2802 DataType::QSymmS16,
2803 DataType::Signed32
2804 };
2805
2806 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2807 ValidateDataTypes(outputTensorInfo, {DataType::Signed32}, descriptorName);
2808}
2809
Sadik Armaganeff363d2019-04-05 15:25:46 +01002810void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2811{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002812 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002813
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002814 ValidateNumInputs(workloadInfo, descriptorName, 2);
2815 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2816
2817 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2818 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2819
2820 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2821 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2822
2823 std::vector<DataType> supportedTypes =
2824 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002825 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002826 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002827 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002828 DataType::QAsymmU8,
2829 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002830 };
2831
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002832 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2833 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002834
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002835 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2836 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002837
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002838 ValidateTensorShapesMatch(inputTensorInfo0,
2839 outputTensorInfo0,
2840 descriptorName,
2841 "input_0",
2842 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002843
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002844 ValidateTensorShapesMatch(inputTensorInfo0,
2845 outputTensorInfo1,
2846 descriptorName,
2847 "input_0",
2848 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002849}
2850
Derek Lamberti901ea112019-12-10 22:07:09 +00002851void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002852{
2853 // This is internally generated so it should not need validation.
2854}
2855
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002856void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2857{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002858 const std::string& descriptorName{"PreluQueueDescriptor"};
2859
2860 ValidateNumInputs(workloadInfo, descriptorName, 2);
2861 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2862
2863 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2864 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2865 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002866
2867 std::vector<DataType> supportedTypes
2868 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002869 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002870 DataType::Float16,
2871 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002872 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002873 DataType::QAsymmU8,
2874 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002875 };
2876
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002877 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2878 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002879
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002880 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002881
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002882 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2883 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002884
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002885 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2886 alphaTensorInfo,
2887 outputTensorInfo,
2888 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002889 "input",
2890 "alpha");
2891}
2892
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002893void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2894{
2895 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2896
2897 ValidateNumInputs(workloadInfo, descriptorName, 1);
2898 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2899
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002900 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2901 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2902
2903 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2904 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002905
2906 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002907
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002908 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2909 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002910
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002911 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2912
2913 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002914 if (m_Parameters.m_BiasEnabled)
2915 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002916 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002917
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002918 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2919 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002920
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002921 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002922 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002923 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002924
2925 ValidatePerAxisQuantization(inputTensorInfo,
2926 outputTensorInfo,
2927 weightTensorInfo,
2928 optionalBiasTensorInfo,
2929 descriptorName);
2930
2931 std::vector<DataType> supportedTypes =
2932 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002933 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002934 DataType::Float32,
2935 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002936 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002937 DataType::QAsymmU8,
2938 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002939 };
2940
2941 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2942 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002943}
2944
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002945void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2946{
2947 const std::string descriptorName{"TransposeQueueDescriptor"};
2948
2949 ValidateNumInputs(workloadInfo, descriptorName, 1);
2950 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2951
2952 const PermutationVector& mapping = m_Parameters.m_DimMappings;
2953
2954 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2955 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2956
2957 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
2958 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
2959
2960 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
2961 {
2962 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
2963 {
2964 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
2965 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
2966 "must match dst dimension " + to_string(i) +
2967 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
2968 }
2969 }
2970
2971 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2972}
2973
James Conroy4f1f8992020-04-29 20:01:10 +01002974void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2975{
2976 const std::string descriptorName{"QLstmQueueDescriptor"};
2977
2978 // Validate number of inputs/outputs
2979 ValidateNumInputs(workloadInfo, descriptorName, 3);
2980 ValidateNumOutputs(workloadInfo, descriptorName, 3);
2981
2982 // Input/output tensor info
2983 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2984 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
2985 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
2986
2987 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2988 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2989 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
2990
2991 // Supported types for various tensors in QLSTM
2992 std::vector<DataType> inputOutputSupportedTypes =
2993 {
2994 DataType::QAsymmS8
2995 };
2996
2997 std::vector<DataType> cellStateSupportedTypes =
2998 {
2999 DataType::QSymmS16
3000 };
3001
3002 std::vector<DataType> weightsSupportedTypes =
3003 {
3004 DataType::QSymmS8
3005 };
3006
3007 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3008 {
3009 DataType::QSymmS16
3010 };
3011
3012 std::vector<DataType> biasSupportedTypes =
3013 {
3014 DataType::Signed32
3015 };
3016
3017 // Validate types of input/output tensors
3018 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3019 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3020 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3021
3022 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3023 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3024 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3025
3026 // Validate matching types of input/output tensors
3027 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3028 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3029 "outputStateIn", "outputStateOut");
3030 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3031
3032 // Infer number of batches, number of units, input size and output size from tensor dimensions
3033 const uint32_t numBatches = inputInfo.GetShape()[0];
3034 const uint32_t inputSize = inputInfo.GetShape()[1];
3035 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3036 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3037
3038 // Validate number of dimensions and number of elements for input/output tensors
3039 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3040 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3041 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3042
3043 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3044 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3045 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3046
3047 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3048 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3049 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3050 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3051
3052 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3053 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3054 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3055
3056 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3057 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3058 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3059
3060 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3061 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3062 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3063 " RecurrentToForgetWeights");
3064
3065 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3066 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3067 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3068
3069 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3070 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3071 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3072
3073 // Validate data types for MANDATORY weights tensors (all should match each other)
3074 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3075
3076 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3077 "inputToForgetWeights", "inputToCellWeights");
3078 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3079 "inputToForgetWeights", "inputToOutputWeights");
3080
3081 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3082 "inputToForgetWeights", "recurrentToForgeteights");
3083 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3084 "inputToForgetWeights", "recurrentToCellWeights");
3085 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3086 "inputToForgetWeights", "recurrentToOutputWeights");
3087
3088 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3089 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3090 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3091 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3092
3093 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3094 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3095 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3096
3097 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3098 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3099 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3100
3101 // Validate data types for MANDATORY bias tensors
3102 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3103
3104 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3105 "forgetGateBias", "cellBias");
3106 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3107 "forgetGateBias", "outputGateBias");
3108
3109 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3110 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3111 !m_Parameters.m_CifgEnabled) ||
3112 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3113 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3114
3115 if (!allCifgParamsPresentOrNot)
3116 {
3117 throw InvalidArgumentException(descriptorName +
3118 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3119 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3120 "set appropriately.");
3121 }
3122
3123 if (!m_Parameters.m_CifgEnabled)
3124 {
3125 // Validate number of dimensions and number of elements
3126 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3127 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3128
3129 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3130 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3131 " RecurrentToInputWeights");
3132
3133 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3134 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3135
3136 // Validate data types
3137 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3138 "inputToForgetWeights", "inputToInputWeights");
3139 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3140 "inputToForgetWeights", "recurrentToInputWeights");
3141 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3142 "forgetGateBias", "inputGateBias");
3143 }
3144
3145 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3146 bool allPeepholeWeightsPresentOrNot =
3147 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3148 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3149 || (!m_CellToInputWeights && !m_CellToForgetWeights
3150 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3151
3152 if (!allPeepholeWeightsPresentOrNot)
3153 {
3154 throw InvalidArgumentException(descriptorName +
3155 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3156 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3157 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3158 "appropriately.");
3159 }
3160
3161 if (m_Parameters.m_PeepholeEnabled)
3162 {
3163 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3164 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3165 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3166
3167 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3168 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3169 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3170 "cellToForgetWeight", "cellToOutputWeights");
3171
3172 if (!m_Parameters.m_CifgEnabled)
3173 {
3174 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3175 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3176 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3177 "cellToForgetWeights", "cellToInputWeights");
3178 }
3179 }
3180
3181 // Validate OPTIONAL params: Layer Norm Weights
3182 bool allLayerNormWeightsPresentOrNot =
3183 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3184 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3185 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3186 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3187
3188 if (!allLayerNormWeightsPresentOrNot)
3189 {
3190 throw InvalidArgumentException(descriptorName +
3191 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3192 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3193 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3194 "only be present when Layer Norm is enabled and CIFG is disabled. "
3195 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3196 }
3197
3198 if (m_Parameters.m_LayerNormEnabled)
3199 {
3200 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3201 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3202 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3203
3204 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3205 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3206 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3207 "forgetLayerNormWeights", "cellLayerNormWeights");
3208
3209 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3210 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3211 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3212 "forgetLayerNormWeights", "outputLayerNormWeights");
3213
3214 if (!m_Parameters.m_CifgEnabled)
3215 {
3216 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3217 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3218 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3219 "forgetLayerNormWeights", "inputLayerNormWeights");
3220 }
3221 }
3222
3223 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3224 bool correctProjectionTensorsPresent =
3225 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3226 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3227 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3228
3229 if (!correctProjectionTensorsPresent)
3230 {
3231 throw InvalidArgumentException(descriptorName +
3232 ": If projection is enabled, ProjectionWeights should be present and "
3233 "ProjectionBias is optional. If projection is disabled, neither "
3234 "ProjectionWeights nor ProjectionBias should be present.");
3235 }
3236
3237 if (m_Parameters.m_ProjectionEnabled)
3238 {
3239 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3240 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3241 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3242
3243 if (m_ProjectionBias)
3244 {
3245 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003246 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003247 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3248 }
3249
3250 }
3251 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3252 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3253 throw InvalidArgumentException(descriptorName +
3254 ": If projection is disabled, output quantization info (scale, offset) "
3255 "should match HiddenStateScale and HiddenStateZeroPoint.");
3256 }
3257
3258}
3259
James Conroy9c3cae82019-08-01 16:01:48 +01003260void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3261{
3262 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3263
3264 // Validate number of inputs/outputs
3265 ValidateNumInputs(workloadInfo, descriptorName, 3);
3266 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3267
3268 // Input/output tensor infos
3269 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3270 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3271 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3272
3273 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3274 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3275
3276 std::vector<DataType> inputOutputSupportedTypes =
3277 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003278 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003279 };
3280
3281 std::vector<DataType> cellStateSupportedTypes =
3282 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003283 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003284 };
3285
3286 std::vector<DataType> weightsSupportedTypes =
3287 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003288 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003289 };
3290
3291 std::vector<DataType> biasSupportedTypes =
3292 {
3293 DataType::Signed32
3294 };
3295
3296 // Validate types of input/output tensors
3297 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3298 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3299 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3300
3301 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3302 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3303
3304 // Validate matching types of input/output tensors
3305 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3306 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3307 "outputStateIn", "outputStateOut");
3308 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3309
3310 // Validate matching quantization info for input/output tensors
3311 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3312 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3313 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003314
James Conroy9c3cae82019-08-01 16:01:48 +01003315 // Infer number of batches, input size and output size from tensor dimensions
3316 const uint32_t numBatches = inputInfo.GetShape()[0];
3317 const uint32_t inputSize = inputInfo.GetShape()[1];
3318 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3319
3320 // Validate number of dimensions and number of elements for input/output tensors
3321 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3322 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3323 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3324 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3325 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3326
3327 // Validate number of dimensions and number of elements for weights tensors
3328 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3329 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3330 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3331
3332 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3333 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3334 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3335
3336 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3337 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3338 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3339
3340 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3341 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3342 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3343
3344 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3345 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3346 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3347
3348 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3349 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3350 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3351 " RecurrentToForgetWeights");
3352
3353 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3354 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3355 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3356
3357 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3358 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3359 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3360
3361 // Validate data types for weights tensors (all should match each other)
3362 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3363
3364 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3365 "inputToInputWeights", "inputToForgetWeights");
3366 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3367 "inputToInputWeights", "inputToCellWeights");
3368 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3369 "inputToInputWeights", "inputToOutputWeights");
3370
3371 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3372 "inputToInputWeights", "recurrentToInputWeights");
3373 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3374 "inputToInputWeights", "recurrentToForgeteights");
3375 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3376 "inputToInputWeights", "recurrentToCellWeights");
3377 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3378 "inputToInputWeights", "recurrentToOutputWeights");
3379
3380 // Validate matching quantization info for weight tensors (all should match each other)
3381 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3382 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3383 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3384 descriptorName, "inputToInputWeights", "inputToCellWeights");
3385 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3386 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3387
3388 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3389 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3390 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3391 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3392 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3393 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3394 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3395 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3396
3397 // Validate number of dimensions and number of elements in bias tensors
3398 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3399 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3400 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3401
3402 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3403 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3404 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3405
3406 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3407 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3408 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3409
3410 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3411 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3412 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3413
3414 // Validate data types for bias tensors (all should match each other)
3415 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3416
3417 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3418 "inputGateBias", "forgetGateBias");
3419 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3420 "inputGateBias", "cellBias");
3421 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3422 "inputGateBias", "outputGateBias");
3423
3424 // Validate bias tensor quantization info
3425 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3426 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3427 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3428 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3429}
3430
Kevin May868eb142019-09-04 17:29:31 +01003431void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3432{
3433 const std::string descriptorName{"AbsQueueDescriptor"};
3434
3435 ValidateNumInputs(workloadInfo, descriptorName, 1);
3436 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3437
3438 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3439 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3440
3441 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3442
3443 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003444 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003445 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003446 DataType::Float16,
3447 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003448 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003449 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003450 DataType::QSymmS16,
3451 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003452 };
Kevin May868eb142019-09-04 17:29:31 +01003453
3454 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3455 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3456}
3457
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003458void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3459{
3460 const std::string descriptorName{"SliceQueueDescriptor"};
3461
3462 ValidateNumInputs(workloadInfo, descriptorName, 1);
3463 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3464
3465 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3466 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3467
3468 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3469
3470 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3471 if (rank > 4)
3472 {
3473 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3474 }
3475
3476 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3477
3478 // Check if m_Begin and m_Size have the expected length
3479 if (m_Parameters.m_Begin.size() != rank)
3480 {
3481 throw InvalidArgumentException(descriptorName +
3482 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3483 }
3484 if (m_Parameters.m_Size.size() != rank)
3485 {
3486 throw InvalidArgumentException(descriptorName +
3487 ": Length of size descriptor must equal rank " + std::to_string(rank));
3488 }
3489
3490 // Check if the shape of the output tensor matches m_Size
3491 const TensorShape& outputShape = outputTensorInfo.GetShape();
3492 for (unsigned int i = 0u; i < rank; ++i)
3493 {
3494 if (m_Parameters.m_Size[i] != outputShape[i])
3495 {
3496 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3497 }
3498 }
3499
3500 // Check if the sum of begin offset and size in a given dimension
3501 // does not exceed the size of corresponding input
3502 const TensorShape& inputShape = inputTensorInfo.GetShape();
3503 for(unsigned int i = 0u; i < rank; ++i)
3504 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003505 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003506 {
3507 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3508 std::to_string(i) + " exceeds input size.");
3509 }
3510 }
3511}
3512
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003513void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3514{
3515 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3516
3517 ValidateNumInputs(workloadInfo, descriptorName, 1);
3518 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3519
3520 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3521 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3522
3523 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3524 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3525
3526 std::vector<DataType> supportedTypes =
3527 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003528 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003529 DataType::Float32,
3530 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003531 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003532 DataType::QAsymmU8,
3533 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003534 };
3535
3536 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3537 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3538
3539 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3540
3541 if (m_Parameters.m_BlockSize == 0)
3542 {
3543 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3544 }
3545
3546 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3547 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3548 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3549 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3550
3551 const TensorShape& outputShape = outputInfo.GetShape();
3552 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3553 {
3554 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3555 "must be divisible by block size.");
3556 }
3557
3558 const TensorShape& inputShape = inputInfo.GetShape();
3559 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3560 {
3561 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3562 "must be divisible by the square of block size." );
3563 }
3564}
3565
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003566void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3567{
3568 const std::string descriptorName{"ComparisonQueueDescriptor"};
3569
3570 ValidateNumInputs(workloadInfo, descriptorName, 2);
3571 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3572
3573 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3574 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3575 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3576
3577 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3578 inputTensorInfo1,
3579 outputTensorInfo,
3580 descriptorName,
3581 "input_0",
3582 "input_1");
3583
3584 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3585 {
3586 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3587 }
3588}
3589
josh minor4a3c6102020-01-06 16:40:46 -06003590void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3591{
3592 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3593
3594 ValidateNumInputs(workloadInfo, descriptorName, 1);
3595 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3596
3597 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3598 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3599
3600 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3601
3602 std::vector<DataType> supportedTypes =
3603 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003604 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003605 DataType::Float16,
3606 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003607 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003608 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003609 DataType::QSymmS16,
3610 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003611 };
3612
James Conroyaba90cd2020-11-06 16:28:18 +00003613 std::vector<DataType> logicalSupportedTypes =
3614 {
3615 DataType::Boolean
3616 };
3617
3618 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3619 {
3620 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3621 }
3622 else
3623 {
3624 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3625 }
3626
3627
josh minor4a3c6102020-01-06 16:40:46 -06003628 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3629}
3630
Finn Williams2605b232020-06-10 15:53:46 +01003631void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3632{
3633 const std::string descriptorName{"RankQueueDescriptor"};
3634
3635 ValidateNumInputs(workloadInfo, descriptorName, 1);
3636 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3637
3638 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3639 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3640
3641 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3642 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3643
3644 std::vector<DataType> supportedTypes =
3645 {
3646 DataType::BFloat16,
3647 DataType::Float16,
3648 DataType::Float32,
3649 DataType::QAsymmS8,
3650 DataType::QAsymmU8,
3651 DataType::QSymmS8,
3652 DataType::QSymmS16,
3653 DataType::Signed32
3654 };
3655
3656 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3657 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3658}
3659
James Conroyaba90cd2020-11-06 16:28:18 +00003660void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3661{
3662 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3663
3664 ValidateNumInputs(workloadInfo, descriptorName, 2);
3665 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3666
3667 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3668 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3669 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3670
3671 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3672 inputTensorInfo1,
3673 outputTensorInfo,
3674 descriptorName,
3675 "input_0",
3676 "input_1");
3677
3678 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3679 {
3680 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3681 }
3682
3683 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3684 {
3685 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3686 }
3687
3688 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3689 {
3690 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3691 }
3692}
3693
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003694void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3695{
3696 const std::string descriptorName{"ReduceQueueDescriptor"};
3697
3698 ValidateNumInputs(workloadInfo, descriptorName, 1);
3699 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3700
3701 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3702 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3703
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003704 std::vector<DataType> supportedTypes =
3705 {
3706 DataType::BFloat16,
3707 DataType::Float16,
3708 DataType::Float32,
3709 DataType::QAsymmS8,
3710 DataType::QAsymmU8,
3711 DataType::QSymmS16,
3712 DataType::Signed32
3713 };
3714
3715 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3716 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3717}
3718
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01003719void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3720{
3721 // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm
3722
3723 const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"};
3724
3725 // check dimensions of all inputs and outputs
3726 if (workloadInfo.m_InputTensorInfos.size() != 3)
3727 {
3728 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
3729 }
3730 if (workloadInfo.m_OutputTensorInfos.size() != 1)
3731 {
3732 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
3733 }
3734
3735 std::vector<DataType> supportedTypes =
3736 {
3737 DataType::Float16,
3738 DataType::Float32,
3739 DataType::QAsymmS8
3740 };
3741
3742 // check for supported type of one input and match them with all the other input and output
3743 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
3744
3745 // type matches all other inputs
3746 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
3747 {
3748 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
3749 workloadInfo.m_InputTensorInfos[i],
3750 descriptorName,
3751 "input_0",
3752 "input_" + std::to_string(i));
3753 }
3754 // type matches all other outputs
3755 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
3756 {
3757 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
3758 workloadInfo.m_OutputTensorInfos[i],
3759 "LstmQueueDescriptor",
3760 "input_0",
3761 "output_" + std::to_string(i));
3762 }
3763
3764 // Making sure clipping parameters have valid values.
3765 // == 0 means no clipping
3766 // > 0 means clipping
3767 if (m_Parameters.m_ClippingThresCell < 0.0f)
3768 {
3769 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
3770 }
3771 if (m_Parameters.m_ClippingThresProj < 0.0f)
3772 {
3773 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
3774 }
3775
3776 unsigned int batchIndx = 0;
3777 unsigned int inputIndx = 1;
3778 uint32_t timeStep = 1;
3779 unsigned int timeIndx = 1;
3780 inputIndx = 2;
3781 if (m_Parameters.m_TimeMajor)
3782 {
3783 batchIndx = 1;
3784 timeIndx = 0;
3785
3786 }
3787 timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx];
3788
3789 // Inferring batch size, number of outputs and number of cells from the inputs.
3790 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx];
3791 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx];
3792 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
3793 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
3794 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
3795 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
3796
3797 // input tensor
3798 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input),
3799 descriptorName + " input_0");
3800 // outputStateInTensor
3801 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
3802 descriptorName + " input_1");
3803 // outputStateInTensor
3804 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
3805 descriptorName + " input_2");
3806
3807 // outputTensor
3808 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 3, (timeStep * n_batch * n_output),
3809 descriptorName + " output_0");
3810
3811 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
3812 if ( m_InputToInputWeights )
3813 {
3814 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
3815 (n_cell * n_input), "InputLayerNormWeights");
3816 }
3817
3818 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
3819 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
3820 (n_cell * n_input), "InputToForgetWeights");
3821
3822 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
3823 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
3824 (n_cell * n_input), "InputToCellWeights");
3825
3826 if ( m_RecurrentToInputWeights )
3827 {
3828 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
3829 (n_cell * n_output), "RecurrentToInputWeights");
3830 }
3831
3832 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
3833 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
3834 (n_cell * n_output), "RecurrentToForgetWeights");
3835
3836 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
3837 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
3838 (n_cell * n_output), "RecurrentToCellWeights");
3839
3840 // Make sure the input-gate's parameters are either both present (regular
3841 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
3842 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
3843 !m_Parameters.m_CifgEnabled) ||
3844 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3845 m_Parameters.m_CifgEnabled));
3846 if (!cifg_weights_all_or_none)
3847 {
3848 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
3849 "RecurrentToInputWeights must either both be present (regular LSTM) "
3850 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
3851 "accordingly.");
3852 }
3853
3854 if ( m_CellToInputWeights )
3855 {
3856 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
3857 n_cell, "CellToInputWeights");
3858 }
3859 if ( m_CellToForgetWeights )
3860 {
3861 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
3862 n_cell, "CellToForgetWeights");
3863 }
3864 if ( m_CellToOutputWeights )
3865 {
3866 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
3867 n_cell, "CellToOutputWeights");
3868 }
3869
3870 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
3871 bool peephole_weights_all_or_none =
3872 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3873 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3874 || ( !m_CellToInputWeights && !m_CellToForgetWeights
3875 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3876 if (!peephole_weights_all_or_none)
3877 {
3878 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
3879 }
3880
3881 // Make sure the input gate bias is present only when not a CIFG-LSTM.
3882 if (m_Parameters.m_CifgEnabled)
3883 {
3884 if (m_InputGateBias)
3885 {
3886 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
3887 }
3888 }
3889 else
3890 {
3891 if (!m_InputGateBias)
3892 {
3893 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
3894 "must be present.");
3895 }
3896 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
3897 n_cell, "InputGateBias");
3898 }
3899
3900 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
3901 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
3902
3903 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
3904 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
3905
3906 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
3907 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
3908
3909 if (m_ProjectionWeights)
3910 {
3911 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
3912 (n_cell * n_output), "ProjectionWeights");
3913 }
3914 if (m_ProjectionBias)
3915 {
3916 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
3917 }
3918
3919 // Making sure the projection tensors are consistent:
3920 // 1) If projection weight is not present, then projection bias should not be
3921 // present.
3922 // 2) If projection weight is present, then projection bias is optional.
3923 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
3924 !m_Parameters.m_ProjectionEnabled)
3925 || (m_ProjectionWeights && !m_ProjectionBias &&
3926 m_Parameters.m_ProjectionEnabled)
3927 || (m_ProjectionWeights && m_ProjectionBias &&
3928 m_Parameters.m_ProjectionEnabled));
3929 if (!projecton_tensors_consistent)
3930 {
3931 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
3932 }
3933
3934 // The four layer normalization weights either all have values or none of them have values. Additionally, if
3935 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
3936 // either all have values or none of them have values. Layer normalization is used when the values of all the
3937 // layer normalization weights are present
3938 if (m_InputLayerNormWeights)
3939 {
3940 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
3941 }
3942 if (m_ForgetLayerNormWeights)
3943 {
3944 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
3945 }
3946 if (m_CellLayerNormWeights)
3947 {
3948 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
3949 }
3950 if (m_OutputLayerNormWeights)
3951 {
3952 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
3953 }
3954
3955 if (m_Parameters.m_LayerNormEnabled)
3956 {
3957 if (!m_Parameters.m_CifgEnabled)
3958 {
3959 if (!m_InputLayerNormWeights)
3960 {
3961 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
3962 "disabled but InputLayerNormWeights are not present");
3963 }
3964 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
3965 1, n_cell, "InputLayerNormWeights");
3966 }
3967 else if (m_InputLayerNormWeights)
3968 {
3969 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
3970 "enabled");
3971 }
3972
3973 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
3974 "ForgetLayerNormWeights");
3975 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
3976
3977 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
3978 "OutputLayerNormWeights");
3979 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
3980
3981 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
3982 "CellLayerNormWeights");
3983 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
3984 }
3985 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
3986 {
3987 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
3988 "normalisation weights are present.");
3989 }
3990}
3991
3992
mathad01df9a3222021-04-28 11:42:57 +01003993} // namespace armnn