blob: 2c5303c019e2b24f6c64610dbb80c4635296ea6c [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00006#include <backendsCommon/WorkloadData.hpp>
7#include <backendsCommon/CpuTensorHandle.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/TensorUtils.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010010#include <armnn/utility/NumericCast.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000011
telsoa014fcda012018-03-09 14:13:49 +000012#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000014#include <string>
15#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000016
James Ward47fce872020-09-10 11:57:28 +010017#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000018
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
telsoa01c577f2c2018-08-31 09:22:23 +010029 case DataType::Float16:
30 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000031 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000032 case DataType::Float32:
33 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000034 case DataType::QAsymmS8:
35 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000036 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000037 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000038 case DataType::QSymmS8:
39 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000040 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010041 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000042 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010043 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000044 return DataType::Float32;
45 }
46}
47
48namespace
49{
50
51//---------------------------------------------------------------
52//android ndk does not support std::to_string function.
53template <typename T>
54std::string to_string(T value)
55{
56 std::ostringstream os;
57 os << value;
58 return os.str();
59}
60
61//---------------------------------------------------------------
62void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
63{
64 if (!ptr)
65 {
66 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
67 paramName + " parameter must be set.");
68 }
69}
70
71//---------------------------------------------------------------
72void ValidateTensorShapesMatch(const TensorInfo& first,
73 const TensorInfo& second,
74 std::string const& descName,
75 std::string const& firstName,
76 std::string const& secondName)
77{
78 if (first.GetShape() != second.GetShape())
79 {
80 throw InvalidArgumentException(descName + ": "
81 + firstName + " & " + secondName + " must have identical shapes");
82 }
83}
84
85//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010086void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000087{
Sadik Armaganeff363d2019-04-05 15:25:46 +010088 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000089 {
90 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010091 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000092 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
93 }
94}
95
96//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010097void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000098{
Sadik Armaganeff363d2019-04-05 15:25:46 +010099 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000100 {
101 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100102 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000103 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
104 }
105}
106
107//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100108void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000109 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100110 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000111 std::string const& tensorName)
112{
113 if (tensor.GetNumDimensions() != numDimensions)
114 {
115 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
116 to_string(tensor.GetNumDimensions()) + " dimensions for " +
117 tensorName + " tensor.");
118 }
119}
120
121//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100122void ValidateTensorNumElements(const TensorInfo& tensor,
123 std::string const& descName,
124 unsigned int numElements,
125 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100126{
127 if (tensor.GetNumElements() != numElements)
128 {
129 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100130 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100131 tensorName + " tensor.");
132 }
133}
134
135//---------------------------------------------------------------
136void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100137 unsigned int numDimension,
138 unsigned int numElements,
139 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100140{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100141 const std::string functionName{"ValidateTensorNumDimNumElem"};
142 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
143 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100144}
145
146//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000147void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
148 const std::string& descName, std::string const& tensorName)
149{
150 if (tensor.GetDataType() != dataType)
151 {
152 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
153 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
154 }
155}
156
Derek Lambertid466a542020-01-22 15:37:29 +0000157void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
158{
159 ARMNN_NO_DEPRECATE_WARN_BEGIN
160 if (tensor.GetDataType() != DataType::QSymmS8 &&
161 tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
162 {
163 throw InvalidArgumentException(descName +
164 ": Expected data type which supports per-axis quantization scheme but got " +
165 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
166 }
167 ARMNN_NO_DEPRECATE_WARN_END
168}
169
telsoa014fcda012018-03-09 14:13:49 +0000170//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100171void ValidateTensorQuantizationSpace(const TensorInfo& first,
172 const TensorInfo& second,
173 const std::string& descName,
174 std::string const& firstName,
175 std::string const& secondName)
176{
177 if (!first.IsQuantized() ||
178 !second.IsQuantized())
179 {
180 // Not a quantized type, ignore the validation
181 return;
182 }
183
184 DataType firstDataType = first.GetDataType();
185 DataType secondDataType = second.GetDataType();
186
187 if (firstDataType != secondDataType)
188 {
189 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
190 " must be of the same quantized type, " +
191 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
192 secondName + " is " + GetDataTypeName(secondDataType));
193 }
194
195 if (!first.IsTypeSpaceMatch(second))
196 {
197 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
198 " must have the same quantization space, " +
199 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
200 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
201 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
202 " and scale " + to_string(second.GetQuantizationScale()));
203 }
204}
205
206//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100207void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
208 const TensorInfo& inputTensorInfo,
209 const TensorInfo& weightsTensorInfo,
210 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000211{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000212 // Helper lambda function to validate a single bias quantization scale value
213 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
214 {
ricbur013f4d7102019-10-31 16:22:18 +0000215 constexpr float tolerance = 0.000001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000216 if (std::abs(biasScale - expectedScale) > tolerance)
217 {
218 // Print the float values with extra precision to see very small differences
219 std::stringstream msg;
220 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
221 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
222 biasScale;
223 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
224 }
225 };
226
telsoa014fcda012018-03-09 14:13:49 +0000227 if (biasTensor.GetQuantizationOffset() != 0)
228 {
229 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
230 to_string(biasTensor.GetQuantizationOffset()));
231 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000232
James Conroy8502ade2020-11-12 19:26:29 +0000233 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000234 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000235 // Validate per-axis quantization scales
236 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
237 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
238
239 if (weightScales.size() != biasScales.size())
240 {
241 std::stringstream msg;
James Conroy8502ade2020-11-12 19:26:29 +0000242 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
243 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
244 << ", biases=" << biasScales.size();
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000245 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
246 }
247
248 for (size_t i = 0ul; i < biasScales.size(); ++i)
249 {
250 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
251 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
252 }
253 }
254 else
255 {
256 // Validate per-tensor quantization scale
257 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
258 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000259 }
260}
261
262//---------------------------------------------------------------
263void ValidateTensors(const std::vector<ITensorHandle*>& vec,
264 unsigned int numExpected,
265 const std::string& descName,
266 const std::string& varName)
267{
268 if (vec.empty() && numExpected > 0)
269 {
270 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
271 }
272
273 for (unsigned int i = 0; i < numExpected; ++i)
274 {
275 if (!vec[i])
276 {
277 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
278 }
279 }
280}
281
282//---------------------------------------------------------------
283void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
284 const TensorInfo& second,
285 const TensorInfo& output,
286 std::string const& descName,
287 std::string const& firstName,
288 std::string const& secondName)
289{
290 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
291 // broadcasted.
292 if (first.GetNumDimensions() != second.GetNumDimensions())
293 {
294 throw InvalidArgumentException(descName + ": Tensors "
295 + firstName + " & " + secondName
296 + " must have the same number of dimensions in order to be broadcasted");
297 }
298 uint32_t numDims = first.GetNumDimensions();
299 std::vector<uint32_t> outputDims(numDims, 0u);
300 for (uint32_t i = 0; i < numDims; i++)
301 {
302 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
303 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
304 if (dimsNotEqual && dimsNotOne)
305 {
306 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
307 }
308 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
309 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100310 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000311 if (broadcastShape != output.GetShape())
312 {
313 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
314 + firstName + " & " + secondName
315 + " does not match the output shape");
316 }
317}
318
319//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100320void ValidateDataTypes(const TensorInfo& info,
321 const std::vector<armnn::DataType>& supportedTypes,
322 std::string const& descName)
323{
324 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
325 if (iterator == supportedTypes.end())
326 {
327 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
328 }
329}
330
James Conroy4d1ff582019-06-10 17:06:39 +0100331//---------------------------------------------------------------
332void ValidateTensorDataTypesMatch(const TensorInfo& first,
333 const TensorInfo& second,
334 std::string const& descName,
335 std::string const& firstName,
336 std::string const& secondName)
337{
338 if (first.GetDataType() != second.GetDataType())
339 {
340 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
341 " must have identical data types.");
342 }
343}
344
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100345//---------------------------------------------------------------
346void ValidateTensorNumElementsMatch(const TensorInfo& first,
347 const TensorInfo& second,
348 std::string const& descName,
349 std::string const& firstName,
350 std::string const& secondName)
351{
352 if (first.GetNumElements() != second.GetNumElements())
353 {
354 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
355 " must have the same number of elements.");
356 }
357}
358
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000359void ValidateWeightDataType(const TensorInfo& inputInfo,
360 const TensorInfo& weightInfo,
361 const std::string& descName)
362{
363 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000364 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000365 {
Derek Lambertid466a542020-01-22 15:37:29 +0000366 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000367 const std::vector<DataType> validTypes =
368 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000369 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100370 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000371 DataType::QSymmS8,
372 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000373 };
Derek Lambertid466a542020-01-22 15:37:29 +0000374 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000375
376 ValidateDataTypes(weightInfo, validTypes, descName);
377 }
378 else
379 {
380 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
381 }
382}
383
384void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
385 const std::string& descName,
386 const std::string& tensorName)
387{
388 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
389 if (!quantizationDim.has_value())
390 {
James Ward47fce872020-09-10 11:57:28 +0100391 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
392 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000393 }
394
395 if (quantizationDim.value() != 0)
396 {
James Ward47fce872020-09-10 11:57:28 +0100397 throw InvalidArgumentException(fmt::format(
398 "{0}: Quantization dimension for per-axis quantization expected to be 0 on tensor {1}, "
399 "but got: {2}", descName, tensorName, quantizationDim.value()));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000400 }
401}
402
403void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
404 const std::string& descName,
405 const std::string& tensorName)
406{
407 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
408 if (quantizationOffset != 0)
409 {
James Ward47fce872020-09-10 11:57:28 +0100410 throw InvalidArgumentException(fmt::format(
411 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
412 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000413 }
414}
415
416void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
417 const TensorInfo& outputInfo,
418 const TensorInfo& weightInfo,
419 const Optional<TensorInfo>& optionalBiasInfo,
420 const std::string& descName)
421{
422 if (weightInfo.HasPerAxisQuantization())
423 {
424 const DataType inputDataType = inputInfo.GetDataType();
425 const DataType outputDataType = outputInfo.GetDataType();
426
Keith Davis0c2eeac2020-02-11 16:51:50 +0000427 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000428
429 if (!canHavePerAxisQuantization)
430 {
James Ward47fce872020-09-10 11:57:28 +0100431 throw InvalidArgumentException(fmt::format(
432 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
433 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000434 }
435
Derek Lambertid466a542020-01-22 15:37:29 +0000436
437 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000438 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
439 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
440
441 if (optionalBiasInfo.has_value())
442 {
443 const TensorInfo& biasInfo = optionalBiasInfo.value();
444 if (!biasInfo.HasPerAxisQuantization())
445 {
James Ward47fce872020-09-10 11:57:28 +0100446 throw InvalidArgumentException(fmt::format(
447 "{}: Per-axis quantization parameters not set on bias tensor, "
448 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000449 }
450
451 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
452 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
453 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
454 }
455 }
456}
457
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100458} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000459
460void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
461 unsigned int numExpectedIn, unsigned int numExpectedOut) const
462{
463 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
464 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
465}
466
467//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100468void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
469{
470 const std::string descriptorName{"MapQueueDescriptor"};
471
472 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100473 ValidateNumOutputs(workloadInfo, descriptorName, 0);
474
475 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
476 {
477 if (!m_Inputs[i])
478 {
479 throw InvalidArgumentException(
480 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
481 }
482 }
483}
484
485//---------------------------------------------------------------
486void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
487{
488 const std::string descriptorName{"UnmapQueueDescriptor"};
489
490 ValidateNumInputs(workloadInfo, descriptorName, 1);
491 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100492
493 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
494 {
495 if (!m_Inputs[i])
496 {
497 throw InvalidArgumentException(
498 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
499 }
500 }
501}
502
503//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000504void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
505{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100506 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000507
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100508 ValidateNumInputs(workloadInfo, descriptorName, 1);
509 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000510
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100511 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
512 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
513
514 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
515 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000516
517 if (m_Inputs.size() != m_Outputs.size())
518 {
James Ward47fce872020-09-10 11:57:28 +0100519 throw InvalidArgumentException(fmt::format(
520 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
521 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000522 }
523
524 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
525 {
526 if (!m_Inputs[i])
527 {
James Ward47fce872020-09-10 11:57:28 +0100528 throw InvalidArgumentException(fmt::format(
529 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000530 }
531
532 if (!m_Outputs[i])
533 {
James Ward47fce872020-09-10 11:57:28 +0100534 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000535 }
536 }
537}
538
Derek Lambertif674aa02019-08-01 15:56:25 +0100539//---------------------------------------------------------------
540void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
541{
542 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
543 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
544
545 if (workloadInfo.m_InputTensorInfos.size() != 1)
546 {
James Ward47fce872020-09-10 11:57:28 +0100547 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
548 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100549
550 }
551
552 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
553 {
James Ward47fce872020-09-10 11:57:28 +0100554 throw InvalidArgumentException(fmt::format(
555 "Number of input infos ({0}) does not match the number of output infos ({1})",
556 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100557 }
558
559 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
560 {
561 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
562 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
563 {
James Ward47fce872020-09-10 11:57:28 +0100564 throw InvalidArgumentException(fmt::format(
565 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100566 }
567 }
568
569 if (m_Inputs.size() != 1)
570 {
James Ward47fce872020-09-10 11:57:28 +0100571 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100572 }
573
574 if (m_Inputs.size() != m_Outputs.size())
575 {
James Ward47fce872020-09-10 11:57:28 +0100576 throw InvalidArgumentException(fmt::format(
577 "Number of inputs ({0}) does not match the number of outputs ({1})",
578 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100579 }
580
581 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
582 {
583 if (!m_Inputs[i])
584 {
James Ward47fce872020-09-10 11:57:28 +0100585 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100586 }
587
588 if (!m_Outputs[i])
589 {
James Ward47fce872020-09-10 11:57:28 +0100590 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100591 }
592 }
593}
594
595//---------------------------------------------------------------
596void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
597{
598 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
599 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
600
Derek Lambertif674aa02019-08-01 15:56:25 +0100601 if (m_Inputs.size() != 1)
602 {
James Ward47fce872020-09-10 11:57:28 +0100603 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100604 }
605
606 if (m_Outputs.size() != 0)
607 {
James Ward47fce872020-09-10 11:57:28 +0100608 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100609 }
610
611 if (!m_Inputs[0])
612 {
James Ward47fce872020-09-10 11:57:28 +0100613 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100614 }
615}
616
617//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000618void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
619{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100620 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100621
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100622 ValidateNumInputs(workloadInfo, descriptorName, 1);
623 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100624
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100625 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
626 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100627
628 std::vector<DataType> supportedTypes =
629 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000630 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100631 DataType::Float16,
632 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000633 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000634 DataType::QAsymmU8,
635 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100636 };
637
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100638 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
639 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
640 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000641}
642
Nikhil Rajee391d52019-09-05 17:50:44 +0100643void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
644{
645 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
646
647 ValidateNumInputs(workloadInfo, descriptorName, 1);
648 ValidateNumOutputs(workloadInfo, descriptorName, 1);
649
650 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
651 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
652
Inki Daed4619e22020-09-10 15:33:54 +0900653 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
654 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100655 {
Inki Daed4619e22020-09-10 15:33:54 +0900656 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100657 }
658
James Conroyd47a0642019-09-17 14:22:06 +0100659 std::vector<DataType> supportedInputTypes =
660 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000661 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100662 DataType::Float16,
663 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100664 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000665 DataType::QAsymmU8,
666 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900667 DataType::Signed32,
668 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100669 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100670
James Conroyd47a0642019-09-17 14:22:06 +0100671 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100672
673 auto inputShape = inputTensorInfo.GetShape();
674 auto outputShape = outputTensorInfo.GetShape();
675
676 auto inputNumDimensions = inputShape.GetNumDimensions();
677 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
678
679 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
680
681 // 1D input shape results in scalar output shape
682 if (inputShape.GetNumDimensions() == 1)
683 {
684 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
685 {
686 throw InvalidArgumentException(descriptorName + outputShapeError);
687 }
688 }
689 else
690 {
691 for (unsigned int i = 0; i < unsignedAxis; ++i)
692 {
693 if (outputShape[i] != inputShape[i])
694 {
695 throw InvalidArgumentException(descriptorName + outputShapeError);
696 }
697 }
698
699 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
700 {
701 if (outputShape[i - 1] != inputShape[i])
702 {
703 throw InvalidArgumentException(descriptorName + outputShapeError);
704 }
705 }
706 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100707}
708
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100709void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
710{
711 const std::string descriptorName{"SoftmaxQueueDescriptor"};
712
713 ValidateNumInputs(workloadInfo, descriptorName, 1);
714 ValidateNumOutputs(workloadInfo, descriptorName, 1);
715
716 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
717 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
718
719 std::vector<DataType> supportedTypes =
720 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000721 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100722 DataType::Float16,
723 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000724 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000725 DataType::QAsymmU8,
726 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100727 };
728
729 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
730 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
731 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
732}
733
telsoa014fcda012018-03-09 14:13:49 +0000734void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
735{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100736 const std::string descriptorName{"SplitterQueueDescriptor"};
737
738 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000739
Ruomei Yan25339c32019-05-28 16:48:20 +0100740 // Check the supported data types
741 std::vector<DataType> supportedTypes =
742 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000743 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100744 DataType::Float32,
745 DataType::Float16,
746 DataType::Boolean,
747 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100748 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000749 DataType::QAsymmU8,
750 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100751 };
752
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100753 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
754 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100755 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100756 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
757 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
758
759 const std::string outputName = "output_" + std::to_string(i);
760 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100761 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100762
telsoa014fcda012018-03-09 14:13:49 +0000763 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
764 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100765 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000766 }
767
768 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
769 {
770 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100771 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000772 "has to match number of workloadInfo.m_OutputTensorInfos. "
773 "Number of windows: " +
774 to_string(m_ViewOrigins.size()) +
775 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
776 }
777
telsoa01c577f2c2018-08-31 09:22:23 +0100778 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000779 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
780 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
781 {
telsoa01c577f2c2018-08-31 09:22:23 +0100782 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000783 ViewOrigin const& e = m_ViewOrigins[w];
784 if (e.m_Origin.size() != inputDims)
785 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100786 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000787 "have the same dimensionality as the input tensor. "
788 "Window origin (index: " +
789 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
790 " dimensions, the input "
791 "tensor has " +
792 to_string(inputDims) + " dimensions.");
793 }
794 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
795 {
796 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
797 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
798 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100799 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000800 "be smaller or equal than the size of the input in that coord.");
801 }
802 }
803 }
804}
805
Jim Flynne242f2d2019-05-22 14:24:13 +0100806void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000807{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100808 const std::string descriptorName{"ConcatQueueDescriptor"};
809
810 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000811
812 if (m_Inputs.size() <= 0)
813 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100814 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000815 }
816 if (m_Outputs.size() <= 0)
817 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100818 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000819 }
820
821 if (workloadInfo.m_InputTensorInfos.size() <= 0)
822 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100823 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000824 }
825 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
826 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100827 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000828 }
829
Nikhil Raj8599a412018-11-19 14:51:07 +0000830 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
831 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100832 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000833 }
834
835 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
836 {
837 return;
838 }
839
telsoa014fcda012018-03-09 14:13:49 +0000840 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
841 {
842 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100843 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000844 "has to match number of workloadInfo.m_InputTensorInfos. "
845 "Number of windows: " +
846 to_string(m_ViewOrigins.size()) +
847 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
848 }
849
telsoa01c577f2c2018-08-31 09:22:23 +0100850 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000851 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
852 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
853 {
telsoa01c577f2c2018-08-31 09:22:23 +0100854 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000855 ViewOrigin const& e = m_ViewOrigins[w];
856 if (e.m_Origin.size() != outputDims)
857 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100858 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000859 "have the same dimensionality as the output tensor. "
860 "Window origin (index: " +
861 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
862 " dimensions, the output "
863 "tensor has " +
864 to_string(outputDims) + " dimensions.");
865 }
telsoa01c577f2c2018-08-31 09:22:23 +0100866 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000867 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
868 {
869 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
870 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
871 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100872 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000873 "be smaller or equal than the size of the output in that coord.");
874 }
875 }
876 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100877
878 // Check the supported data types
879 std::vector<DataType> supportedTypes =
880 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000881 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100882 DataType::Float32,
883 DataType::Float16,
884 DataType::Boolean,
885 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100886 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000887 DataType::QAsymmU8,
888 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100889 };
890
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100891 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
892 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100893 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100894 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
895 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
896
897 const std::string inputName = "input_" + std::to_string(i);
898 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100899 }
telsoa014fcda012018-03-09 14:13:49 +0000900}
901
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100902void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
903{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100904 const std::string descriptorName{"StackQueueDescriptor"};
905
906 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100907
908 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
909 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100910 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100911 }
912
913 // All inputs must have the same shape, which is defined in parameters
914 const TensorShape& inputShape = m_Parameters.m_InputShape;
915 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
916 {
917 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
918 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100919 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100920 }
921 }
922
Matthew Jacksondba634f2019-08-15 15:14:18 +0100923 if (inputShape.GetNumDimensions() > 4)
924 {
925 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
926 }
927
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100928 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
929 // since the output tensor has an additional dimension.
930 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
931 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100932 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100933 "than the number of input dimensions.");
934 }
935
936 // Output shape must be as inferred from the input shape
937 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
938 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
939 {
940 if (outputShape[i] != inputShape[i])
941 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100942 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100943 "match shape inferred from input tensor.");
944 }
945 }
946
947 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
948 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100949 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100950 "match shape inferred from input tensor.");
951 }
952
953 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
954 {
955 if (outputShape[i] != inputShape[i-1])
956 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100957 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100958 "match shape inferred from input tensor.");
959 }
960 }
961
Matthew Jacksondba634f2019-08-15 15:14:18 +0100962 if (outputShape.GetNumDimensions() > 5)
963 {
964 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
965 }
966
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100967 // Check the supported data types
968 std::vector<DataType> supportedTypes =
969 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000970 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100971 DataType::Float32,
972 DataType::Float16,
973 DataType::Boolean,
974 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100975 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000976 DataType::QAsymmU8,
977 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100978 };
979
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100980 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100981
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100982 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100983 {
984 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
985 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100986 descriptorName,
987 "input_0",
988 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100989 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100990
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100991 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
992 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100993 descriptorName,
994 "input_0",
995 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100996}
997
Ryan OSheaec6c6802020-06-05 17:17:06 +0100998void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
999{
1000 const std::string descriptorName{"FillQueueDescriptor"};
1001
1002 ValidateNumInputs(workloadInfo, descriptorName, 1);
1003 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1004
1005 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1006 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1007
1008 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1009
1010 std::vector<DataType> supportedTypes =
1011 {
1012 DataType::BFloat16,
1013 DataType::Float32,
1014 DataType::Float16,
1015 DataType::Signed32
1016 };
1017
1018 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1019}
1020
telsoa014fcda012018-03-09 14:13:49 +00001021void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1022{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001023 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001024
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001025 uint32_t numInputs = 1;
1026 if (!m_Parameters.m_ConstantWeights)
1027 {
1028 numInputs = 2;
1029 if (m_Parameters.m_BiasEnabled)
1030 {
1031 numInputs = 3;
1032 }
1033 }
1034 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001035 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1036
1037 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1038 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1039
1040 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1041
1042 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001043 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001044 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001045 }
1046
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001047 TensorInfo weightTensorInfo;
1048 if (m_Parameters.m_ConstantWeights)
1049 {
1050 ValidatePointer(m_Weight, descriptorName, "weight");
1051 weightTensorInfo = m_Weight->GetTensorInfo();
1052 }
1053 else
1054 {
1055 weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
1056 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001057 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001058
1059 if (m_Parameters.m_BiasEnabled)
1060 {
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001061 TensorInfo biasTensorInfo;
1062 if (m_Parameters.m_ConstantWeights)
1063 {
1064 ValidatePointer(m_Bias, descriptorName, "bias");
1065 biasTensorInfo = m_Bias->GetTensorInfo();
1066 }
1067 else
1068 {
1069 biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
1070 }
telsoa01c577f2c2018-08-31 09:22:23 +01001071 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001072 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001073 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1074 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001075 }
1076
Francis Murtagh46c09d02019-05-28 08:15:28 +01001077 // Check the supported data types
1078 std::vector<DataType> supportedTypes =
1079 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001080 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001081 DataType::Float32,
1082 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001083 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001084 DataType::QAsymmU8,
1085 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001086 };
1087
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001088 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001089
1090 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1091 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1092 {
1093 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1094 {
1095 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1096 "for BFloat16 input.");
1097 }
1098 }
1099 else
1100 {
1101 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1102 }
telsoa014fcda012018-03-09 14:13:49 +00001103}
1104
telsoa014fcda012018-03-09 14:13:49 +00001105void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1106{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001107 const std::string descriptorName{"NormalizationQueueDescriptor"};
1108
1109 ValidateNumInputs(workloadInfo, descriptorName, 1);
1110 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1111
1112 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1113 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001114
1115 // Check the supported data types
1116 std::vector<DataType> supportedTypes =
1117 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001118 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001119 DataType::Float16,
1120 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001121 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001122 DataType::QAsymmU8,
1123 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001124 };
1125
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001126 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001127
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001128 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001129
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001130 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001131}
1132
1133void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1134{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001135 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001136
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001137 ValidateNumInputs(workloadInfo, descriptorName, 2);
1138 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1139
1140 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1141 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1142 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1143
1144 std::vector<DataType> supportedTypes =
1145 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001146 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001147 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001148 DataType::Float16,
1149 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001150 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001151 DataType::QSymmS16,
1152 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001153 };
1154
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001155 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1156 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1157 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001158
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001159 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1160 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001161
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001162 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1163 inputTensorInfo1,
1164 outputTensorInfo,
1165 descriptorName,
1166 "input_0",
1167 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001168}
1169
telsoa014fcda012018-03-09 14:13:49 +00001170void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1171{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001172 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001173
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001174 ValidateNumInputs(workloadInfo, descriptorName, 2);
1175 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1176
1177 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1178 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1179 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1180
1181 std::vector<DataType> supportedTypes =
1182 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001183 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001184 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001185 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001186 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001187 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001188 DataType::QSymmS16,
1189 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001190 };
1191
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001192 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1193 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1194 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001195
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001196 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1197 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001198
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001199 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1200 inputTensorInfo1,
1201 outputTensorInfo,
1202 descriptorName,
1203 "input_0",
1204 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001205}
1206
1207void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1208{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001209 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001210
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001211 ValidateNumInputs(workloadInfo, descriptorName, 1);
1212 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1213
1214 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1215 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001216
1217 std::vector<DataType> supportedTypes =
1218 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001219 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001220 DataType::Float16,
1221 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001222 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001223 DataType::QAsymmU8,
1224 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001225 };
1226
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001227 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1228 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001229
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001230 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001231 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001232
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001233 ValidatePointer(m_Mean, descriptorName, "mean");
1234 ValidatePointer(m_Variance, descriptorName, "variance");
1235 ValidatePointer(m_Beta, descriptorName, "beta");
1236 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001237
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001238 const TensorInfo& mean = m_Mean->GetTensorInfo();
1239 const TensorInfo& variance = m_Variance->GetTensorInfo();
1240 const TensorInfo& beta = m_Beta->GetTensorInfo();
1241 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001242
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001243 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1244 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1245 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1246 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001247
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001248 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1249 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1250 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001251}
1252
1253void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1254{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001255 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001256
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001257 ValidateNumInputs(workloadInfo, descriptorName, 1);
1258 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001259
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001260 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1261 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001262
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001263 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1264 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001265
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001266 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001267
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001268 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1269 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001270
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001271 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001272
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001273 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001274 if (m_Parameters.m_BiasEnabled)
1275 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001276 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001277
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001278 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1279 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001280
1281 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1282 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001283 }
1284
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001285 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1286 {
1287 throw InvalidArgumentException(
1288 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1289 "cannot be either negative or 0.",
1290 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1291 }
1292
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001293 ValidatePerAxisQuantization(inputTensorInfo,
1294 outputTensorInfo,
1295 weightTensorInfo,
1296 optionalBiasTensorInfo,
1297 descriptorName);
1298
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001299 std::vector<DataType> supportedTypes =
1300 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001301 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001302 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001303 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001304 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001305 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001306 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001307 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001308 };
1309
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001310 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001311
1312 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1313 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1314 {
1315 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1316 {
1317 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1318 "for BFloat16 input.");
1319 }
1320 }
1321 else
1322 {
1323 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1324 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001325}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001326
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001327void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1328{
1329 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1330
1331 ValidateNumInputs(workloadInfo, descriptorName, 1);
1332 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1333
1334 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1335 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1336
1337 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1338 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1339
1340 ValidatePointer(m_Weight, descriptorName, "weight");
1341
1342 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1343 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1344
1345 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1346 {
1347 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001348 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1349 "cannot be smaller than 1.",
1350 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001351 }
1352
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001353 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1354 {
1355 throw InvalidArgumentException(
1356 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1357 "cannot be either negative or 0.",
1358 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1359 }
1360
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001361 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1362
1363 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1364 // inputChannels * channelMultiplier should be equal to outputChannels.
1365 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1366 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1367 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1368 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1369 {
James Ward47fce872020-09-10 11:57:28 +01001370 throw InvalidArgumentException(fmt::format(
1371 "{0}: output_channels (provided {1}) should be equal to input_channels (provided {2}) "
1372 "multiplied by channel_multiplier (provided {3}).",
1373 descriptorName, numWeightOutputChannels, numWeightInputChannels, numWeightChannelMultiplier));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001374 }
1375
Teresa Charlind8df0262019-11-11 12:28:15 +00001376 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001377
Teresa Charlind8df0262019-11-11 12:28:15 +00001378 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001379 if (m_Parameters.m_BiasEnabled)
1380 {
1381 ValidatePointer(m_Bias, descriptorName, "bias");
1382
Teresa Charlind8df0262019-11-11 12:28:15 +00001383 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1384 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001385
1386 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1387 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1388 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001389 ValidatePerAxisQuantization(inputTensorInfo,
1390 outputTensorInfo,
1391 weightTensorInfo,
1392 optionalBiasTensorInfo,
1393 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001394
1395 std::vector<DataType> supportedTypes =
1396 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001397 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001398 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001399 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001400 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001401 DataType::QAsymmU8,
1402 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001403 };
1404
1405 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1406 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001407}
1408
1409void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1410{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001411 const std::string descriptorName{"PermuteQueueDescriptor"};
1412
1413 ValidateNumInputs(workloadInfo, descriptorName, 1);
1414 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001415
1416 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1417
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001418 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1419 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001420
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001421 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1422 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001423
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001424 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001425 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001426 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001427 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001428 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1429 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1430 "must match dst dimension " + to_string(mapping[i]) +
1431 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001432 }
1433 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001434
1435 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001436}
1437
1438void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1439{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001440 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001441
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001442 ValidateNumInputs(workloadInfo, descriptorName, 1);
1443 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1444
1445 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1446 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1447
1448 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1449 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001450
1451 std::vector<DataType> supportedTypes =
1452 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001453 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001454 DataType::Float32,
1455 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001456 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001457 DataType::QAsymmU8,
1458 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001459 };
1460
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001461 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1462 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001463}
1464
1465void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1466{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001467 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001468
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001469 ValidateNumInputs(workloadInfo, descriptorName, 1);
1470 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1471
1472 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1473 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1474
1475 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1476 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001477
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001478 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001479 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001480 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001481 DataType::Float16,
1482 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001483 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001484 DataType::QAsymmU8,
1485 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001486 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001487
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001488 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1489 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001490
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001491 // ResizeBilinear only changes width and height: batch and channel count must match.
1492 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1493 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001494 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001495 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001496 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001497 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1498 descriptorName, inputBatchSize, outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001499 }
1500
Teresa Charlin970f43b2019-07-01 13:51:07 +01001501 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001502 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1503 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001504 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001505 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001506 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001507 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1508 descriptorName, inputChannelCount, outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001509 }
1510}
1511
1512void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1513{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001514 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001515
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001516 ValidateNumInputs(workloadInfo, descriptorName, 1);
1517 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1518
1519 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1520 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1521
1522 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1523 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001524
1525 std::vector<DataType> supportedTypes =
1526 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001527 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001528 DataType::Float16,
1529 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001530 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001531 DataType::QAsymmU8,
1532 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001533 };
1534
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001535 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1536 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001537
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001538 // Resize only changes width and height: batch and channel count must match.
1539 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1540 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001541 if (inputBatchSize != outputBatchSize)
1542 {
1543 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001544 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1545 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001546 }
1547
1548 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001549 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1550 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001551 if (inputChannelCount != outputChannelCount)
1552 {
1553 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001554 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1555 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001556 }
1557}
1558
1559void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1560{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001561 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001562
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001563 ValidateNumInputs(workloadInfo, descriptorName, 1);
1564 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1565
1566 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1567 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1568
1569 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1570 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1571
1572 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1573
telsoa014fcda012018-03-09 14:13:49 +00001574 if (m_Parameters.m_Min > m_Parameters.m_Max)
1575 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001576 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001577 }
telsoa014fcda012018-03-09 14:13:49 +00001578}
1579
Kevin Mayce5045a2019-10-02 14:07:47 +01001580void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1581{
1582 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1583
1584 ValidateNumInputs(workloadInfo, descriptorName, 1);
1585 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1586
1587 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1588 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1589
1590 if (inputTensorInfo.GetNumDimensions() > 4)
1591 {
1592 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1593 }
1594
1595 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1596
1597 // Check the supported data types
1598 std::vector<DataType> supportedTypes =
1599 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001600 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001601 DataType::Float32,
1602 DataType::Float16
1603 };
1604
1605 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001606 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001607}
1608
telsoa014fcda012018-03-09 14:13:49 +00001609void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1610{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001611 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001612
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001613 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001614 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1615
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001616 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1617 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1618
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001619 if (inputTensorInfo.GetNumDimensions() > 4)
1620 {
1621 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1622 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001623
1624 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001625
1626 // Check the supported data types
1627 std::vector<DataType> supportedTypes =
1628 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001629 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001630 DataType::Float32,
1631 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001632 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001633 DataType::QAsymmU8,
1634 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001635 };
1636
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001637 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001638 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1639}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001640
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001641void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1642{
1643 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1644
1645 ValidateNumInputs(workloadInfo, descriptorName, 1);
1646 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1647
1648 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1649 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1650
1651 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1652
1653 std::vector<DataType> supportedTypes =
1654 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001655 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001656 DataType::Float32,
1657 DataType::Float16,
1658 };
1659
1660 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001661 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001662}
1663
1664void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1665{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001666 const std::string descriptorName{"ConstantQueueDescriptor"};
1667
1668 ValidateNumInputs(workloadInfo, descriptorName, 0);
1669 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001670
1671 if (!m_LayerOutput)
1672 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001673 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001674 }
1675
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001676 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1677 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001678
1679 // Check the supported data types
1680 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001681 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001682 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001683 DataType::Float32,
1684 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001685 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001686 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001687 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001688 DataType::QSymmS16,
1689 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001690 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001691
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001692 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001693}
1694
1695void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1696{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001697 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001698
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001699 ValidateNumInputs(workloadInfo, descriptorName, 1);
1700 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1701
1702 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1703 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1704
1705 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001706
1707 // Check the supported data types
1708 std::vector<DataType> supportedTypes =
1709 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001710 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001711 DataType::Float32,
1712 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001713 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001714 DataType::QAsymmU8,
1715 DataType::QSymmS16,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001716 DataType::Signed32,
1717 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001718 };
1719
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001720 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1721 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001722}
1723
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001724void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1725{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001726 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001727
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001728 ValidateNumInputs(workloadInfo, descriptorName, 1);
1729 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1730
1731 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1732 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1733
1734 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1735 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001736
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001737 if (m_Parameters.m_BlockShape.size() != 2)
1738 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001739 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001740 }
1741
1742 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1743 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001744 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1745 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001746 }
1747
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001748 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001749
1750 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001751 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001752
Matthew Bentham8800c002018-11-19 13:19:28 +00001753 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001754
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001755 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1756 widthPad.first + widthPad.second;
1757 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1758 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001759
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001760 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1761 inputShape[dimensionIndices.GetChannelsIndex()];
1762 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001763
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001764 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001765 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001766 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001767 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001768 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001769 }
1770
1771 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001772 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001773 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1774 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001775 }
nikraj01120522a2019-05-31 11:33:07 +01001776
1777 std::vector<DataType> supportedTypes =
1778 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001779 DataType::BFloat16,
1780 DataType::Float16,
1781 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001782 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001783 DataType::QAsymmU8,
1784 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001785 };
1786
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001787 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1788 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001789}
1790
Keith Davisa57eccb2019-06-14 17:33:22 +01001791void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1792{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001793 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001794
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001795 ValidateNumInputs(workloadInfo, descriptorName, 1);
1796 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001797
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001798 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1799 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1800
1801 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1802 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001803
1804 std::vector<DataType> supportedTypes =
1805 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001806 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001807 DataType::Float32,
1808 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001809 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001810 DataType::QAsymmU8,
1811 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001812 };
1813
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001814 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1815 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001816
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001817 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1818
1819 if (m_Parameters.m_BlockSize == 0)
1820 {
1821 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1822 }
1823
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001824 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1825 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1826 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1827 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001828
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001829 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001830 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001831 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001832 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1833 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001834 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001835
1836 const TensorShape& outputShape = outputTensorInfo.GetShape();
1837 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1838 {
1839 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1840 "must be divisible by the square of block size." );
1841 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001842}
1843
telsoa014fcda012018-03-09 14:13:49 +00001844void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1845{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001846 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001847
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001848 ValidateNumInputs(workloadInfo, descriptorName, 1);
1849 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1850
1851 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1852 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001853
1854 std::vector<DataType> supportedTypes =
1855 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001856 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001857 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001858 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001859 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001860 };
1861
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001862 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001863
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001864 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001865 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001866 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001867 }
1868}
1869
telsoa01c577f2c2018-08-31 09:22:23 +01001870void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1871{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001872 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1873
1874 const std::string descriptorName{"LstmQueueDescriptor"};
1875
1876 // check dimensions of all inputs and outputs
1877 if (workloadInfo.m_InputTensorInfos.size() != 3)
1878 {
1879 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1880 }
1881 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1882 {
1883 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1884 }
1885
1886 std::vector<DataType> supportedTypes =
1887 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001888 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001889 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001890 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001891 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001892 };
1893
Jan Eilers38e05bd2019-06-26 13:10:09 +01001894 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001895 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1896
Jan Eilers38e05bd2019-06-26 13:10:09 +01001897 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001898 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001899 {
1900 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1901 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001902 descriptorName,
1903 "input_0",
1904 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001905 }
1906 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001907 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001908 {
1909 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1910 workloadInfo.m_OutputTensorInfos[i],
1911 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001912 "input_0",
1913 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001914 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001915
janeil0117d8d852019-11-15 15:00:16 +00001916 // Making sure clipping parameters have valid values.
1917 // == 0 means no clipping
1918 // > 0 means clipping
1919 if (m_Parameters.m_ClippingThresCell < 0.0f)
1920 {
1921 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
1922 }
1923 if (m_Parameters.m_ClippingThresProj < 0.0f)
1924 {
1925 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
1926 }
1927
Jan Eilers38e05bd2019-06-26 13:10:09 +01001928
1929 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01001930 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1931 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1932 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1933 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1934 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1935 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1936
Jan Eilers38e05bd2019-06-26 13:10:09 +01001937 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001938 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1939 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001940 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001941 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1942 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001943 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001944 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1945 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001946 // scratchBufferTensor
1947 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001948 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1949 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001950 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001951 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1952 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001953 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001954 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1955 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001956 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001957 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1958 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001959
1960
1961 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1962 if ( m_InputToInputWeights )
1963 {
1964 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1965 (n_cell * n_input), "InputLayerNormWeights");
1966 }
1967
1968 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1969 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1970 (n_cell * n_input), "InputToForgetWeights");
1971
1972 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1973 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1974 (n_cell * n_input), "InputToCellWeights");
1975
1976 if ( m_RecurrentToInputWeights )
1977 {
1978 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1979 (n_cell * n_output), "RecurrentToInputWeights");
1980 }
1981
1982 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1983 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1984 (n_cell * n_output), "RecurrentToForgetWeights");
1985
1986 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1987 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1988 (n_cell * n_output), "RecurrentToCellWeights");
1989
1990 // Make sure the input-gate's parameters are either both present (regular
1991 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1992 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1993 !m_Parameters.m_CifgEnabled) ||
1994 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1995 m_Parameters.m_CifgEnabled));
1996 if (!cifg_weights_all_or_none)
1997 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001998 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1999 "RecurrentToInputWeights must either both be present (regular LSTM) "
2000 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2001 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002002 }
2003
2004 if ( m_CellToInputWeights )
2005 {
2006 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2007 n_cell, "CellToInputWeights");
2008 }
2009 if ( m_CellToForgetWeights )
2010 {
2011 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2012 n_cell, "CellToForgetWeights");
2013 }
2014 if ( m_CellToOutputWeights )
2015 {
2016 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2017 n_cell, "CellToOutputWeights");
2018 }
2019
2020 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2021 bool peephole_weights_all_or_none =
2022 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2023 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2024 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2025 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2026 if (!peephole_weights_all_or_none)
2027 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002028 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002029 }
2030
2031 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2032 if (m_Parameters.m_CifgEnabled)
2033 {
2034 if (m_InputGateBias)
2035 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002036 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002037 }
2038 }
2039 else
2040 {
2041 if (!m_InputGateBias)
2042 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002043 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2044 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002045 }
2046 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2047 n_cell, "InputGateBias");
2048 }
2049
2050 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2051 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2052
2053 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2054 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2055
2056 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2057 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2058
2059 if (m_ProjectionWeights)
2060 {
2061 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2062 (n_cell * n_output), "ProjectionWeights");
2063 }
2064 if (m_ProjectionBias)
2065 {
2066 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2067 }
2068
2069 // Making sure the projection tensors are consistent:
2070 // 1) If projection weight is not present, then projection bias should not be
2071 // present.
2072 // 2) If projection weight is present, then projection bias is optional.
2073 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2074 !m_Parameters.m_ProjectionEnabled)
2075 || (m_ProjectionWeights && !m_ProjectionBias &&
2076 m_Parameters.m_ProjectionEnabled)
2077 || (m_ProjectionWeights && m_ProjectionBias &&
2078 m_Parameters.m_ProjectionEnabled));
2079 if (!projecton_tensors_consistent)
2080 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002081 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002082 }
2083
2084 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2085 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2086 // either all have values or none of them have values. Layer normalization is used when the values of all the
2087 // layer normalization weights are present
2088 if (m_InputLayerNormWeights)
2089 {
2090 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2091 }
2092 if (m_ForgetLayerNormWeights)
2093 {
2094 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2095 }
2096 if (m_CellLayerNormWeights)
2097 {
2098 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2099 }
2100 if (m_OutputLayerNormWeights)
2101 {
2102 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2103 }
2104
Jan Eilers38e05bd2019-06-26 13:10:09 +01002105 if (m_Parameters.m_LayerNormEnabled)
2106 {
2107 if (!m_Parameters.m_CifgEnabled)
2108 {
2109 if (!m_InputLayerNormWeights)
2110 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002111 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2112 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002113 }
2114 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2115 1, n_cell, "InputLayerNormWeights");
2116 }
2117 else if (m_InputLayerNormWeights)
2118 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002119 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2120 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002121 }
2122
2123 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2124 "ForgetLayerNormWeights");
2125 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2126
2127 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2128 "OutputLayerNormWeights");
2129 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2130
2131 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2132 "CellLayerNormWeights");
2133 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2134 }
2135 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2136 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002137 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2138 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002139 }
telsoa01c577f2c2018-08-31 09:22:23 +01002140}
2141
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002142void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2143{
2144 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2145
2146 ValidateNumInputs(workloadInfo, descriptorName, 1);
2147 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2148
2149 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2150 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2151
2152 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2153 {
2154 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2155 }
2156
2157 if (outputTensorInfo.GetDataType() != DataType::Float32)
2158 {
2159 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2160 }
2161
2162 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2163}
2164
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002165void ConvertFp32ToBf16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2166{
2167 const std::string descriptorName{"ConvertFp32ToBf16QueueDescriptor"};
2168
2169 ValidateNumInputs(workloadInfo, descriptorName, 1);
2170 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2171
2172 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2173 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2174
2175 if (inputTensorInfo.GetDataType() != DataType::Float32)
2176 {
2177 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2178 }
2179
2180 if (outputTensorInfo.GetDataType() != DataType::BFloat16)
2181 {
2182 throw InvalidArgumentException(descriptorName + ": Output tensor type must be BFloat16.");
2183 }
2184
2185 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2186}
2187
telsoa01c577f2c2018-08-31 09:22:23 +01002188void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2189{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002190 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002191
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002192 ValidateNumInputs(workloadInfo, descriptorName, 1);
2193 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2194
2195 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2196 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2197
2198 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002199 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002200 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002201 }
2202
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002203 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002204 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002205 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002206 }
2207
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002208 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002209}
2210
2211void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2212{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002213 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002214
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002215 ValidateNumInputs(workloadInfo, descriptorName, 1);
2216 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2217
2218 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2219 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2220
2221 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002222 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002223 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002224 }
2225
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002226 if (outputTensorInfo.GetDataType() != DataType::Float32)
2227 {
2228 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2229 }
2230
2231 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002232}
2233
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002234void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2235{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002236 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002237
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002238 ValidateNumInputs(workloadInfo, descriptorName, 2);
2239 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2240
2241 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2242 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2243 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2244
2245 std::vector<DataType> supportedTypes =
2246 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002247 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002248 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002249 DataType::Float32,
2250 DataType::QAsymmS8,
2251 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002252 DataType::QSymmS16,
2253 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002254 };
2255
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002256 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2257 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2258 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002259
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002260 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2261 inputTensorInfo1,
2262 outputTensorInfo,
2263 descriptorName,
2264 "input_0",
2265 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002266}
2267
David Beckc2044fe2018-09-05 15:00:38 +01002268void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2269{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002270 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002271
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002272 ValidateNumInputs(workloadInfo, descriptorName, 2);
2273 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2274
2275 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2276 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2277 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2278
2279 std::vector<DataType> supportedTypes =
2280 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002281 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002282 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002283 DataType::Float32,
2284 DataType::QAsymmS8,
2285 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002286 DataType::QSymmS16,
2287 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002288 };
2289
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002290 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2291 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2292 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002293
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002294 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2295 inputTensorInfo1,
2296 outputTensorInfo,
2297 descriptorName,
2298 "input_0",
2299 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002300}
2301
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002302void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2303{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002304 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002305
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002306 ValidateNumInputs(workloadInfo, descriptorName, 2);
2307 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2308
2309 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2310 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2311 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2312
2313 std::vector<DataType> supportedTypes =
2314 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002315 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002316 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002317 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002318 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002319 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002320 DataType::QSymmS16,
2321 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002322 };
2323
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002324 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2325 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2326 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002327
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002328 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2329 inputTensorInfo1,
2330 outputTensorInfo,
2331 descriptorName,
2332 "input_0",
2333 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002334}
2335
narpra01a6bf9122018-09-10 09:50:09 +01002336void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2337{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002338 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002339
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002340 ValidateNumInputs(workloadInfo, descriptorName, 1);
2341 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2342
2343 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2344 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002345
2346 std::vector<DataType> supportedTypes =
2347 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002348 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002349 DataType::Float32,
2350 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002351 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002352 DataType::QAsymmU8,
2353 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002354 };
narpra01eb061912018-09-10 17:35:27 +01002355
James Conroy4d1ff582019-06-10 17:06:39 +01002356 // First check if input tensor data type is supported, then
2357 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002358 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2359 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002360
narpra0132b90462018-09-13 11:07:48 +01002361 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002362 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002363 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002364 }
narpra0132b90462018-09-13 11:07:48 +01002365 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002366 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002367 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002368 }
2369 else
2370 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002371 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002372 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002373 ValidateTensorNumDimensions(outputTensorInfo,
2374 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002375 outputDim > 0 ? outputDim : 1,
2376 "output");
2377 }
narpra01a6bf9122018-09-10 09:50:09 +01002378}
2379
jimfly012c9322a2018-09-19 10:59:49 +01002380void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2381{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002382 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002383
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002384 ValidateNumInputs(workloadInfo, descriptorName, 1);
2385 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2386
2387 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2388 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002389
jimfly012c9322a2018-09-19 10:59:49 +01002390 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002391 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2392
jimfly012c9322a2018-09-19 10:59:49 +01002393 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002394 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2395 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2396 "as there are dimensions in the input tensor that is " +
2397 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2398 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002399 }
2400}
2401
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002402void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2403{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002404 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002405
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002406 ValidateNumInputs(workloadInfo, descriptorName, 1);
2407 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002408
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002409 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2410 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2411
Sadik Armagan2208b602019-07-31 16:36:27 +01002412 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002413 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002414 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002415 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002416 DataType::Float16,
2417 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002418 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002419 DataType::QAsymmU8,
2420 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002421 };
2422
2423 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002424
Keith Davis0c2eeac2020-02-11 16:51:50 +00002425 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002426 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002427 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002428 }
2429}
2430
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002431void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2432{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002433 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002434
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002435 ValidateNumInputs(workloadInfo, descriptorName, 1);
2436 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002437
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002438 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2439 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002440
2441 std::vector<DataType> supportedTypes =
2442 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002443 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002444 DataType::Float32,
2445 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002446 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002447 DataType::QAsymmU8,
2448 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002449 };
2450
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002451 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2452 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002453}
2454
Conor Kennedy430b5d82018-11-14 15:28:28 +00002455void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2456{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002457 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002458
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002459 ValidateNumInputs(workloadInfo, descriptorName, 1);
2460 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2461
2462 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2463 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002464
2465 std::vector<DataType> supportedTypes =
2466 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002467 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002468 DataType::Float16,
2469 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002470 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002471 DataType::QAsymmU8,
2472 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002473 };
2474
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002475 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2476 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002477
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002478 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002479
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002480 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002481 if (rank > 4)
2482 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002483 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002484 }
2485
Conor Kennedy430b5d82018-11-14 15:28:28 +00002486 // Begin, End & Stride length must be of rank(input0)
2487 if (m_Parameters.m_Begin.size() != rank)
2488 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002489 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002490 }
2491
2492 if (m_Parameters.m_End.size() != rank)
2493 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002494 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002495 }
2496
2497 if (m_Parameters.m_Stride.size() != rank)
2498 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002499 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002500 }
2501
2502 // Stride entries must be non-zero
2503 for (auto& stride : m_Parameters.m_Stride)
2504 {
2505 if (stride == 0)
2506 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002507 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002508 }
2509 }
2510}
2511
kevmay0190539692018-11-29 08:40:19 +00002512void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2513{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002514 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002515
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002516 ValidateNumInputs(workloadInfo, descriptorName, 2);
2517 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2518
2519 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2520 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2521 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2522
2523 std::vector<DataType> supportedTypes =
2524 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002525 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002526 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002527 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002528 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002529 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002530 DataType::QSymmS16,
2531 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002532 };
2533
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002534 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2535 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2536 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002537
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002538 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2539 inputTensorInfo1,
2540 outputTensorInfo,
2541 descriptorName,
2542 "input_0",
2543 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002544}
2545
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002546void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2547{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002548 const std::string descriptorName{"DebugQueueDescriptor"};
2549
2550 ValidateNumInputs(workloadInfo, descriptorName, 1);
2551 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002552}
2553
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002554void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2555{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002556 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002557
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002558 ValidateNumInputs(workloadInfo, descriptorName, 2);
2559 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002560
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002561 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2562 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2563 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2564
2565 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2566 inputTensorInfo1,
2567 outputTensorInfo,
2568 descriptorName,
2569 "input_0",
2570 "input_1");
2571
2572 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002573 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002574 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002575 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002576}
2577
FrancisMurtagh878f0232018-12-19 10:56:15 +00002578void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2579{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002580 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002581
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002582 ValidateNumInputs(workloadInfo, descriptorName, 2);
2583 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002584
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002585 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2586 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2587 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2588
2589 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2590 inputTensorInfo1,
2591 outputTensorInfo,
2592 descriptorName,
2593 "input_0",
2594 "input_1");
2595
2596 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002597 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002598 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002599 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002600}
2601
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002602void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2603{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002604 const std::string descriptorName{"RsqrtQueueDescriptor"};
2605
2606 ValidateNumInputs(workloadInfo, descriptorName, 1);
2607 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2608
2609 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2610 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2611
2612 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002613
2614 std::vector<DataType> supportedTypes =
2615 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002616 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002617 DataType::Float16,
2618 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002619 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002620 DataType::QAsymmU8,
2621 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002622 };
2623
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002624 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2625 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002626}
2627
narpra01b89b05f2019-01-16 09:53:09 +00002628void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2629{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002630 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002631
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002632 ValidateNumInputs(workloadInfo, descriptorName, 2);
2633 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002634
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002635 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2636 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002637 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002638 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002639 }
2640
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002641 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2642 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2643
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002644 std::vector<DataType> supportedTypes =
2645 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002646 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002647 DataType::Float16,
2648 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002649 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002650 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002651 DataType::QSymmS16,
2652 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002653 };
2654
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002655 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002656
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002657 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002658
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002659 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2660 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002661}
2662
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002663void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2664{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002665 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2666
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002667 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002668
2669 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2670 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002671 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002672 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2673 }
2674
2675 if (m_Anchors == nullptr)
2676 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002677 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002678 }
2679
2680 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002681 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2682 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2683
2684 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002685 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002686 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2687 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002688
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002689 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2690 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2691 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002692
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002693 const std::vector<DataType> supportedInputTypes =
2694 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002695 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002696 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002697 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002698 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002699 DataType::QAsymmU8,
2700 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002701 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002702
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002703 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2704 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2705 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2706
2707 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2708 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2709 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2710 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2711
2712 // NOTE: Output is always Float32 regardless of input type
2713 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2714 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2715 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2716 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002717
2718 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2719 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002720 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002721 "must be positive and less than or equal to 1.");
2722 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002723
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002724 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2725 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002726 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002727 "should be equal to number of classes + 1.");
2728 }
2729}
2730
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002731void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2732{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002733 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002734
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002735 ValidateNumInputs(workloadInfo, descriptorName, 1);
2736 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2737
2738 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2739 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2740
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002741 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002742 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002743 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002744 }
2745
Sadik Armagan2208b602019-07-31 16:36:27 +01002746 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002747 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002748 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002749 DataType::Float32,
2750 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002751 };
2752
2753 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002754}
2755
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002756void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2757{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002758 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002759
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002760 ValidateNumInputs(workloadInfo, descriptorName, 2);
2761 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002762
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002763 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2764 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2765 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002766
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002767 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2768 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2769
2770 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2771 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002772}
2773
Sadik Armaganeff363d2019-04-05 15:25:46 +01002774void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2775{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002776 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002777
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002778 ValidateNumInputs(workloadInfo, descriptorName, 2);
2779 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2780
2781 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2782 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2783
2784 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2785 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2786
2787 std::vector<DataType> supportedTypes =
2788 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002789 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002790 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002791 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002792 DataType::QAsymmU8,
2793 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002794 };
2795
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002796 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2797 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002798
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002799 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2800 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002801
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002802 ValidateTensorShapesMatch(inputTensorInfo0,
2803 outputTensorInfo0,
2804 descriptorName,
2805 "input_0",
2806 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002807
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002808 ValidateTensorShapesMatch(inputTensorInfo0,
2809 outputTensorInfo1,
2810 descriptorName,
2811 "input_0",
2812 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002813}
2814
Derek Lamberti901ea112019-12-10 22:07:09 +00002815void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002816{
2817 // This is internally generated so it should not need validation.
2818}
2819
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002820void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2821{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002822 const std::string& descriptorName{"PreluQueueDescriptor"};
2823
2824 ValidateNumInputs(workloadInfo, descriptorName, 2);
2825 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2826
2827 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2828 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2829 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002830
2831 std::vector<DataType> supportedTypes
2832 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002833 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002834 DataType::Float16,
2835 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002836 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002837 DataType::QAsymmU8,
2838 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002839 };
2840
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002841 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2842 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002843
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002844 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002845
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002846 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2847 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002848
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002849 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2850 alphaTensorInfo,
2851 outputTensorInfo,
2852 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002853 "input",
2854 "alpha");
2855}
2856
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002857void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2858{
2859 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2860
2861 ValidateNumInputs(workloadInfo, descriptorName, 1);
2862 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2863
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002864 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2865 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2866
2867 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2868 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002869
2870 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002871
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002872 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2873 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002874
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002875 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2876
2877 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002878 if (m_Parameters.m_BiasEnabled)
2879 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002880 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002881
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002882 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2883 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002884
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002885 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002886 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002887 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002888
2889 ValidatePerAxisQuantization(inputTensorInfo,
2890 outputTensorInfo,
2891 weightTensorInfo,
2892 optionalBiasTensorInfo,
2893 descriptorName);
2894
2895 std::vector<DataType> supportedTypes =
2896 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002897 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002898 DataType::Float32,
2899 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002900 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002901 DataType::QAsymmU8,
2902 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002903 };
2904
2905 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2906 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002907}
2908
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002909void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2910{
2911 const std::string descriptorName{"TransposeQueueDescriptor"};
2912
2913 ValidateNumInputs(workloadInfo, descriptorName, 1);
2914 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2915
2916 const PermutationVector& mapping = m_Parameters.m_DimMappings;
2917
2918 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2919 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2920
2921 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
2922 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
2923
2924 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
2925 {
2926 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
2927 {
2928 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
2929 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
2930 "must match dst dimension " + to_string(i) +
2931 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
2932 }
2933 }
2934
2935 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2936}
2937
James Conroy4f1f8992020-04-29 20:01:10 +01002938void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2939{
2940 const std::string descriptorName{"QLstmQueueDescriptor"};
2941
2942 // Validate number of inputs/outputs
2943 ValidateNumInputs(workloadInfo, descriptorName, 3);
2944 ValidateNumOutputs(workloadInfo, descriptorName, 3);
2945
2946 // Input/output tensor info
2947 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2948 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
2949 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
2950
2951 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2952 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2953 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
2954
2955 // Supported types for various tensors in QLSTM
2956 std::vector<DataType> inputOutputSupportedTypes =
2957 {
2958 DataType::QAsymmS8
2959 };
2960
2961 std::vector<DataType> cellStateSupportedTypes =
2962 {
2963 DataType::QSymmS16
2964 };
2965
2966 std::vector<DataType> weightsSupportedTypes =
2967 {
2968 DataType::QSymmS8
2969 };
2970
2971 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
2972 {
2973 DataType::QSymmS16
2974 };
2975
2976 std::vector<DataType> biasSupportedTypes =
2977 {
2978 DataType::Signed32
2979 };
2980
2981 // Validate types of input/output tensors
2982 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2983 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2984 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2985
2986 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2987 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2988 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
2989
2990 // Validate matching types of input/output tensors
2991 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2992 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2993 "outputStateIn", "outputStateOut");
2994 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2995
2996 // Infer number of batches, number of units, input size and output size from tensor dimensions
2997 const uint32_t numBatches = inputInfo.GetShape()[0];
2998 const uint32_t inputSize = inputInfo.GetShape()[1];
2999 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3000 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3001
3002 // Validate number of dimensions and number of elements for input/output tensors
3003 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3004 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3005 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3006
3007 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3008 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3009 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3010
3011 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3012 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3013 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3014 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3015
3016 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3017 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3018 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3019
3020 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3021 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3022 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3023
3024 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3025 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3026 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3027 " RecurrentToForgetWeights");
3028
3029 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3030 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3031 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3032
3033 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3034 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3035 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3036
3037 // Validate data types for MANDATORY weights tensors (all should match each other)
3038 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3039
3040 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3041 "inputToForgetWeights", "inputToCellWeights");
3042 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3043 "inputToForgetWeights", "inputToOutputWeights");
3044
3045 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3046 "inputToForgetWeights", "recurrentToForgeteights");
3047 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3048 "inputToForgetWeights", "recurrentToCellWeights");
3049 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3050 "inputToForgetWeights", "recurrentToOutputWeights");
3051
3052 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3053 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3054 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3055 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3056
3057 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3058 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3059 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3060
3061 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3062 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3063 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3064
3065 // Validate data types for MANDATORY bias tensors
3066 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3067
3068 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3069 "forgetGateBias", "cellBias");
3070 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3071 "forgetGateBias", "outputGateBias");
3072
3073 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3074 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3075 !m_Parameters.m_CifgEnabled) ||
3076 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3077 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3078
3079 if (!allCifgParamsPresentOrNot)
3080 {
3081 throw InvalidArgumentException(descriptorName +
3082 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3083 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3084 "set appropriately.");
3085 }
3086
3087 if (!m_Parameters.m_CifgEnabled)
3088 {
3089 // Validate number of dimensions and number of elements
3090 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3091 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3092
3093 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3094 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3095 " RecurrentToInputWeights");
3096
3097 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3098 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3099
3100 // Validate data types
3101 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3102 "inputToForgetWeights", "inputToInputWeights");
3103 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3104 "inputToForgetWeights", "recurrentToInputWeights");
3105 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3106 "forgetGateBias", "inputGateBias");
3107 }
3108
3109 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3110 bool allPeepholeWeightsPresentOrNot =
3111 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3112 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3113 || (!m_CellToInputWeights && !m_CellToForgetWeights
3114 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3115
3116 if (!allPeepholeWeightsPresentOrNot)
3117 {
3118 throw InvalidArgumentException(descriptorName +
3119 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3120 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3121 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3122 "appropriately.");
3123 }
3124
3125 if (m_Parameters.m_PeepholeEnabled)
3126 {
3127 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3128 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3129 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3130
3131 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3132 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3133 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3134 "cellToForgetWeight", "cellToOutputWeights");
3135
3136 if (!m_Parameters.m_CifgEnabled)
3137 {
3138 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3139 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3140 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3141 "cellToForgetWeights", "cellToInputWeights");
3142 }
3143 }
3144
3145 // Validate OPTIONAL params: Layer Norm Weights
3146 bool allLayerNormWeightsPresentOrNot =
3147 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3148 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3149 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3150 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3151
3152 if (!allLayerNormWeightsPresentOrNot)
3153 {
3154 throw InvalidArgumentException(descriptorName +
3155 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3156 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3157 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3158 "only be present when Layer Norm is enabled and CIFG is disabled. "
3159 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3160 }
3161
3162 if (m_Parameters.m_LayerNormEnabled)
3163 {
3164 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3165 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3166 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3167
3168 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3169 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3170 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3171 "forgetLayerNormWeights", "cellLayerNormWeights");
3172
3173 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3174 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3175 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3176 "forgetLayerNormWeights", "outputLayerNormWeights");
3177
3178 if (!m_Parameters.m_CifgEnabled)
3179 {
3180 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3181 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3182 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3183 "forgetLayerNormWeights", "inputLayerNormWeights");
3184 }
3185 }
3186
3187 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3188 bool correctProjectionTensorsPresent =
3189 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3190 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3191 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3192
3193 if (!correctProjectionTensorsPresent)
3194 {
3195 throw InvalidArgumentException(descriptorName +
3196 ": If projection is enabled, ProjectionWeights should be present and "
3197 "ProjectionBias is optional. If projection is disabled, neither "
3198 "ProjectionWeights nor ProjectionBias should be present.");
3199 }
3200
3201 if (m_Parameters.m_ProjectionEnabled)
3202 {
3203 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3204 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3205 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3206
3207 if (m_ProjectionBias)
3208 {
3209 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003210 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003211 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3212 }
3213
3214 }
3215 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3216 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3217 throw InvalidArgumentException(descriptorName +
3218 ": If projection is disabled, output quantization info (scale, offset) "
3219 "should match HiddenStateScale and HiddenStateZeroPoint.");
3220 }
3221
3222}
3223
James Conroy9c3cae82019-08-01 16:01:48 +01003224void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3225{
3226 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3227
3228 // Validate number of inputs/outputs
3229 ValidateNumInputs(workloadInfo, descriptorName, 3);
3230 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3231
3232 // Input/output tensor infos
3233 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3234 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3235 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3236
3237 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3238 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3239
3240 std::vector<DataType> inputOutputSupportedTypes =
3241 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003242 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003243 };
3244
3245 std::vector<DataType> cellStateSupportedTypes =
3246 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003247 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003248 };
3249
3250 std::vector<DataType> weightsSupportedTypes =
3251 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003252 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003253 };
3254
3255 std::vector<DataType> biasSupportedTypes =
3256 {
3257 DataType::Signed32
3258 };
3259
3260 // Validate types of input/output tensors
3261 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3262 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3263 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3264
3265 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3266 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3267
3268 // Validate matching types of input/output tensors
3269 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3270 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3271 "outputStateIn", "outputStateOut");
3272 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3273
3274 // Validate matching quantization info for input/output tensors
3275 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3276 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3277 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003278
James Conroy9c3cae82019-08-01 16:01:48 +01003279 // Infer number of batches, input size and output size from tensor dimensions
3280 const uint32_t numBatches = inputInfo.GetShape()[0];
3281 const uint32_t inputSize = inputInfo.GetShape()[1];
3282 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3283
3284 // Validate number of dimensions and number of elements for input/output tensors
3285 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3286 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3287 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3288 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3289 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3290
3291 // Validate number of dimensions and number of elements for weights tensors
3292 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3293 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3294 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3295
3296 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3297 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3298 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3299
3300 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3301 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3302 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3303
3304 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3305 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3306 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3307
3308 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3309 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3310 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3311
3312 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3313 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3314 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3315 " RecurrentToForgetWeights");
3316
3317 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3318 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3319 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3320
3321 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3322 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3323 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3324
3325 // Validate data types for weights tensors (all should match each other)
3326 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3327
3328 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3329 "inputToInputWeights", "inputToForgetWeights");
3330 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3331 "inputToInputWeights", "inputToCellWeights");
3332 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3333 "inputToInputWeights", "inputToOutputWeights");
3334
3335 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3336 "inputToInputWeights", "recurrentToInputWeights");
3337 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3338 "inputToInputWeights", "recurrentToForgeteights");
3339 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3340 "inputToInputWeights", "recurrentToCellWeights");
3341 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3342 "inputToInputWeights", "recurrentToOutputWeights");
3343
3344 // Validate matching quantization info for weight tensors (all should match each other)
3345 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3346 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3347 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3348 descriptorName, "inputToInputWeights", "inputToCellWeights");
3349 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3350 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3351
3352 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3353 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3354 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3355 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3356 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3357 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3358 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3359 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3360
3361 // Validate number of dimensions and number of elements in bias tensors
3362 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3363 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3364 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3365
3366 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3367 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3368 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3369
3370 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3371 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3372 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3373
3374 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3375 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3376 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3377
3378 // Validate data types for bias tensors (all should match each other)
3379 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3380
3381 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3382 "inputGateBias", "forgetGateBias");
3383 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3384 "inputGateBias", "cellBias");
3385 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3386 "inputGateBias", "outputGateBias");
3387
3388 // Validate bias tensor quantization info
3389 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3390 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3391 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3392 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3393}
3394
Kevin May868eb142019-09-04 17:29:31 +01003395void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3396{
3397 const std::string descriptorName{"AbsQueueDescriptor"};
3398
3399 ValidateNumInputs(workloadInfo, descriptorName, 1);
3400 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3401
3402 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3403 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3404
3405 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3406
3407 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003408 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003409 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003410 DataType::Float16,
3411 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003412 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003413 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003414 DataType::QSymmS16,
3415 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003416 };
Kevin May868eb142019-09-04 17:29:31 +01003417
3418 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3419 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3420}
3421
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003422void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3423{
3424 const std::string descriptorName{"SliceQueueDescriptor"};
3425
3426 ValidateNumInputs(workloadInfo, descriptorName, 1);
3427 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3428
3429 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3430 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3431
3432 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3433
3434 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3435 if (rank > 4)
3436 {
3437 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3438 }
3439
3440 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3441
3442 // Check if m_Begin and m_Size have the expected length
3443 if (m_Parameters.m_Begin.size() != rank)
3444 {
3445 throw InvalidArgumentException(descriptorName +
3446 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3447 }
3448 if (m_Parameters.m_Size.size() != rank)
3449 {
3450 throw InvalidArgumentException(descriptorName +
3451 ": Length of size descriptor must equal rank " + std::to_string(rank));
3452 }
3453
3454 // Check if the shape of the output tensor matches m_Size
3455 const TensorShape& outputShape = outputTensorInfo.GetShape();
3456 for (unsigned int i = 0u; i < rank; ++i)
3457 {
3458 if (m_Parameters.m_Size[i] != outputShape[i])
3459 {
3460 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3461 }
3462 }
3463
3464 // Check if the sum of begin offset and size in a given dimension
3465 // does not exceed the size of corresponding input
3466 const TensorShape& inputShape = inputTensorInfo.GetShape();
3467 for(unsigned int i = 0u; i < rank; ++i)
3468 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003469 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003470 {
3471 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3472 std::to_string(i) + " exceeds input size.");
3473 }
3474 }
3475}
3476
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003477void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3478{
3479 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3480
3481 ValidateNumInputs(workloadInfo, descriptorName, 1);
3482 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3483
3484 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3485 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3486
3487 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3488 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3489
3490 std::vector<DataType> supportedTypes =
3491 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003492 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003493 DataType::Float32,
3494 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003495 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003496 DataType::QAsymmU8,
3497 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003498 };
3499
3500 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3501 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3502
3503 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3504
3505 if (m_Parameters.m_BlockSize == 0)
3506 {
3507 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3508 }
3509
3510 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3511 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3512 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3513 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3514
3515 const TensorShape& outputShape = outputInfo.GetShape();
3516 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3517 {
3518 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3519 "must be divisible by block size.");
3520 }
3521
3522 const TensorShape& inputShape = inputInfo.GetShape();
3523 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3524 {
3525 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3526 "must be divisible by the square of block size." );
3527 }
3528}
3529
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003530void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3531{
3532 const std::string descriptorName{"ComparisonQueueDescriptor"};
3533
3534 ValidateNumInputs(workloadInfo, descriptorName, 2);
3535 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3536
3537 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3538 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3539 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3540
3541 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3542 inputTensorInfo1,
3543 outputTensorInfo,
3544 descriptorName,
3545 "input_0",
3546 "input_1");
3547
3548 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3549 {
3550 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3551 }
3552}
3553
josh minor4a3c6102020-01-06 16:40:46 -06003554void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3555{
3556 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3557
3558 ValidateNumInputs(workloadInfo, descriptorName, 1);
3559 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3560
3561 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3562 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3563
3564 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3565
3566 std::vector<DataType> supportedTypes =
3567 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003568 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003569 DataType::Float16,
3570 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003571 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003572 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003573 DataType::QSymmS16,
3574 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003575 };
3576
James Conroyaba90cd2020-11-06 16:28:18 +00003577 std::vector<DataType> logicalSupportedTypes =
3578 {
3579 DataType::Boolean
3580 };
3581
3582 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3583 {
3584 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3585 }
3586 else
3587 {
3588 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3589 }
3590
3591
josh minor4a3c6102020-01-06 16:40:46 -06003592 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3593}
3594
Finn Williams2605b232020-06-10 15:53:46 +01003595void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3596{
3597 const std::string descriptorName{"RankQueueDescriptor"};
3598
3599 ValidateNumInputs(workloadInfo, descriptorName, 1);
3600 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3601
3602 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3603 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3604
3605 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3606 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3607
3608 std::vector<DataType> supportedTypes =
3609 {
3610 DataType::BFloat16,
3611 DataType::Float16,
3612 DataType::Float32,
3613 DataType::QAsymmS8,
3614 DataType::QAsymmU8,
3615 DataType::QSymmS8,
3616 DataType::QSymmS16,
3617 DataType::Signed32
3618 };
3619
3620 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3621 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3622}
3623
James Conroyaba90cd2020-11-06 16:28:18 +00003624void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3625{
3626 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3627
3628 ValidateNumInputs(workloadInfo, descriptorName, 2);
3629 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3630
3631 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3632 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3633 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3634
3635 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3636 inputTensorInfo1,
3637 outputTensorInfo,
3638 descriptorName,
3639 "input_0",
3640 "input_1");
3641
3642 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3643 {
3644 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3645 }
3646
3647 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3648 {
3649 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3650 }
3651
3652 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3653 {
3654 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3655 }
3656}
3657
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003658void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3659{
3660 const std::string descriptorName{"ReduceQueueDescriptor"};
3661
3662 ValidateNumInputs(workloadInfo, descriptorName, 1);
3663 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3664
3665 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3666 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3667
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003668 std::vector<DataType> supportedTypes =
3669 {
3670 DataType::BFloat16,
3671 DataType::Float16,
3672 DataType::Float32,
3673 DataType::QAsymmS8,
3674 DataType::QAsymmU8,
3675 DataType::QSymmS16,
3676 DataType::Signed32
3677 };
3678
3679 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3680 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3681}
3682
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003683} // namespace armnn