blob: 44a6a17b372758f4aedcbd65134ee4377f71826c [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
James Conroy1f58f032021-04-27 17:13:27 +01006#include <backendsCommon/TensorHandle.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00007#include <backendsCommon/WorkloadData.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/TensorUtils.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010010#include <armnn/utility/NumericCast.hpp>
mathad01df9a3222021-04-28 11:42:57 +010011#include <armnn/Logging.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000012
telsoa014fcda012018-03-09 14:13:49 +000013#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000015#include <string>
16#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000017
James Ward47fce872020-09-10 11:57:28 +010018#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000019
Matteo Martincigh21350152018-11-28 16:22:22 +000020using namespace armnnUtils;
21
telsoa014fcda012018-03-09 14:13:49 +000022namespace armnn
23{
24
25//---------------------------------------------------------------
26DataType GetBiasDataType(DataType inputDataType)
27{
28 switch (inputDataType)
29 {
telsoa01c577f2c2018-08-31 09:22:23 +010030 case DataType::Float16:
31 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000032 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000033 case DataType::Float32:
34 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000035 case DataType::QAsymmS8:
36 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000037 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000038 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000039 case DataType::QSymmS8:
40 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000041 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010042 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000043 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010044 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000045 return DataType::Float32;
46 }
47}
48
49namespace
50{
51
52//---------------------------------------------------------------
53//android ndk does not support std::to_string function.
54template <typename T>
55std::string to_string(T value)
56{
57 std::ostringstream os;
58 os << value;
59 return os.str();
60}
61
62//---------------------------------------------------------------
63void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
64{
65 if (!ptr)
66 {
67 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
68 paramName + " parameter must be set.");
69 }
70}
71
72//---------------------------------------------------------------
73void ValidateTensorShapesMatch(const TensorInfo& first,
74 const TensorInfo& second,
75 std::string const& descName,
76 std::string const& firstName,
77 std::string const& secondName)
78{
79 if (first.GetShape() != second.GetShape())
80 {
81 throw InvalidArgumentException(descName + ": "
82 + firstName + " & " + secondName + " must have identical shapes");
83 }
84}
85
86//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010087void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000088{
Sadik Armaganeff363d2019-04-05 15:25:46 +010089 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000090 {
91 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010092 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000093 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
94 }
95}
96
97//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010098void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000099{
Sadik Armaganeff363d2019-04-05 15:25:46 +0100100 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000101 {
102 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100103 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000104 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
105 }
106}
107
108//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100109void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000110 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100111 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000112 std::string const& tensorName)
113{
114 if (tensor.GetNumDimensions() != numDimensions)
115 {
116 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
117 to_string(tensor.GetNumDimensions()) + " dimensions for " +
118 tensorName + " tensor.");
119 }
120}
121
122//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100123void ValidateTensorNumElements(const TensorInfo& tensor,
124 std::string const& descName,
125 unsigned int numElements,
126 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100127{
128 if (tensor.GetNumElements() != numElements)
129 {
130 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100131 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100132 tensorName + " tensor.");
133 }
134}
135
136//---------------------------------------------------------------
137void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100138 unsigned int numDimension,
139 unsigned int numElements,
140 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100141{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100142 const std::string functionName{"ValidateTensorNumDimNumElem"};
143 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
144 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100145}
146
147//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000148void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
149 const std::string& descName, std::string const& tensorName)
150{
151 if (tensor.GetDataType() != dataType)
152 {
153 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
154 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
155 }
156}
157
Derek Lambertid466a542020-01-22 15:37:29 +0000158void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
159{
160 ARMNN_NO_DEPRECATE_WARN_BEGIN
161 if (tensor.GetDataType() != DataType::QSymmS8 &&
162 tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
163 {
164 throw InvalidArgumentException(descName +
165 ": Expected data type which supports per-axis quantization scheme but got " +
166 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
167 }
168 ARMNN_NO_DEPRECATE_WARN_END
169}
170
telsoa014fcda012018-03-09 14:13:49 +0000171//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100172void ValidateTensorQuantizationSpace(const TensorInfo& first,
173 const TensorInfo& second,
174 const std::string& descName,
175 std::string const& firstName,
176 std::string const& secondName)
177{
178 if (!first.IsQuantized() ||
179 !second.IsQuantized())
180 {
181 // Not a quantized type, ignore the validation
182 return;
183 }
184
185 DataType firstDataType = first.GetDataType();
186 DataType secondDataType = second.GetDataType();
187
188 if (firstDataType != secondDataType)
189 {
190 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
191 " must be of the same quantized type, " +
192 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
193 secondName + " is " + GetDataTypeName(secondDataType));
194 }
195
196 if (!first.IsTypeSpaceMatch(second))
197 {
198 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
199 " must have the same quantization space, " +
200 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
201 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
202 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
203 " and scale " + to_string(second.GetQuantizationScale()));
204 }
205}
206
207//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100208void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
209 const TensorInfo& inputTensorInfo,
210 const TensorInfo& weightsTensorInfo,
211 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000212{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000213 // Helper lambda function to validate a single bias quantization scale value
214 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
215 {
mathad01df9a3222021-04-28 11:42:57 +0100216 constexpr float tolerance = 0.0001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000217 if (std::abs(biasScale - expectedScale) > tolerance)
218 {
219 // Print the float values with extra precision to see very small differences
mathad01df9a3222021-04-28 11:42:57 +0100220 ARMNN_LOG(warning) << std::setprecision(6) << descName << ": Expected " << expectedScale <<
221 " for bias quantization scale (product of input and weight scales), but got " <<
222 biasScale << ". Using scale provided.";
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000223 }
224 };
225
telsoa014fcda012018-03-09 14:13:49 +0000226 if (biasTensor.GetQuantizationOffset() != 0)
227 {
228 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
229 to_string(biasTensor.GetQuantizationOffset()));
230 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000231
James Conroy8502ade2020-11-12 19:26:29 +0000232 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000233 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000234 // Validate per-axis quantization scales
235 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
236 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
237
238 if (weightScales.size() != biasScales.size())
239 {
240 std::stringstream msg;
James Conroy8502ade2020-11-12 19:26:29 +0000241 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
242 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
243 << ", biases=" << biasScales.size();
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000244 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
245 }
246
247 for (size_t i = 0ul; i < biasScales.size(); ++i)
248 {
249 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
250 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
251 }
252 }
253 else
254 {
255 // Validate per-tensor quantization scale
256 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
257 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000258 }
259}
260
261//---------------------------------------------------------------
262void ValidateTensors(const std::vector<ITensorHandle*>& vec,
263 unsigned int numExpected,
264 const std::string& descName,
265 const std::string& varName)
266{
267 if (vec.empty() && numExpected > 0)
268 {
269 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
270 }
271
272 for (unsigned int i = 0; i < numExpected; ++i)
273 {
274 if (!vec[i])
275 {
276 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
277 }
278 }
279}
280
281//---------------------------------------------------------------
282void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
283 const TensorInfo& second,
284 const TensorInfo& output,
285 std::string const& descName,
286 std::string const& firstName,
287 std::string const& secondName)
288{
289 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
290 // broadcasted.
291 if (first.GetNumDimensions() != second.GetNumDimensions())
292 {
293 throw InvalidArgumentException(descName + ": Tensors "
294 + firstName + " & " + secondName
295 + " must have the same number of dimensions in order to be broadcasted");
296 }
297 uint32_t numDims = first.GetNumDimensions();
298 std::vector<uint32_t> outputDims(numDims, 0u);
299 for (uint32_t i = 0; i < numDims; i++)
300 {
301 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
302 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
303 if (dimsNotEqual && dimsNotOne)
304 {
305 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
306 }
307 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
308 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100309 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000310 if (broadcastShape != output.GetShape())
311 {
312 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
313 + firstName + " & " + secondName
314 + " does not match the output shape");
315 }
316}
317
318//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100319void ValidateDataTypes(const TensorInfo& info,
320 const std::vector<armnn::DataType>& supportedTypes,
321 std::string const& descName)
322{
323 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
324 if (iterator == supportedTypes.end())
325 {
326 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
327 }
328}
329
James Conroy4d1ff582019-06-10 17:06:39 +0100330//---------------------------------------------------------------
331void ValidateTensorDataTypesMatch(const TensorInfo& first,
332 const TensorInfo& second,
333 std::string const& descName,
334 std::string const& firstName,
335 std::string const& secondName)
336{
337 if (first.GetDataType() != second.GetDataType())
338 {
339 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
340 " must have identical data types.");
341 }
342}
343
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100344//---------------------------------------------------------------
345void ValidateTensorNumElementsMatch(const TensorInfo& first,
346 const TensorInfo& second,
347 std::string const& descName,
348 std::string const& firstName,
349 std::string const& secondName)
350{
351 if (first.GetNumElements() != second.GetNumElements())
352 {
353 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
354 " must have the same number of elements.");
355 }
356}
357
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000358void ValidateWeightDataType(const TensorInfo& inputInfo,
359 const TensorInfo& weightInfo,
360 const std::string& descName)
361{
362 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000363 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000364 {
Derek Lambertid466a542020-01-22 15:37:29 +0000365 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000366 const std::vector<DataType> validTypes =
367 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000368 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100369 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000370 DataType::QSymmS8,
371 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000372 };
Derek Lambertid466a542020-01-22 15:37:29 +0000373 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000374
375 ValidateDataTypes(weightInfo, validTypes, descName);
376 }
377 else
378 {
379 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
380 }
381}
382
383void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
384 const std::string& descName,
385 const std::string& tensorName)
386{
387 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
388 if (!quantizationDim.has_value())
389 {
James Ward47fce872020-09-10 11:57:28 +0100390 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
391 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000392 }
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000393}
394
395void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
396 const std::string& descName,
397 const std::string& tensorName)
398{
399 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
400 if (quantizationOffset != 0)
401 {
James Ward47fce872020-09-10 11:57:28 +0100402 throw InvalidArgumentException(fmt::format(
403 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
404 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000405 }
406}
407
408void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
409 const TensorInfo& outputInfo,
410 const TensorInfo& weightInfo,
411 const Optional<TensorInfo>& optionalBiasInfo,
412 const std::string& descName)
413{
414 if (weightInfo.HasPerAxisQuantization())
415 {
416 const DataType inputDataType = inputInfo.GetDataType();
417 const DataType outputDataType = outputInfo.GetDataType();
418
Keith Davis0c2eeac2020-02-11 16:51:50 +0000419 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000420
421 if (!canHavePerAxisQuantization)
422 {
James Ward47fce872020-09-10 11:57:28 +0100423 throw InvalidArgumentException(fmt::format(
424 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
425 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000426 }
427
Derek Lambertid466a542020-01-22 15:37:29 +0000428
429 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000430 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
431 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
432
433 if (optionalBiasInfo.has_value())
434 {
435 const TensorInfo& biasInfo = optionalBiasInfo.value();
436 if (!biasInfo.HasPerAxisQuantization())
437 {
James Ward47fce872020-09-10 11:57:28 +0100438 throw InvalidArgumentException(fmt::format(
439 "{}: Per-axis quantization parameters not set on bias tensor, "
440 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000441 }
442
443 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
444 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
445 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
446 }
447 }
448}
449
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100450} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000451
452void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
453 unsigned int numExpectedIn, unsigned int numExpectedOut) const
454{
455 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
456 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
457}
458
459//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100460void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
461{
462 const std::string descriptorName{"MapQueueDescriptor"};
463
464 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100465 ValidateNumOutputs(workloadInfo, descriptorName, 0);
466
467 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
468 {
469 if (!m_Inputs[i])
470 {
471 throw InvalidArgumentException(
472 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
473 }
474 }
475}
476
477//---------------------------------------------------------------
478void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
479{
480 const std::string descriptorName{"UnmapQueueDescriptor"};
481
482 ValidateNumInputs(workloadInfo, descriptorName, 1);
483 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100484
485 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
486 {
487 if (!m_Inputs[i])
488 {
489 throw InvalidArgumentException(
490 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
491 }
492 }
493}
494
495//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000496void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
497{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100498 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000499
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100500 ValidateNumInputs(workloadInfo, descriptorName, 1);
501 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000502
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100503 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
504 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
505
506 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
507 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000508
509 if (m_Inputs.size() != m_Outputs.size())
510 {
James Ward47fce872020-09-10 11:57:28 +0100511 throw InvalidArgumentException(fmt::format(
512 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
513 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000514 }
515
516 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
517 {
518 if (!m_Inputs[i])
519 {
James Ward47fce872020-09-10 11:57:28 +0100520 throw InvalidArgumentException(fmt::format(
521 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000522 }
523
524 if (!m_Outputs[i])
525 {
James Ward47fce872020-09-10 11:57:28 +0100526 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000527 }
528 }
529}
530
Derek Lambertif674aa02019-08-01 15:56:25 +0100531//---------------------------------------------------------------
532void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
533{
534 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
535 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
536
537 if (workloadInfo.m_InputTensorInfos.size() != 1)
538 {
James Ward47fce872020-09-10 11:57:28 +0100539 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
540 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100541
542 }
543
544 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
545 {
James Ward47fce872020-09-10 11:57:28 +0100546 throw InvalidArgumentException(fmt::format(
547 "Number of input infos ({0}) does not match the number of output infos ({1})",
548 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100549 }
550
551 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
552 {
553 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
554 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
555 {
James Ward47fce872020-09-10 11:57:28 +0100556 throw InvalidArgumentException(fmt::format(
557 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100558 }
559 }
560
561 if (m_Inputs.size() != 1)
562 {
James Ward47fce872020-09-10 11:57:28 +0100563 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100564 }
565
566 if (m_Inputs.size() != m_Outputs.size())
567 {
James Ward47fce872020-09-10 11:57:28 +0100568 throw InvalidArgumentException(fmt::format(
569 "Number of inputs ({0}) does not match the number of outputs ({1})",
570 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100571 }
572
573 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
574 {
575 if (!m_Inputs[i])
576 {
James Ward47fce872020-09-10 11:57:28 +0100577 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100578 }
579
580 if (!m_Outputs[i])
581 {
James Ward47fce872020-09-10 11:57:28 +0100582 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100583 }
584 }
585}
586
587//---------------------------------------------------------------
588void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
589{
590 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
591 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
592
Derek Lambertif674aa02019-08-01 15:56:25 +0100593 if (m_Inputs.size() != 1)
594 {
James Ward47fce872020-09-10 11:57:28 +0100595 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100596 }
597
598 if (m_Outputs.size() != 0)
599 {
James Ward47fce872020-09-10 11:57:28 +0100600 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100601 }
602
603 if (!m_Inputs[0])
604 {
James Ward47fce872020-09-10 11:57:28 +0100605 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100606 }
607}
608
609//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000610void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
611{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100612 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100613
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100614 ValidateNumInputs(workloadInfo, descriptorName, 1);
615 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100616
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100617 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
618 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100619
620 std::vector<DataType> supportedTypes =
621 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000622 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100623 DataType::Float16,
624 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000625 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000626 DataType::QAsymmU8,
627 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100628 };
629
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100630 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
631 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
632 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000633}
634
Nikhil Rajee391d52019-09-05 17:50:44 +0100635void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
636{
637 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
638
639 ValidateNumInputs(workloadInfo, descriptorName, 1);
640 ValidateNumOutputs(workloadInfo, descriptorName, 1);
641
642 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
643 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
644
Inki Daed4619e22020-09-10 15:33:54 +0900645 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
646 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100647 {
Inki Daed4619e22020-09-10 15:33:54 +0900648 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100649 }
650
James Conroyd47a0642019-09-17 14:22:06 +0100651 std::vector<DataType> supportedInputTypes =
652 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000653 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100654 DataType::Float16,
655 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100656 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000657 DataType::QAsymmU8,
658 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900659 DataType::Signed32,
660 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100661 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100662
James Conroyd47a0642019-09-17 14:22:06 +0100663 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100664
665 auto inputShape = inputTensorInfo.GetShape();
666 auto outputShape = outputTensorInfo.GetShape();
667
668 auto inputNumDimensions = inputShape.GetNumDimensions();
669 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
670
671 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
672
673 // 1D input shape results in scalar output shape
674 if (inputShape.GetNumDimensions() == 1)
675 {
676 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
677 {
678 throw InvalidArgumentException(descriptorName + outputShapeError);
679 }
680 }
681 else
682 {
683 for (unsigned int i = 0; i < unsignedAxis; ++i)
684 {
685 if (outputShape[i] != inputShape[i])
686 {
687 throw InvalidArgumentException(descriptorName + outputShapeError);
688 }
689 }
690
691 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
692 {
693 if (outputShape[i - 1] != inputShape[i])
694 {
695 throw InvalidArgumentException(descriptorName + outputShapeError);
696 }
697 }
698 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100699}
700
mathad01b392e982021-04-07 12:07:30 +0100701void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
702{
703 const std::string descriptorName{"CastQueueDescriptor"};
704
705 ValidateNumInputs(workloadInfo, descriptorName, 1);
706 ValidateNumOutputs(workloadInfo, descriptorName, 1);
707
708 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
709 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
710
711 std::vector<DataType> supportedTypes =
712 {
713 DataType::BFloat16,
714 DataType::Float16,
715 DataType::Float32,
716 DataType::QAsymmS8,
717 DataType::QAsymmU8,
718 DataType::QSymmS8,
719 DataType::QSymmS16,
720 DataType::Signed32,
721 DataType::Signed64
722 };
723
724 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
725 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
726}
727
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100728void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
729{
730 const std::string descriptorName{"SoftmaxQueueDescriptor"};
731
732 ValidateNumInputs(workloadInfo, descriptorName, 1);
733 ValidateNumOutputs(workloadInfo, descriptorName, 1);
734
735 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
736 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
737
738 std::vector<DataType> supportedTypes =
739 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000740 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100741 DataType::Float16,
742 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000743 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000744 DataType::QAsymmU8,
745 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100746 };
747
748 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
749 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
750 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
751}
752
telsoa014fcda012018-03-09 14:13:49 +0000753void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
754{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100755 const std::string descriptorName{"SplitterQueueDescriptor"};
756
757 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000758
Ruomei Yan25339c32019-05-28 16:48:20 +0100759 // Check the supported data types
760 std::vector<DataType> supportedTypes =
761 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000762 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100763 DataType::Float32,
764 DataType::Float16,
765 DataType::Boolean,
766 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100767 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000768 DataType::QAsymmU8,
769 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100770 };
771
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100772 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
773 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100774 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100775 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
776 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
777
778 const std::string outputName = "output_" + std::to_string(i);
779 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100780 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100781
telsoa014fcda012018-03-09 14:13:49 +0000782 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
783 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100784 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000785 }
786
787 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
788 {
789 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100790 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000791 "has to match number of workloadInfo.m_OutputTensorInfos. "
792 "Number of windows: " +
793 to_string(m_ViewOrigins.size()) +
794 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
795 }
796
telsoa01c577f2c2018-08-31 09:22:23 +0100797 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000798 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
799 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
800 {
telsoa01c577f2c2018-08-31 09:22:23 +0100801 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000802 ViewOrigin const& e = m_ViewOrigins[w];
803 if (e.m_Origin.size() != inputDims)
804 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100805 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000806 "have the same dimensionality as the input tensor. "
807 "Window origin (index: " +
808 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
809 " dimensions, the input "
810 "tensor has " +
811 to_string(inputDims) + " dimensions.");
812 }
813 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
814 {
815 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
816 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
817 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100818 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000819 "be smaller or equal than the size of the input in that coord.");
820 }
821 }
822 }
823}
824
Jim Flynne242f2d2019-05-22 14:24:13 +0100825void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000826{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100827 const std::string descriptorName{"ConcatQueueDescriptor"};
828
829 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000830
831 if (m_Inputs.size() <= 0)
832 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100833 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000834 }
835 if (m_Outputs.size() <= 0)
836 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100837 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000838 }
839
840 if (workloadInfo.m_InputTensorInfos.size() <= 0)
841 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100842 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000843 }
844 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
845 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100846 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000847 }
848
Nikhil Raj8599a412018-11-19 14:51:07 +0000849 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
850 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100851 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000852 }
853
854 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
855 {
856 return;
857 }
858
telsoa014fcda012018-03-09 14:13:49 +0000859 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
860 {
861 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100862 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000863 "has to match number of workloadInfo.m_InputTensorInfos. "
864 "Number of windows: " +
865 to_string(m_ViewOrigins.size()) +
866 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
867 }
868
telsoa01c577f2c2018-08-31 09:22:23 +0100869 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000870 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
871 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
872 {
telsoa01c577f2c2018-08-31 09:22:23 +0100873 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000874 ViewOrigin const& e = m_ViewOrigins[w];
875 if (e.m_Origin.size() != outputDims)
876 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100877 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000878 "have the same dimensionality as the output tensor. "
879 "Window origin (index: " +
880 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
881 " dimensions, the output "
882 "tensor has " +
883 to_string(outputDims) + " dimensions.");
884 }
telsoa01c577f2c2018-08-31 09:22:23 +0100885 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000886 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
887 {
888 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
889 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
890 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100891 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000892 "be smaller or equal than the size of the output in that coord.");
893 }
894 }
895 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100896
897 // Check the supported data types
898 std::vector<DataType> supportedTypes =
899 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000900 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100901 DataType::Float32,
902 DataType::Float16,
903 DataType::Boolean,
904 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100905 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000906 DataType::QAsymmU8,
907 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100908 };
909
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100910 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
911 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100912 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100913 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
914 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
915
916 const std::string inputName = "input_" + std::to_string(i);
917 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100918 }
telsoa014fcda012018-03-09 14:13:49 +0000919}
920
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100921void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
922{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100923 const std::string descriptorName{"StackQueueDescriptor"};
924
925 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100926
927 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
928 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100929 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100930 }
931
932 // All inputs must have the same shape, which is defined in parameters
933 const TensorShape& inputShape = m_Parameters.m_InputShape;
934 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
935 {
936 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
937 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100938 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100939 }
940 }
941
Matthew Jacksondba634f2019-08-15 15:14:18 +0100942 if (inputShape.GetNumDimensions() > 4)
943 {
944 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
945 }
946
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100947 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
948 // since the output tensor has an additional dimension.
949 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
950 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100951 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100952 "than the number of input dimensions.");
953 }
954
955 // Output shape must be as inferred from the input shape
956 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
957 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
958 {
959 if (outputShape[i] != inputShape[i])
960 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100961 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100962 "match shape inferred from input tensor.");
963 }
964 }
965
966 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
967 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100968 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100969 "match shape inferred from input tensor.");
970 }
971
972 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
973 {
974 if (outputShape[i] != inputShape[i-1])
975 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100976 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100977 "match shape inferred from input tensor.");
978 }
979 }
980
Matthew Jacksondba634f2019-08-15 15:14:18 +0100981 if (outputShape.GetNumDimensions() > 5)
982 {
983 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
984 }
985
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100986 // Check the supported data types
987 std::vector<DataType> supportedTypes =
988 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000989 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100990 DataType::Float32,
991 DataType::Float16,
992 DataType::Boolean,
993 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100994 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000995 DataType::QAsymmU8,
996 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100997 };
998
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100999 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001000
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001001 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001002 {
1003 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1004 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001005 descriptorName,
1006 "input_0",
1007 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001008 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001009
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001010 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1011 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001012 descriptorName,
1013 "input_0",
1014 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001015}
1016
Ryan OSheaec6c6802020-06-05 17:17:06 +01001017void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1018{
1019 const std::string descriptorName{"FillQueueDescriptor"};
1020
1021 ValidateNumInputs(workloadInfo, descriptorName, 1);
1022 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1023
1024 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1025 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1026
1027 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1028
1029 std::vector<DataType> supportedTypes =
1030 {
1031 DataType::BFloat16,
1032 DataType::Float32,
1033 DataType::Float16,
1034 DataType::Signed32
1035 };
1036
1037 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1038}
1039
telsoa014fcda012018-03-09 14:13:49 +00001040void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1041{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001042 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001043
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001044 uint32_t numInputs = 1;
1045 if (!m_Parameters.m_ConstantWeights)
1046 {
1047 numInputs = 2;
1048 if (m_Parameters.m_BiasEnabled)
1049 {
1050 numInputs = 3;
1051 }
1052 }
1053 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001054 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1055
1056 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1057 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1058
1059 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1060
1061 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001062 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001063 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001064 }
1065
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001066 TensorInfo weightTensorInfo;
1067 if (m_Parameters.m_ConstantWeights)
1068 {
1069 ValidatePointer(m_Weight, descriptorName, "weight");
1070 weightTensorInfo = m_Weight->GetTensorInfo();
1071 }
1072 else
1073 {
1074 weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
1075 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001076 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001077
1078 if (m_Parameters.m_BiasEnabled)
1079 {
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001080 TensorInfo biasTensorInfo;
1081 if (m_Parameters.m_ConstantWeights)
1082 {
1083 ValidatePointer(m_Bias, descriptorName, "bias");
1084 biasTensorInfo = m_Bias->GetTensorInfo();
1085 }
1086 else
1087 {
1088 biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
1089 }
telsoa01c577f2c2018-08-31 09:22:23 +01001090 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001091 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001092 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1093 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001094 }
1095
Francis Murtagh46c09d02019-05-28 08:15:28 +01001096 // Check the supported data types
1097 std::vector<DataType> supportedTypes =
1098 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001099 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001100 DataType::Float32,
1101 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001102 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001103 DataType::QAsymmU8,
1104 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001105 };
1106
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001107 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001108
1109 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1110 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1111 {
1112 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1113 {
1114 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1115 "for BFloat16 input.");
1116 }
1117 }
1118 else
1119 {
1120 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1121 }
telsoa014fcda012018-03-09 14:13:49 +00001122}
1123
telsoa014fcda012018-03-09 14:13:49 +00001124void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1125{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001126 const std::string descriptorName{"NormalizationQueueDescriptor"};
1127
1128 ValidateNumInputs(workloadInfo, descriptorName, 1);
1129 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1130
1131 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1132 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001133
1134 // Check the supported data types
1135 std::vector<DataType> supportedTypes =
1136 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001137 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001138 DataType::Float16,
1139 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001140 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001141 DataType::QAsymmU8,
1142 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001143 };
1144
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001145 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001146
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001147 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001148
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001149 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001150}
1151
1152void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1153{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001154 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001155
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001156 ValidateNumInputs(workloadInfo, descriptorName, 2);
1157 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1158
1159 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1160 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1161 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1162
1163 std::vector<DataType> supportedTypes =
1164 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001165 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001166 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001167 DataType::Float16,
1168 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001169 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001170 DataType::QSymmS16,
1171 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001172 };
1173
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001174 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1175 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1176 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001177
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001178 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1179 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001180
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001181 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1182 inputTensorInfo1,
1183 outputTensorInfo,
1184 descriptorName,
1185 "input_0",
1186 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001187}
1188
telsoa014fcda012018-03-09 14:13:49 +00001189void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1190{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001191 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001192
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001193 ValidateNumInputs(workloadInfo, descriptorName, 2);
1194 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1195
1196 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1197 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1198 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1199
1200 std::vector<DataType> supportedTypes =
1201 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001202 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001203 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001204 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001205 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001206 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001207 DataType::QSymmS16,
1208 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001209 };
1210
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001211 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1212 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1213 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001214
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001215 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1216 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001217
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001218 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1219 inputTensorInfo1,
1220 outputTensorInfo,
1221 descriptorName,
1222 "input_0",
1223 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001224}
1225
1226void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1227{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001228 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001229
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001230 ValidateNumInputs(workloadInfo, descriptorName, 1);
1231 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1232
1233 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1234 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001235
1236 std::vector<DataType> supportedTypes =
1237 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001238 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001239 DataType::Float16,
1240 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001241 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001242 DataType::QAsymmU8,
1243 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001244 };
1245
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001246 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1247 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001248
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001249 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001250 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001251
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001252 ValidatePointer(m_Mean, descriptorName, "mean");
1253 ValidatePointer(m_Variance, descriptorName, "variance");
1254 ValidatePointer(m_Beta, descriptorName, "beta");
1255 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001256
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001257 const TensorInfo& mean = m_Mean->GetTensorInfo();
1258 const TensorInfo& variance = m_Variance->GetTensorInfo();
1259 const TensorInfo& beta = m_Beta->GetTensorInfo();
1260 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001261
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001262 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1263 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1264 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1265 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001266
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001267 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1268 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1269 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001270}
1271
1272void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1273{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001274 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001275
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001276 ValidateNumInputs(workloadInfo, descriptorName, 1);
1277 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001278
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001279 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1280 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001281
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001282 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1283 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001284
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001285 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001286
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001287 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1288 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001289
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001290 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001291
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001292 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001293 if (m_Parameters.m_BiasEnabled)
1294 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001295 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001296
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001297 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1298 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001299
1300 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1301 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001302 }
1303
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001304 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1305 {
1306 throw InvalidArgumentException(
1307 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1308 "cannot be either negative or 0.",
1309 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1310 }
1311
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001312 ValidatePerAxisQuantization(inputTensorInfo,
1313 outputTensorInfo,
1314 weightTensorInfo,
1315 optionalBiasTensorInfo,
1316 descriptorName);
1317
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001318 std::vector<DataType> supportedTypes =
1319 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001320 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001321 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001322 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001323 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001324 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001325 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001326 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001327 };
1328
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001329 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001330
1331 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1332 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1333 {
1334 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1335 {
1336 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1337 "for BFloat16 input.");
1338 }
1339 }
1340 else
1341 {
1342 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1343 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001344}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001345
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001346void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1347{
1348 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1349
1350 ValidateNumInputs(workloadInfo, descriptorName, 1);
1351 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1352
1353 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1354 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1355
1356 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1357 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1358
1359 ValidatePointer(m_Weight, descriptorName, "weight");
1360
1361 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1362 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1363
1364 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1365 {
1366 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001367 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1368 "cannot be smaller than 1.",
1369 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001370 }
1371
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001372 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1373 {
1374 throw InvalidArgumentException(
1375 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1376 "cannot be either negative or 0.",
1377 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1378 }
1379
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001380 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1381
Jan Eilers53ef7952021-06-02 12:01:25 +01001382 // Expected weight shape: [ 1, H, W, I*M ] - This shape does NOT depend on the data layout
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001383 // inputChannels * channelMultiplier should be equal to outputChannels.
Jan Eilers53ef7952021-06-02 12:01:25 +01001384 const unsigned int numWeightOutputChannels = weightTensorInfo.GetShape()[3]; // I*M=Cout
1385 const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1386 if (numWeightOutputChannels != numOutputChannels)
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001387 {
James Ward47fce872020-09-10 11:57:28 +01001388 throw InvalidArgumentException(fmt::format(
Jan Eilers53ef7952021-06-02 12:01:25 +01001389 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1390 "But 4th dimension is not equal to Cout. Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
1391 descriptorName,
1392 numOutputChannels,
1393 weightTensorInfo.GetShape()[0],
1394 weightTensorInfo.GetShape()[1],
1395 weightTensorInfo.GetShape()[2],
1396 weightTensorInfo.GetShape()[3]));
1397 }
1398 if (weightTensorInfo.GetShape()[0] != 1)
1399 {
1400 throw InvalidArgumentException(fmt::format(
1401 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1402 "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
1403 descriptorName,
1404 weightTensorInfo.GetShape()[0],
1405 weightTensorInfo.GetShape()[1],
1406 weightTensorInfo.GetShape()[2],
1407 weightTensorInfo.GetShape()[3]));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001408 }
1409
Teresa Charlind8df0262019-11-11 12:28:15 +00001410 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001411
Teresa Charlind8df0262019-11-11 12:28:15 +00001412 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001413 if (m_Parameters.m_BiasEnabled)
1414 {
1415 ValidatePointer(m_Bias, descriptorName, "bias");
1416
Teresa Charlind8df0262019-11-11 12:28:15 +00001417 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1418 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001419
1420 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1421 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1422 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001423 ValidatePerAxisQuantization(inputTensorInfo,
1424 outputTensorInfo,
1425 weightTensorInfo,
1426 optionalBiasTensorInfo,
1427 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001428
1429 std::vector<DataType> supportedTypes =
1430 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001431 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001432 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001433 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001434 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001435 DataType::QAsymmU8,
1436 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001437 };
1438
1439 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1440 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001441}
1442
1443void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1444{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001445 const std::string descriptorName{"PermuteQueueDescriptor"};
1446
1447 ValidateNumInputs(workloadInfo, descriptorName, 1);
1448 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001449
1450 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1451
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001452 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1453 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001454
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001455 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1456 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001457
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001458 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001459 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001460 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001461 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001462 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1463 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1464 "must match dst dimension " + to_string(mapping[i]) +
1465 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001466 }
1467 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001468
1469 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001470}
1471
1472void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1473{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001474 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001475
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001476 ValidateNumInputs(workloadInfo, descriptorName, 1);
1477 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1478
1479 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1480 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1481
1482 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1483 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001484
1485 std::vector<DataType> supportedTypes =
1486 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001487 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001488 DataType::Float32,
1489 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001490 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001491 DataType::QAsymmU8,
1492 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001493 };
1494
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001495 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1496 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001497}
1498
1499void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1500{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001501 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001502
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001503 ValidateNumInputs(workloadInfo, descriptorName, 1);
1504 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1505
1506 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1507 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1508
1509 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1510 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001511
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001512 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001513 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001514 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001515 DataType::Float16,
1516 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001517 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001518 DataType::QAsymmU8,
1519 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001520 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001521
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001522 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1523 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001524
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001525 // ResizeBilinear only changes width and height: batch and channel count must match.
1526 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1527 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001528 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001529 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001530 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001531 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1532 descriptorName, inputBatchSize, outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001533 }
1534
Teresa Charlin970f43b2019-07-01 13:51:07 +01001535 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001536 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1537 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001538 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001539 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001540 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001541 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1542 descriptorName, inputChannelCount, outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001543 }
1544}
1545
1546void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1547{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001548 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001549
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001550 ValidateNumInputs(workloadInfo, descriptorName, 1);
1551 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1552
1553 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1554 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1555
1556 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1557 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001558
1559 std::vector<DataType> supportedTypes =
1560 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001561 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001562 DataType::Float16,
1563 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001564 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001565 DataType::QAsymmU8,
1566 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001567 };
1568
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001569 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1570 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001571
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001572 // Resize only changes width and height: batch and channel count must match.
1573 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1574 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001575 if (inputBatchSize != outputBatchSize)
1576 {
1577 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001578 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1579 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001580 }
1581
1582 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001583 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1584 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001585 if (inputChannelCount != outputChannelCount)
1586 {
1587 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001588 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1589 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001590 }
1591}
1592
1593void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1594{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001595 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001596
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001597 ValidateNumInputs(workloadInfo, descriptorName, 1);
1598 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1599
1600 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1601 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1602
1603 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1604 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1605
1606 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1607
telsoa014fcda012018-03-09 14:13:49 +00001608 if (m_Parameters.m_Min > m_Parameters.m_Max)
1609 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001610 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001611 }
telsoa014fcda012018-03-09 14:13:49 +00001612}
1613
Kevin Mayce5045a2019-10-02 14:07:47 +01001614void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1615{
1616 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1617
1618 ValidateNumInputs(workloadInfo, descriptorName, 1);
1619 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1620
1621 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1622 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1623
1624 if (inputTensorInfo.GetNumDimensions() > 4)
1625 {
1626 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1627 }
1628
1629 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1630
1631 // Check the supported data types
1632 std::vector<DataType> supportedTypes =
1633 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001634 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001635 DataType::Float32,
1636 DataType::Float16
1637 };
1638
1639 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001640 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001641}
1642
telsoa014fcda012018-03-09 14:13:49 +00001643void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1644{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001645 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001646
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001647 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001648 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1649
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001650 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1651 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1652
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001653 if (inputTensorInfo.GetNumDimensions() > 4)
1654 {
1655 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1656 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001657
1658 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001659
1660 // Check the supported data types
1661 std::vector<DataType> supportedTypes =
1662 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001663 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001664 DataType::Float32,
1665 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001666 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001667 DataType::QAsymmU8,
1668 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001669 };
1670
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001671 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001672 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1673}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001674
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001675void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1676{
1677 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1678
1679 ValidateNumInputs(workloadInfo, descriptorName, 1);
1680 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1681
1682 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1683 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1684
1685 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1686
1687 std::vector<DataType> supportedTypes =
1688 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001689 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001690 DataType::Float32,
1691 DataType::Float16,
1692 };
1693
1694 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001695 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001696}
1697
1698void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1699{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001700 const std::string descriptorName{"ConstantQueueDescriptor"};
1701
1702 ValidateNumInputs(workloadInfo, descriptorName, 0);
1703 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001704
1705 if (!m_LayerOutput)
1706 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001707 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001708 }
1709
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001710 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1711 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001712
1713 // Check the supported data types
1714 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001715 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001716 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001717 DataType::Float32,
1718 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001719 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001720 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001721 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001722 DataType::QSymmS16,
1723 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001724 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001725
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001726 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001727}
1728
1729void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1730{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001731 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001732
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001733 ValidateNumInputs(workloadInfo, descriptorName, 1);
1734 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1735
1736 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1737 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1738
1739 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001740
1741 // Check the supported data types
1742 std::vector<DataType> supportedTypes =
1743 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001744 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001745 DataType::Float32,
1746 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001747 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001748 DataType::QAsymmU8,
1749 DataType::QSymmS16,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001750 DataType::Signed32,
1751 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001752 };
1753
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001754 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1755 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001756}
1757
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001758void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1759{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001760 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001761
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001762 ValidateNumInputs(workloadInfo, descriptorName, 1);
1763 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1764
1765 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1766 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1767
1768 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1769 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001770
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001771 if (m_Parameters.m_BlockShape.size() != 2)
1772 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001773 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001774 }
1775
1776 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1777 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001778 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1779 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001780 }
1781
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001782 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001783
1784 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001785 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001786
Matthew Bentham8800c002018-11-19 13:19:28 +00001787 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001788
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001789 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1790 widthPad.first + widthPad.second;
1791 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1792 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001793
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001794 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1795 inputShape[dimensionIndices.GetChannelsIndex()];
1796 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001797
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001798 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001799 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001800 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001801 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001802 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001803 }
1804
1805 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001806 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001807 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1808 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001809 }
nikraj01120522a2019-05-31 11:33:07 +01001810
1811 std::vector<DataType> supportedTypes =
1812 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001813 DataType::BFloat16,
1814 DataType::Float16,
1815 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001816 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001817 DataType::QAsymmU8,
1818 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001819 };
1820
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001821 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1822 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001823}
1824
Keith Davisa57eccb2019-06-14 17:33:22 +01001825void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1826{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001827 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001828
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001829 ValidateNumInputs(workloadInfo, descriptorName, 1);
1830 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001831
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001832 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1833 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1834
1835 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1836 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001837
1838 std::vector<DataType> supportedTypes =
1839 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001840 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001841 DataType::Float32,
1842 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001843 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001844 DataType::QAsymmU8,
1845 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001846 };
1847
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001848 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1849 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001850
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001851 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1852
1853 if (m_Parameters.m_BlockSize == 0)
1854 {
1855 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1856 }
1857
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001858 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1859 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1860 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1861 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001862
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001863 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001864 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001865 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001866 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1867 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001868 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001869
1870 const TensorShape& outputShape = outputTensorInfo.GetShape();
1871 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1872 {
1873 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1874 "must be divisible by the square of block size." );
1875 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001876}
1877
telsoa014fcda012018-03-09 14:13:49 +00001878void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1879{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001880 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001881
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001882 ValidateNumInputs(workloadInfo, descriptorName, 1);
1883 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1884
1885 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1886 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001887
1888 std::vector<DataType> supportedTypes =
1889 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001890 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001891 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001892 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001893 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001894 };
1895
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001896 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001897
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001898 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001899 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001900 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001901 }
1902}
1903
telsoa01c577f2c2018-08-31 09:22:23 +01001904void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1905{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001906 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1907
1908 const std::string descriptorName{"LstmQueueDescriptor"};
1909
1910 // check dimensions of all inputs and outputs
1911 if (workloadInfo.m_InputTensorInfos.size() != 3)
1912 {
1913 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1914 }
1915 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1916 {
1917 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1918 }
1919
1920 std::vector<DataType> supportedTypes =
1921 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001922 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001923 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001924 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001925 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001926 };
1927
Jan Eilers38e05bd2019-06-26 13:10:09 +01001928 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001929 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1930
Jan Eilers38e05bd2019-06-26 13:10:09 +01001931 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001932 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001933 {
1934 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1935 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001936 descriptorName,
1937 "input_0",
1938 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001939 }
1940 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001941 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001942 {
1943 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1944 workloadInfo.m_OutputTensorInfos[i],
1945 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001946 "input_0",
1947 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001948 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001949
janeil0117d8d852019-11-15 15:00:16 +00001950 // Making sure clipping parameters have valid values.
1951 // == 0 means no clipping
1952 // > 0 means clipping
1953 if (m_Parameters.m_ClippingThresCell < 0.0f)
1954 {
1955 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
1956 }
1957 if (m_Parameters.m_ClippingThresProj < 0.0f)
1958 {
1959 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
1960 }
1961
Jan Eilers38e05bd2019-06-26 13:10:09 +01001962
1963 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01001964 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1965 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1966 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1967 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1968 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1969 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1970
Jan Eilers38e05bd2019-06-26 13:10:09 +01001971 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001972 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1973 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001974 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001975 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1976 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001977 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001978 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1979 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001980 // scratchBufferTensor
1981 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001982 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1983 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001984 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001985 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1986 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001987 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001988 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1989 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001990 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001991 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1992 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001993
1994
1995 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1996 if ( m_InputToInputWeights )
1997 {
1998 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1999 (n_cell * n_input), "InputLayerNormWeights");
2000 }
2001
2002 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2003 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
2004 (n_cell * n_input), "InputToForgetWeights");
2005
2006 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2007 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2008 (n_cell * n_input), "InputToCellWeights");
2009
2010 if ( m_RecurrentToInputWeights )
2011 {
2012 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2013 (n_cell * n_output), "RecurrentToInputWeights");
2014 }
2015
2016 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2017 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2018 (n_cell * n_output), "RecurrentToForgetWeights");
2019
2020 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2021 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2022 (n_cell * n_output), "RecurrentToCellWeights");
2023
2024 // Make sure the input-gate's parameters are either both present (regular
2025 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2026 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2027 !m_Parameters.m_CifgEnabled) ||
2028 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2029 m_Parameters.m_CifgEnabled));
2030 if (!cifg_weights_all_or_none)
2031 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002032 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2033 "RecurrentToInputWeights must either both be present (regular LSTM) "
2034 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2035 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002036 }
2037
2038 if ( m_CellToInputWeights )
2039 {
2040 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2041 n_cell, "CellToInputWeights");
2042 }
2043 if ( m_CellToForgetWeights )
2044 {
2045 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2046 n_cell, "CellToForgetWeights");
2047 }
2048 if ( m_CellToOutputWeights )
2049 {
2050 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2051 n_cell, "CellToOutputWeights");
2052 }
2053
2054 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2055 bool peephole_weights_all_or_none =
2056 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2057 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2058 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2059 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2060 if (!peephole_weights_all_or_none)
2061 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002062 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002063 }
2064
2065 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2066 if (m_Parameters.m_CifgEnabled)
2067 {
2068 if (m_InputGateBias)
2069 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002070 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002071 }
2072 }
2073 else
2074 {
2075 if (!m_InputGateBias)
2076 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002077 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2078 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002079 }
2080 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2081 n_cell, "InputGateBias");
2082 }
2083
2084 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2085 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2086
2087 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2088 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2089
2090 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2091 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2092
2093 if (m_ProjectionWeights)
2094 {
2095 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2096 (n_cell * n_output), "ProjectionWeights");
2097 }
2098 if (m_ProjectionBias)
2099 {
2100 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2101 }
2102
2103 // Making sure the projection tensors are consistent:
2104 // 1) If projection weight is not present, then projection bias should not be
2105 // present.
2106 // 2) If projection weight is present, then projection bias is optional.
2107 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2108 !m_Parameters.m_ProjectionEnabled)
2109 || (m_ProjectionWeights && !m_ProjectionBias &&
2110 m_Parameters.m_ProjectionEnabled)
2111 || (m_ProjectionWeights && m_ProjectionBias &&
2112 m_Parameters.m_ProjectionEnabled));
2113 if (!projecton_tensors_consistent)
2114 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002115 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002116 }
2117
2118 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2119 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2120 // either all have values or none of them have values. Layer normalization is used when the values of all the
2121 // layer normalization weights are present
2122 if (m_InputLayerNormWeights)
2123 {
2124 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2125 }
2126 if (m_ForgetLayerNormWeights)
2127 {
2128 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2129 }
2130 if (m_CellLayerNormWeights)
2131 {
2132 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2133 }
2134 if (m_OutputLayerNormWeights)
2135 {
2136 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2137 }
2138
Jan Eilers38e05bd2019-06-26 13:10:09 +01002139 if (m_Parameters.m_LayerNormEnabled)
2140 {
2141 if (!m_Parameters.m_CifgEnabled)
2142 {
2143 if (!m_InputLayerNormWeights)
2144 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002145 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2146 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002147 }
2148 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2149 1, n_cell, "InputLayerNormWeights");
2150 }
2151 else if (m_InputLayerNormWeights)
2152 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002153 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2154 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002155 }
2156
2157 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2158 "ForgetLayerNormWeights");
2159 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2160
2161 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2162 "OutputLayerNormWeights");
2163 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2164
2165 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2166 "CellLayerNormWeights");
2167 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2168 }
2169 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2170 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002171 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2172 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002173 }
telsoa01c577f2c2018-08-31 09:22:23 +01002174}
2175
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002176void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2177{
2178 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2179
2180 ValidateNumInputs(workloadInfo, descriptorName, 1);
2181 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2182
2183 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2184 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2185
2186 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2187 {
2188 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2189 }
2190
2191 if (outputTensorInfo.GetDataType() != DataType::Float32)
2192 {
2193 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2194 }
2195
2196 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2197}
2198
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002199void ConvertFp32ToBf16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2200{
2201 const std::string descriptorName{"ConvertFp32ToBf16QueueDescriptor"};
2202
2203 ValidateNumInputs(workloadInfo, descriptorName, 1);
2204 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2205
2206 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2207 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2208
2209 if (inputTensorInfo.GetDataType() != DataType::Float32)
2210 {
2211 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2212 }
2213
2214 if (outputTensorInfo.GetDataType() != DataType::BFloat16)
2215 {
2216 throw InvalidArgumentException(descriptorName + ": Output tensor type must be BFloat16.");
2217 }
2218
2219 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2220}
2221
telsoa01c577f2c2018-08-31 09:22:23 +01002222void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2223{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002224 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002225
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002226 ValidateNumInputs(workloadInfo, descriptorName, 1);
2227 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2228
2229 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2230 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2231
2232 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002233 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002234 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002235 }
2236
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002237 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002238 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002239 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002240 }
2241
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002242 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002243}
2244
2245void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2246{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002247 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002248
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002249 ValidateNumInputs(workloadInfo, descriptorName, 1);
2250 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2251
2252 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2253 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2254
2255 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002256 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002257 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002258 }
2259
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002260 if (outputTensorInfo.GetDataType() != DataType::Float32)
2261 {
2262 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2263 }
2264
2265 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002266}
2267
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002268void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2269{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002270 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002271
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002272 ValidateNumInputs(workloadInfo, descriptorName, 2);
2273 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2274
2275 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2276 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2277 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2278
2279 std::vector<DataType> supportedTypes =
2280 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002281 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002282 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002283 DataType::Float32,
2284 DataType::QAsymmS8,
2285 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002286 DataType::QSymmS16,
2287 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002288 };
2289
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002290 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2291 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2292 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002293
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002294 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2295 inputTensorInfo1,
2296 outputTensorInfo,
2297 descriptorName,
2298 "input_0",
2299 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002300}
2301
David Beckc2044fe2018-09-05 15:00:38 +01002302void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2303{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002304 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002305
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002306 ValidateNumInputs(workloadInfo, descriptorName, 2);
2307 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2308
2309 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2310 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2311 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2312
2313 std::vector<DataType> supportedTypes =
2314 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002315 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002316 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002317 DataType::Float32,
2318 DataType::QAsymmS8,
2319 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002320 DataType::QSymmS16,
2321 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002322 };
2323
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002324 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2325 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2326 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002327
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002328 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2329 inputTensorInfo1,
2330 outputTensorInfo,
2331 descriptorName,
2332 "input_0",
2333 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002334}
2335
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002336void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2337{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002338 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002339
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002340 ValidateNumInputs(workloadInfo, descriptorName, 2);
2341 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2342
2343 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2344 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2345 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2346
2347 std::vector<DataType> supportedTypes =
2348 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002349 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002350 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002351 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002352 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002353 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002354 DataType::QSymmS16,
2355 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002356 };
2357
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002358 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2359 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2360 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002361
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002362 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2363 inputTensorInfo1,
2364 outputTensorInfo,
2365 descriptorName,
2366 "input_0",
2367 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002368}
2369
narpra01a6bf9122018-09-10 09:50:09 +01002370void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2371{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002372 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002373
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002374 ValidateNumInputs(workloadInfo, descriptorName, 1);
2375 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2376
2377 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2378 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002379
2380 std::vector<DataType> supportedTypes =
2381 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002382 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002383 DataType::Float32,
2384 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002385 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002386 DataType::QAsymmU8,
2387 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002388 };
narpra01eb061912018-09-10 17:35:27 +01002389
James Conroy4d1ff582019-06-10 17:06:39 +01002390 // First check if input tensor data type is supported, then
2391 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002392 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2393 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002394
narpra0132b90462018-09-13 11:07:48 +01002395 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002396 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002397 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002398 }
narpra0132b90462018-09-13 11:07:48 +01002399 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002400 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002401 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002402 }
2403 else
2404 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002405 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002406 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002407 ValidateTensorNumDimensions(outputTensorInfo,
2408 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002409 outputDim > 0 ? outputDim : 1,
2410 "output");
2411 }
narpra01a6bf9122018-09-10 09:50:09 +01002412}
2413
jimfly012c9322a2018-09-19 10:59:49 +01002414void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2415{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002416 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002417
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002418 ValidateNumInputs(workloadInfo, descriptorName, 1);
2419 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2420
2421 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2422 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002423
jimfly012c9322a2018-09-19 10:59:49 +01002424 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002425 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2426
jimfly012c9322a2018-09-19 10:59:49 +01002427 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002428 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2429 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2430 "as there are dimensions in the input tensor that is " +
2431 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2432 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002433 }
2434}
2435
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002436void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2437{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002438 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002439
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002440 ValidateNumInputs(workloadInfo, descriptorName, 1);
2441 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002442
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002443 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2444 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2445
Sadik Armagan2208b602019-07-31 16:36:27 +01002446 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002447 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002448 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002449 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002450 DataType::Float16,
2451 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002452 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002453 DataType::QAsymmU8,
2454 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002455 };
2456
2457 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002458
Keith Davis0c2eeac2020-02-11 16:51:50 +00002459 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002460 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002461 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002462 }
2463}
2464
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002465void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2466{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002467 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002468
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002469 ValidateNumInputs(workloadInfo, descriptorName, 1);
2470 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002471
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002472 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2473 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002474
2475 std::vector<DataType> supportedTypes =
2476 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002477 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002478 DataType::Float32,
2479 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002480 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002481 DataType::QAsymmU8,
2482 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002483 };
2484
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002485 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2486 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002487}
2488
Conor Kennedy430b5d82018-11-14 15:28:28 +00002489void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2490{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002491 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002492
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002493 ValidateNumInputs(workloadInfo, descriptorName, 1);
2494 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2495
2496 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2497 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002498
2499 std::vector<DataType> supportedTypes =
2500 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002501 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002502 DataType::Float16,
2503 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002504 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002505 DataType::QAsymmU8,
2506 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002507 };
2508
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002509 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2510 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002511
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002512 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002513
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002514 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002515 if (rank > 4)
2516 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002517 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002518 }
2519
Conor Kennedy430b5d82018-11-14 15:28:28 +00002520 // Begin, End & Stride length must be of rank(input0)
2521 if (m_Parameters.m_Begin.size() != rank)
2522 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002523 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002524 }
2525
2526 if (m_Parameters.m_End.size() != rank)
2527 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002528 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002529 }
2530
2531 if (m_Parameters.m_Stride.size() != rank)
2532 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002533 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002534 }
2535
2536 // Stride entries must be non-zero
2537 for (auto& stride : m_Parameters.m_Stride)
2538 {
2539 if (stride == 0)
2540 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002541 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002542 }
2543 }
2544}
2545
kevmay0190539692018-11-29 08:40:19 +00002546void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2547{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002548 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002549
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002550 ValidateNumInputs(workloadInfo, descriptorName, 2);
2551 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2552
2553 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2554 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2555 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2556
2557 std::vector<DataType> supportedTypes =
2558 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002559 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002560 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002561 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002562 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002563 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002564 DataType::QSymmS16,
2565 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002566 };
2567
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002568 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2569 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2570 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002571
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002572 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2573 inputTensorInfo1,
2574 outputTensorInfo,
2575 descriptorName,
2576 "input_0",
2577 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002578}
2579
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002580void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2581{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002582 const std::string descriptorName{"DebugQueueDescriptor"};
2583
2584 ValidateNumInputs(workloadInfo, descriptorName, 1);
2585 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002586}
2587
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002588void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2589{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002590 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002591
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002592 ValidateNumInputs(workloadInfo, descriptorName, 2);
2593 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002594
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002595 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2596 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2597 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2598
2599 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2600 inputTensorInfo1,
2601 outputTensorInfo,
2602 descriptorName,
2603 "input_0",
2604 "input_1");
2605
2606 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002607 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002608 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002609 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002610}
2611
FrancisMurtagh878f0232018-12-19 10:56:15 +00002612void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2613{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002614 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002615
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002616 ValidateNumInputs(workloadInfo, descriptorName, 2);
2617 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002618
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002619 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2620 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2621 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2622
2623 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2624 inputTensorInfo1,
2625 outputTensorInfo,
2626 descriptorName,
2627 "input_0",
2628 "input_1");
2629
2630 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002631 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002632 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002633 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002634}
2635
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002636void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2637{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002638 const std::string descriptorName{"RsqrtQueueDescriptor"};
2639
2640 ValidateNumInputs(workloadInfo, descriptorName, 1);
2641 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2642
2643 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2644 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2645
2646 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002647
2648 std::vector<DataType> supportedTypes =
2649 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002650 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002651 DataType::Float16,
2652 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002653 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002654 DataType::QAsymmU8,
2655 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002656 };
2657
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002658 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2659 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002660}
2661
narpra01b89b05f2019-01-16 09:53:09 +00002662void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2663{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002664 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002665
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002666 ValidateNumInputs(workloadInfo, descriptorName, 2);
2667 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002668
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002669 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2670 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002671 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002672 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002673 }
2674
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002675 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2676 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2677
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002678 std::vector<DataType> supportedTypes =
2679 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002680 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002681 DataType::Float16,
2682 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002683 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002684 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002685 DataType::QSymmS16,
2686 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002687 };
2688
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002689 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002690
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002691 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002692
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002693 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2694 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002695}
2696
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002697void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2698{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002699 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2700
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002701 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002702
2703 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2704 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002705 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002706 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2707 }
2708
2709 if (m_Anchors == nullptr)
2710 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002711 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002712 }
2713
2714 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002715 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2716 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2717
2718 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002719 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002720 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2721 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002722
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002723 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2724 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2725 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002726
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002727 const std::vector<DataType> supportedInputTypes =
2728 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002729 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002730 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002731 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002732 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002733 DataType::QAsymmU8,
2734 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002735 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002736
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002737 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2738 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2739 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2740
2741 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2742 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2743 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2744 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2745
2746 // NOTE: Output is always Float32 regardless of input type
2747 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2748 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2749 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2750 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002751
2752 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2753 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002754 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002755 "must be positive and less than or equal to 1.");
2756 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002757
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002758 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2759 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002760 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002761 "should be equal to number of classes + 1.");
2762 }
2763}
2764
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002765void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2766{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002767 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002768
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002769 ValidateNumInputs(workloadInfo, descriptorName, 1);
2770 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2771
2772 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2773 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2774
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002775 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002776 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002777 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002778 }
2779
Sadik Armagan2208b602019-07-31 16:36:27 +01002780 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002781 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002782 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002783 DataType::Float32,
2784 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002785 };
2786
2787 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002788}
2789
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002790void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2791{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002792 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002793
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002794 ValidateNumInputs(workloadInfo, descriptorName, 2);
2795 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002796
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002797 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2798 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2799 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002800
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002801 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2802 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2803
2804 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2805 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002806}
2807
Sadik Armaganeff363d2019-04-05 15:25:46 +01002808void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2809{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002810 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002811
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002812 ValidateNumInputs(workloadInfo, descriptorName, 2);
2813 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2814
2815 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2816 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2817
2818 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2819 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2820
2821 std::vector<DataType> supportedTypes =
2822 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002823 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002824 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002825 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002826 DataType::QAsymmU8,
2827 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002828 };
2829
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002830 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2831 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002832
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002833 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2834 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002835
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002836 ValidateTensorShapesMatch(inputTensorInfo0,
2837 outputTensorInfo0,
2838 descriptorName,
2839 "input_0",
2840 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002841
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002842 ValidateTensorShapesMatch(inputTensorInfo0,
2843 outputTensorInfo1,
2844 descriptorName,
2845 "input_0",
2846 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002847}
2848
Derek Lamberti901ea112019-12-10 22:07:09 +00002849void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002850{
2851 // This is internally generated so it should not need validation.
2852}
2853
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002854void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2855{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002856 const std::string& descriptorName{"PreluQueueDescriptor"};
2857
2858 ValidateNumInputs(workloadInfo, descriptorName, 2);
2859 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2860
2861 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2862 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2863 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002864
2865 std::vector<DataType> supportedTypes
2866 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002867 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002868 DataType::Float16,
2869 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002870 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002871 DataType::QAsymmU8,
2872 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002873 };
2874
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002875 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2876 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002877
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002878 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002879
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002880 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2881 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002882
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002883 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2884 alphaTensorInfo,
2885 outputTensorInfo,
2886 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002887 "input",
2888 "alpha");
2889}
2890
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002891void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2892{
2893 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2894
2895 ValidateNumInputs(workloadInfo, descriptorName, 1);
2896 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2897
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002898 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2899 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2900
2901 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2902 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002903
2904 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002905
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002906 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2907 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002908
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002909 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2910
2911 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002912 if (m_Parameters.m_BiasEnabled)
2913 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002914 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002915
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002916 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2917 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002918
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002919 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002920 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002921 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002922
2923 ValidatePerAxisQuantization(inputTensorInfo,
2924 outputTensorInfo,
2925 weightTensorInfo,
2926 optionalBiasTensorInfo,
2927 descriptorName);
2928
2929 std::vector<DataType> supportedTypes =
2930 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002931 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002932 DataType::Float32,
2933 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002934 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002935 DataType::QAsymmU8,
2936 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002937 };
2938
2939 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2940 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002941}
2942
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002943void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2944{
2945 const std::string descriptorName{"TransposeQueueDescriptor"};
2946
2947 ValidateNumInputs(workloadInfo, descriptorName, 1);
2948 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2949
2950 const PermutationVector& mapping = m_Parameters.m_DimMappings;
2951
2952 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2953 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2954
2955 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
2956 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
2957
2958 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
2959 {
2960 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
2961 {
2962 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
2963 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
2964 "must match dst dimension " + to_string(i) +
2965 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
2966 }
2967 }
2968
2969 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2970}
2971
James Conroy4f1f8992020-04-29 20:01:10 +01002972void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2973{
2974 const std::string descriptorName{"QLstmQueueDescriptor"};
2975
2976 // Validate number of inputs/outputs
2977 ValidateNumInputs(workloadInfo, descriptorName, 3);
2978 ValidateNumOutputs(workloadInfo, descriptorName, 3);
2979
2980 // Input/output tensor info
2981 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2982 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
2983 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
2984
2985 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2986 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2987 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
2988
2989 // Supported types for various tensors in QLSTM
2990 std::vector<DataType> inputOutputSupportedTypes =
2991 {
2992 DataType::QAsymmS8
2993 };
2994
2995 std::vector<DataType> cellStateSupportedTypes =
2996 {
2997 DataType::QSymmS16
2998 };
2999
3000 std::vector<DataType> weightsSupportedTypes =
3001 {
3002 DataType::QSymmS8
3003 };
3004
3005 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3006 {
3007 DataType::QSymmS16
3008 };
3009
3010 std::vector<DataType> biasSupportedTypes =
3011 {
3012 DataType::Signed32
3013 };
3014
3015 // Validate types of input/output tensors
3016 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3017 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3018 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3019
3020 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3021 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3022 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3023
3024 // Validate matching types of input/output tensors
3025 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3026 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3027 "outputStateIn", "outputStateOut");
3028 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3029
3030 // Infer number of batches, number of units, input size and output size from tensor dimensions
3031 const uint32_t numBatches = inputInfo.GetShape()[0];
3032 const uint32_t inputSize = inputInfo.GetShape()[1];
3033 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3034 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3035
3036 // Validate number of dimensions and number of elements for input/output tensors
3037 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3038 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3039 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3040
3041 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3042 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3043 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3044
3045 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3046 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3047 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3048 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3049
3050 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3051 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3052 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3053
3054 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3055 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3056 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3057
3058 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3059 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3060 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3061 " RecurrentToForgetWeights");
3062
3063 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3064 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3065 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3066
3067 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3068 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3069 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3070
3071 // Validate data types for MANDATORY weights tensors (all should match each other)
3072 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3073
3074 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3075 "inputToForgetWeights", "inputToCellWeights");
3076 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3077 "inputToForgetWeights", "inputToOutputWeights");
3078
3079 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3080 "inputToForgetWeights", "recurrentToForgeteights");
3081 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3082 "inputToForgetWeights", "recurrentToCellWeights");
3083 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3084 "inputToForgetWeights", "recurrentToOutputWeights");
3085
3086 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3087 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3088 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3089 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3090
3091 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3092 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3093 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3094
3095 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3096 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3097 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3098
3099 // Validate data types for MANDATORY bias tensors
3100 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3101
3102 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3103 "forgetGateBias", "cellBias");
3104 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3105 "forgetGateBias", "outputGateBias");
3106
3107 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3108 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3109 !m_Parameters.m_CifgEnabled) ||
3110 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3111 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3112
3113 if (!allCifgParamsPresentOrNot)
3114 {
3115 throw InvalidArgumentException(descriptorName +
3116 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3117 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3118 "set appropriately.");
3119 }
3120
3121 if (!m_Parameters.m_CifgEnabled)
3122 {
3123 // Validate number of dimensions and number of elements
3124 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3125 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3126
3127 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3128 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3129 " RecurrentToInputWeights");
3130
3131 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3132 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3133
3134 // Validate data types
3135 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3136 "inputToForgetWeights", "inputToInputWeights");
3137 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3138 "inputToForgetWeights", "recurrentToInputWeights");
3139 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3140 "forgetGateBias", "inputGateBias");
3141 }
3142
3143 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3144 bool allPeepholeWeightsPresentOrNot =
3145 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3146 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3147 || (!m_CellToInputWeights && !m_CellToForgetWeights
3148 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3149
3150 if (!allPeepholeWeightsPresentOrNot)
3151 {
3152 throw InvalidArgumentException(descriptorName +
3153 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3154 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3155 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3156 "appropriately.");
3157 }
3158
3159 if (m_Parameters.m_PeepholeEnabled)
3160 {
3161 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3162 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3163 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3164
3165 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3166 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3167 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3168 "cellToForgetWeight", "cellToOutputWeights");
3169
3170 if (!m_Parameters.m_CifgEnabled)
3171 {
3172 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3173 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3174 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3175 "cellToForgetWeights", "cellToInputWeights");
3176 }
3177 }
3178
3179 // Validate OPTIONAL params: Layer Norm Weights
3180 bool allLayerNormWeightsPresentOrNot =
3181 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3182 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3183 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3184 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3185
3186 if (!allLayerNormWeightsPresentOrNot)
3187 {
3188 throw InvalidArgumentException(descriptorName +
3189 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3190 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3191 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3192 "only be present when Layer Norm is enabled and CIFG is disabled. "
3193 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3194 }
3195
3196 if (m_Parameters.m_LayerNormEnabled)
3197 {
3198 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3199 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3200 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3201
3202 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3203 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3204 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3205 "forgetLayerNormWeights", "cellLayerNormWeights");
3206
3207 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3208 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3209 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3210 "forgetLayerNormWeights", "outputLayerNormWeights");
3211
3212 if (!m_Parameters.m_CifgEnabled)
3213 {
3214 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3215 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3216 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3217 "forgetLayerNormWeights", "inputLayerNormWeights");
3218 }
3219 }
3220
3221 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3222 bool correctProjectionTensorsPresent =
3223 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3224 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3225 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3226
3227 if (!correctProjectionTensorsPresent)
3228 {
3229 throw InvalidArgumentException(descriptorName +
3230 ": If projection is enabled, ProjectionWeights should be present and "
3231 "ProjectionBias is optional. If projection is disabled, neither "
3232 "ProjectionWeights nor ProjectionBias should be present.");
3233 }
3234
3235 if (m_Parameters.m_ProjectionEnabled)
3236 {
3237 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3238 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3239 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3240
3241 if (m_ProjectionBias)
3242 {
3243 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003244 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003245 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3246 }
3247
3248 }
3249 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3250 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3251 throw InvalidArgumentException(descriptorName +
3252 ": If projection is disabled, output quantization info (scale, offset) "
3253 "should match HiddenStateScale and HiddenStateZeroPoint.");
3254 }
3255
3256}
3257
James Conroy9c3cae82019-08-01 16:01:48 +01003258void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3259{
3260 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3261
3262 // Validate number of inputs/outputs
3263 ValidateNumInputs(workloadInfo, descriptorName, 3);
3264 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3265
3266 // Input/output tensor infos
3267 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3268 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3269 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3270
3271 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3272 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3273
3274 std::vector<DataType> inputOutputSupportedTypes =
3275 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003276 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003277 };
3278
3279 std::vector<DataType> cellStateSupportedTypes =
3280 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003281 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003282 };
3283
3284 std::vector<DataType> weightsSupportedTypes =
3285 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003286 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003287 };
3288
3289 std::vector<DataType> biasSupportedTypes =
3290 {
3291 DataType::Signed32
3292 };
3293
3294 // Validate types of input/output tensors
3295 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3296 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3297 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3298
3299 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3300 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3301
3302 // Validate matching types of input/output tensors
3303 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3304 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3305 "outputStateIn", "outputStateOut");
3306 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3307
3308 // Validate matching quantization info for input/output tensors
3309 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3310 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3311 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003312
James Conroy9c3cae82019-08-01 16:01:48 +01003313 // Infer number of batches, input size and output size from tensor dimensions
3314 const uint32_t numBatches = inputInfo.GetShape()[0];
3315 const uint32_t inputSize = inputInfo.GetShape()[1];
3316 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3317
3318 // Validate number of dimensions and number of elements for input/output tensors
3319 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3320 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3321 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3322 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3323 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3324
3325 // Validate number of dimensions and number of elements for weights tensors
3326 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3327 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3328 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3329
3330 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3331 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3332 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3333
3334 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3335 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3336 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3337
3338 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3339 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3340 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3341
3342 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3343 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3344 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3345
3346 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3347 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3348 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3349 " RecurrentToForgetWeights");
3350
3351 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3352 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3353 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3354
3355 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3356 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3357 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3358
3359 // Validate data types for weights tensors (all should match each other)
3360 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3361
3362 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3363 "inputToInputWeights", "inputToForgetWeights");
3364 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3365 "inputToInputWeights", "inputToCellWeights");
3366 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3367 "inputToInputWeights", "inputToOutputWeights");
3368
3369 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3370 "inputToInputWeights", "recurrentToInputWeights");
3371 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3372 "inputToInputWeights", "recurrentToForgeteights");
3373 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3374 "inputToInputWeights", "recurrentToCellWeights");
3375 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3376 "inputToInputWeights", "recurrentToOutputWeights");
3377
3378 // Validate matching quantization info for weight tensors (all should match each other)
3379 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3380 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3381 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3382 descriptorName, "inputToInputWeights", "inputToCellWeights");
3383 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3384 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3385
3386 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3387 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3388 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3389 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3390 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3391 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3392 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3393 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3394
3395 // Validate number of dimensions and number of elements in bias tensors
3396 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3397 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3398 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3399
3400 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3401 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3402 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3403
3404 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3405 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3406 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3407
3408 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3409 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3410 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3411
3412 // Validate data types for bias tensors (all should match each other)
3413 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3414
3415 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3416 "inputGateBias", "forgetGateBias");
3417 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3418 "inputGateBias", "cellBias");
3419 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3420 "inputGateBias", "outputGateBias");
3421
3422 // Validate bias tensor quantization info
3423 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3424 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3425 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3426 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3427}
3428
Kevin May868eb142019-09-04 17:29:31 +01003429void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3430{
3431 const std::string descriptorName{"AbsQueueDescriptor"};
3432
3433 ValidateNumInputs(workloadInfo, descriptorName, 1);
3434 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3435
3436 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3437 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3438
3439 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3440
3441 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003442 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003443 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003444 DataType::Float16,
3445 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003446 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003447 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003448 DataType::QSymmS16,
3449 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003450 };
Kevin May868eb142019-09-04 17:29:31 +01003451
3452 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3453 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3454}
3455
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003456void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3457{
3458 const std::string descriptorName{"SliceQueueDescriptor"};
3459
3460 ValidateNumInputs(workloadInfo, descriptorName, 1);
3461 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3462
3463 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3464 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3465
3466 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3467
3468 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3469 if (rank > 4)
3470 {
3471 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3472 }
3473
3474 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3475
3476 // Check if m_Begin and m_Size have the expected length
3477 if (m_Parameters.m_Begin.size() != rank)
3478 {
3479 throw InvalidArgumentException(descriptorName +
3480 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3481 }
3482 if (m_Parameters.m_Size.size() != rank)
3483 {
3484 throw InvalidArgumentException(descriptorName +
3485 ": Length of size descriptor must equal rank " + std::to_string(rank));
3486 }
3487
3488 // Check if the shape of the output tensor matches m_Size
3489 const TensorShape& outputShape = outputTensorInfo.GetShape();
3490 for (unsigned int i = 0u; i < rank; ++i)
3491 {
3492 if (m_Parameters.m_Size[i] != outputShape[i])
3493 {
3494 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3495 }
3496 }
3497
3498 // Check if the sum of begin offset and size in a given dimension
3499 // does not exceed the size of corresponding input
3500 const TensorShape& inputShape = inputTensorInfo.GetShape();
3501 for(unsigned int i = 0u; i < rank; ++i)
3502 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003503 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003504 {
3505 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3506 std::to_string(i) + " exceeds input size.");
3507 }
3508 }
3509}
3510
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003511void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3512{
3513 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3514
3515 ValidateNumInputs(workloadInfo, descriptorName, 1);
3516 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3517
3518 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3519 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3520
3521 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3522 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3523
3524 std::vector<DataType> supportedTypes =
3525 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003526 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003527 DataType::Float32,
3528 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003529 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003530 DataType::QAsymmU8,
3531 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003532 };
3533
3534 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3535 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3536
3537 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3538
3539 if (m_Parameters.m_BlockSize == 0)
3540 {
3541 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3542 }
3543
3544 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3545 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3546 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3547 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3548
3549 const TensorShape& outputShape = outputInfo.GetShape();
3550 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3551 {
3552 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3553 "must be divisible by block size.");
3554 }
3555
3556 const TensorShape& inputShape = inputInfo.GetShape();
3557 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3558 {
3559 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3560 "must be divisible by the square of block size." );
3561 }
3562}
3563
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003564void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3565{
3566 const std::string descriptorName{"ComparisonQueueDescriptor"};
3567
3568 ValidateNumInputs(workloadInfo, descriptorName, 2);
3569 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3570
3571 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3572 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3573 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3574
3575 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3576 inputTensorInfo1,
3577 outputTensorInfo,
3578 descriptorName,
3579 "input_0",
3580 "input_1");
3581
3582 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3583 {
3584 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3585 }
3586}
3587
josh minor4a3c6102020-01-06 16:40:46 -06003588void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3589{
3590 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3591
3592 ValidateNumInputs(workloadInfo, descriptorName, 1);
3593 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3594
3595 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3596 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3597
3598 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3599
3600 std::vector<DataType> supportedTypes =
3601 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003602 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003603 DataType::Float16,
3604 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003605 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003606 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003607 DataType::QSymmS16,
3608 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003609 };
3610
James Conroyaba90cd2020-11-06 16:28:18 +00003611 std::vector<DataType> logicalSupportedTypes =
3612 {
3613 DataType::Boolean
3614 };
3615
3616 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3617 {
3618 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3619 }
3620 else
3621 {
3622 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3623 }
3624
3625
josh minor4a3c6102020-01-06 16:40:46 -06003626 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3627}
3628
Finn Williams2605b232020-06-10 15:53:46 +01003629void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3630{
3631 const std::string descriptorName{"RankQueueDescriptor"};
3632
3633 ValidateNumInputs(workloadInfo, descriptorName, 1);
3634 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3635
3636 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3637 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3638
3639 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3640 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3641
3642 std::vector<DataType> supportedTypes =
3643 {
3644 DataType::BFloat16,
3645 DataType::Float16,
3646 DataType::Float32,
3647 DataType::QAsymmS8,
3648 DataType::QAsymmU8,
3649 DataType::QSymmS8,
3650 DataType::QSymmS16,
3651 DataType::Signed32
3652 };
3653
3654 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3655 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3656}
3657
James Conroyaba90cd2020-11-06 16:28:18 +00003658void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3659{
3660 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3661
3662 ValidateNumInputs(workloadInfo, descriptorName, 2);
3663 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3664
3665 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3666 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3667 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3668
3669 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3670 inputTensorInfo1,
3671 outputTensorInfo,
3672 descriptorName,
3673 "input_0",
3674 "input_1");
3675
3676 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3677 {
3678 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3679 }
3680
3681 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3682 {
3683 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3684 }
3685
3686 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3687 {
3688 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3689 }
3690}
3691
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003692void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3693{
3694 const std::string descriptorName{"ReduceQueueDescriptor"};
3695
3696 ValidateNumInputs(workloadInfo, descriptorName, 1);
3697 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3698
3699 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3700 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3701
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003702 std::vector<DataType> supportedTypes =
3703 {
3704 DataType::BFloat16,
3705 DataType::Float16,
3706 DataType::Float32,
3707 DataType::QAsymmS8,
3708 DataType::QAsymmU8,
3709 DataType::QSymmS16,
3710 DataType::Signed32
3711 };
3712
3713 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3714 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3715}
3716
mathad01df9a3222021-04-28 11:42:57 +01003717} // namespace armnn