blob: 100d23ee39cc0c548e20ede2cc50539f9db98b8e [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00006#include <backendsCommon/WorkloadData.hpp>
7#include <backendsCommon/CpuTensorHandle.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/TensorUtils.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010010#include <armnn/utility/NumericCast.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000011
telsoa014fcda012018-03-09 14:13:49 +000012#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000014#include <string>
15#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000016
James Ward47fce872020-09-10 11:57:28 +010017#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000018
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
telsoa01c577f2c2018-08-31 09:22:23 +010029 case DataType::Float16:
30 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000031 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000032 case DataType::Float32:
33 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000034 case DataType::QAsymmS8:
35 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000036 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000037 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000038 case DataType::QSymmS8:
39 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000040 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010041 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000042 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010043 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000044 return DataType::Float32;
45 }
46}
47
48namespace
49{
50
51//---------------------------------------------------------------
52//android ndk does not support std::to_string function.
53template <typename T>
54std::string to_string(T value)
55{
56 std::ostringstream os;
57 os << value;
58 return os.str();
59}
60
61//---------------------------------------------------------------
62void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
63{
64 if (!ptr)
65 {
66 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
67 paramName + " parameter must be set.");
68 }
69}
70
71//---------------------------------------------------------------
72void ValidateTensorShapesMatch(const TensorInfo& first,
73 const TensorInfo& second,
74 std::string const& descName,
75 std::string const& firstName,
76 std::string const& secondName)
77{
78 if (first.GetShape() != second.GetShape())
79 {
80 throw InvalidArgumentException(descName + ": "
81 + firstName + " & " + secondName + " must have identical shapes");
82 }
83}
84
85//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010086void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000087{
Sadik Armaganeff363d2019-04-05 15:25:46 +010088 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000089 {
90 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010091 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000092 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
93 }
94}
95
96//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010097void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000098{
Sadik Armaganeff363d2019-04-05 15:25:46 +010099 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000100 {
101 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100102 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000103 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
104 }
105}
106
107//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100108void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000109 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100110 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000111 std::string const& tensorName)
112{
113 if (tensor.GetNumDimensions() != numDimensions)
114 {
115 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
116 to_string(tensor.GetNumDimensions()) + " dimensions for " +
117 tensorName + " tensor.");
118 }
119}
120
121//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100122void ValidateTensorNumElements(const TensorInfo& tensor,
123 std::string const& descName,
124 unsigned int numElements,
125 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100126{
127 if (tensor.GetNumElements() != numElements)
128 {
129 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100130 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100131 tensorName + " tensor.");
132 }
133}
134
135//---------------------------------------------------------------
136void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100137 unsigned int numDimension,
138 unsigned int numElements,
139 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100140{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100141 const std::string functionName{"ValidateTensorNumDimNumElem"};
142 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
143 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100144}
145
146//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000147void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
148 const std::string& descName, std::string const& tensorName)
149{
150 if (tensor.GetDataType() != dataType)
151 {
152 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
153 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
154 }
155}
156
Derek Lambertid466a542020-01-22 15:37:29 +0000157void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
158{
159 ARMNN_NO_DEPRECATE_WARN_BEGIN
160 if (tensor.GetDataType() != DataType::QSymmS8 &&
161 tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
162 {
163 throw InvalidArgumentException(descName +
164 ": Expected data type which supports per-axis quantization scheme but got " +
165 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
166 }
167 ARMNN_NO_DEPRECATE_WARN_END
168}
169
telsoa014fcda012018-03-09 14:13:49 +0000170//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100171void ValidateTensorQuantizationSpace(const TensorInfo& first,
172 const TensorInfo& second,
173 const std::string& descName,
174 std::string const& firstName,
175 std::string const& secondName)
176{
177 if (!first.IsQuantized() ||
178 !second.IsQuantized())
179 {
180 // Not a quantized type, ignore the validation
181 return;
182 }
183
184 DataType firstDataType = first.GetDataType();
185 DataType secondDataType = second.GetDataType();
186
187 if (firstDataType != secondDataType)
188 {
189 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
190 " must be of the same quantized type, " +
191 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
192 secondName + " is " + GetDataTypeName(secondDataType));
193 }
194
195 if (!first.IsTypeSpaceMatch(second))
196 {
197 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
198 " must have the same quantization space, " +
199 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
200 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
201 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
202 " and scale " + to_string(second.GetQuantizationScale()));
203 }
204}
205
206//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100207void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
208 const TensorInfo& inputTensorInfo,
209 const TensorInfo& weightsTensorInfo,
210 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000211{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000212 // Helper lambda function to validate a single bias quantization scale value
213 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
214 {
ricbur013f4d7102019-10-31 16:22:18 +0000215 constexpr float tolerance = 0.000001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000216 if (std::abs(biasScale - expectedScale) > tolerance)
217 {
218 // Print the float values with extra precision to see very small differences
219 std::stringstream msg;
220 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
221 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
222 biasScale;
223 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
224 }
225 };
226
telsoa014fcda012018-03-09 14:13:49 +0000227 if (biasTensor.GetQuantizationOffset() != 0)
228 {
229 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
230 to_string(biasTensor.GetQuantizationOffset()));
231 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000232
James Conroy8502ade2020-11-12 19:26:29 +0000233 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000234 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000235 // Validate per-axis quantization scales
236 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
237 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
238
239 if (weightScales.size() != biasScales.size())
240 {
241 std::stringstream msg;
James Conroy8502ade2020-11-12 19:26:29 +0000242 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
243 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
244 << ", biases=" << biasScales.size();
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000245 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
246 }
247
248 for (size_t i = 0ul; i < biasScales.size(); ++i)
249 {
250 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
251 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
252 }
253 }
254 else
255 {
256 // Validate per-tensor quantization scale
257 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
258 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000259 }
260}
261
262//---------------------------------------------------------------
263void ValidateTensors(const std::vector<ITensorHandle*>& vec,
264 unsigned int numExpected,
265 const std::string& descName,
266 const std::string& varName)
267{
268 if (vec.empty() && numExpected > 0)
269 {
270 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
271 }
272
273 for (unsigned int i = 0; i < numExpected; ++i)
274 {
275 if (!vec[i])
276 {
277 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
278 }
279 }
280}
281
282//---------------------------------------------------------------
283void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
284 const TensorInfo& second,
285 const TensorInfo& output,
286 std::string const& descName,
287 std::string const& firstName,
288 std::string const& secondName)
289{
290 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
291 // broadcasted.
292 if (first.GetNumDimensions() != second.GetNumDimensions())
293 {
294 throw InvalidArgumentException(descName + ": Tensors "
295 + firstName + " & " + secondName
296 + " must have the same number of dimensions in order to be broadcasted");
297 }
298 uint32_t numDims = first.GetNumDimensions();
299 std::vector<uint32_t> outputDims(numDims, 0u);
300 for (uint32_t i = 0; i < numDims; i++)
301 {
302 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
303 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
304 if (dimsNotEqual && dimsNotOne)
305 {
306 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
307 }
308 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
309 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100310 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000311 if (broadcastShape != output.GetShape())
312 {
313 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
314 + firstName + " & " + secondName
315 + " does not match the output shape");
316 }
317}
318
319//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100320void ValidateDataTypes(const TensorInfo& info,
321 const std::vector<armnn::DataType>& supportedTypes,
322 std::string const& descName)
323{
324 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
325 if (iterator == supportedTypes.end())
326 {
327 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
328 }
329}
330
James Conroy4d1ff582019-06-10 17:06:39 +0100331//---------------------------------------------------------------
332void ValidateTensorDataTypesMatch(const TensorInfo& first,
333 const TensorInfo& second,
334 std::string const& descName,
335 std::string const& firstName,
336 std::string const& secondName)
337{
338 if (first.GetDataType() != second.GetDataType())
339 {
340 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
341 " must have identical data types.");
342 }
343}
344
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100345//---------------------------------------------------------------
346void ValidateTensorNumElementsMatch(const TensorInfo& first,
347 const TensorInfo& second,
348 std::string const& descName,
349 std::string const& firstName,
350 std::string const& secondName)
351{
352 if (first.GetNumElements() != second.GetNumElements())
353 {
354 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
355 " must have the same number of elements.");
356 }
357}
358
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000359void ValidateWeightDataType(const TensorInfo& inputInfo,
360 const TensorInfo& weightInfo,
361 const std::string& descName)
362{
363 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000364 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000365 {
Derek Lambertid466a542020-01-22 15:37:29 +0000366 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000367 const std::vector<DataType> validTypes =
368 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000369 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100370 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000371 DataType::QSymmS8,
372 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000373 };
Derek Lambertid466a542020-01-22 15:37:29 +0000374 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000375
376 ValidateDataTypes(weightInfo, validTypes, descName);
377 }
378 else
379 {
380 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
381 }
382}
383
384void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
385 const std::string& descName,
386 const std::string& tensorName)
387{
388 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
389 if (!quantizationDim.has_value())
390 {
James Ward47fce872020-09-10 11:57:28 +0100391 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
392 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000393 }
394
395 if (quantizationDim.value() != 0)
396 {
James Ward47fce872020-09-10 11:57:28 +0100397 throw InvalidArgumentException(fmt::format(
398 "{0}: Quantization dimension for per-axis quantization expected to be 0 on tensor {1}, "
399 "but got: {2}", descName, tensorName, quantizationDim.value()));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000400 }
401}
402
403void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
404 const std::string& descName,
405 const std::string& tensorName)
406{
407 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
408 if (quantizationOffset != 0)
409 {
James Ward47fce872020-09-10 11:57:28 +0100410 throw InvalidArgumentException(fmt::format(
411 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
412 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000413 }
414}
415
416void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
417 const TensorInfo& outputInfo,
418 const TensorInfo& weightInfo,
419 const Optional<TensorInfo>& optionalBiasInfo,
420 const std::string& descName)
421{
422 if (weightInfo.HasPerAxisQuantization())
423 {
424 const DataType inputDataType = inputInfo.GetDataType();
425 const DataType outputDataType = outputInfo.GetDataType();
426
Keith Davis0c2eeac2020-02-11 16:51:50 +0000427 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000428
429 if (!canHavePerAxisQuantization)
430 {
James Ward47fce872020-09-10 11:57:28 +0100431 throw InvalidArgumentException(fmt::format(
432 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
433 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000434 }
435
Derek Lambertid466a542020-01-22 15:37:29 +0000436
437 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000438 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
439 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
440
441 if (optionalBiasInfo.has_value())
442 {
443 const TensorInfo& biasInfo = optionalBiasInfo.value();
444 if (!biasInfo.HasPerAxisQuantization())
445 {
James Ward47fce872020-09-10 11:57:28 +0100446 throw InvalidArgumentException(fmt::format(
447 "{}: Per-axis quantization parameters not set on bias tensor, "
448 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000449 }
450
451 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
452 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
453 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
454 }
455 }
456}
457
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100458} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000459
460void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
461 unsigned int numExpectedIn, unsigned int numExpectedOut) const
462{
463 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
464 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
465}
466
467//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100468void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
469{
470 const std::string descriptorName{"MapQueueDescriptor"};
471
472 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100473 ValidateNumOutputs(workloadInfo, descriptorName, 0);
474
475 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
476 {
477 if (!m_Inputs[i])
478 {
479 throw InvalidArgumentException(
480 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
481 }
482 }
483}
484
485//---------------------------------------------------------------
486void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
487{
488 const std::string descriptorName{"UnmapQueueDescriptor"};
489
490 ValidateNumInputs(workloadInfo, descriptorName, 1);
491 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100492
493 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
494 {
495 if (!m_Inputs[i])
496 {
497 throw InvalidArgumentException(
498 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
499 }
500 }
501}
502
503//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000504void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
505{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100506 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000507
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100508 ValidateNumInputs(workloadInfo, descriptorName, 1);
509 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000510
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100511 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
512 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
513
514 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
515 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000516
517 if (m_Inputs.size() != m_Outputs.size())
518 {
James Ward47fce872020-09-10 11:57:28 +0100519 throw InvalidArgumentException(fmt::format(
520 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
521 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000522 }
523
524 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
525 {
526 if (!m_Inputs[i])
527 {
James Ward47fce872020-09-10 11:57:28 +0100528 throw InvalidArgumentException(fmt::format(
529 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000530 }
531
532 if (!m_Outputs[i])
533 {
James Ward47fce872020-09-10 11:57:28 +0100534 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000535 }
536 }
537}
538
Derek Lambertif674aa02019-08-01 15:56:25 +0100539//---------------------------------------------------------------
540void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
541{
542 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
543 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
544
545 if (workloadInfo.m_InputTensorInfos.size() != 1)
546 {
James Ward47fce872020-09-10 11:57:28 +0100547 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
548 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100549
550 }
551
552 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
553 {
James Ward47fce872020-09-10 11:57:28 +0100554 throw InvalidArgumentException(fmt::format(
555 "Number of input infos ({0}) does not match the number of output infos ({1})",
556 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100557 }
558
559 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
560 {
561 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
562 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
563 {
James Ward47fce872020-09-10 11:57:28 +0100564 throw InvalidArgumentException(fmt::format(
565 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100566 }
567 }
568
569 if (m_Inputs.size() != 1)
570 {
James Ward47fce872020-09-10 11:57:28 +0100571 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100572 }
573
574 if (m_Inputs.size() != m_Outputs.size())
575 {
James Ward47fce872020-09-10 11:57:28 +0100576 throw InvalidArgumentException(fmt::format(
577 "Number of inputs ({0}) does not match the number of outputs ({1})",
578 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100579 }
580
581 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
582 {
583 if (!m_Inputs[i])
584 {
James Ward47fce872020-09-10 11:57:28 +0100585 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100586 }
587
588 if (!m_Outputs[i])
589 {
James Ward47fce872020-09-10 11:57:28 +0100590 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100591 }
592 }
593}
594
595//---------------------------------------------------------------
596void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
597{
598 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
599 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
600
Derek Lambertif674aa02019-08-01 15:56:25 +0100601 if (m_Inputs.size() != 1)
602 {
James Ward47fce872020-09-10 11:57:28 +0100603 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100604 }
605
606 if (m_Outputs.size() != 0)
607 {
James Ward47fce872020-09-10 11:57:28 +0100608 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100609 }
610
611 if (!m_Inputs[0])
612 {
James Ward47fce872020-09-10 11:57:28 +0100613 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100614 }
615}
616
617//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000618void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
619{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100620 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100621
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100622 ValidateNumInputs(workloadInfo, descriptorName, 1);
623 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100624
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100625 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
626 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100627
628 std::vector<DataType> supportedTypes =
629 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000630 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100631 DataType::Float16,
632 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000633 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000634 DataType::QAsymmU8,
635 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100636 };
637
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100638 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
639 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
640 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000641}
642
Nikhil Rajee391d52019-09-05 17:50:44 +0100643void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
644{
645 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
646
647 ValidateNumInputs(workloadInfo, descriptorName, 1);
648 ValidateNumOutputs(workloadInfo, descriptorName, 1);
649
650 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
651 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
652
Inki Daed4619e22020-09-10 15:33:54 +0900653 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
654 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100655 {
Inki Daed4619e22020-09-10 15:33:54 +0900656 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100657 }
658
James Conroyd47a0642019-09-17 14:22:06 +0100659 std::vector<DataType> supportedInputTypes =
660 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000661 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100662 DataType::Float16,
663 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100664 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000665 DataType::QAsymmU8,
666 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900667 DataType::Signed32,
668 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100669 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100670
James Conroyd47a0642019-09-17 14:22:06 +0100671 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100672
673 auto inputShape = inputTensorInfo.GetShape();
674 auto outputShape = outputTensorInfo.GetShape();
675
676 auto inputNumDimensions = inputShape.GetNumDimensions();
677 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
678
679 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
680
681 // 1D input shape results in scalar output shape
682 if (inputShape.GetNumDimensions() == 1)
683 {
684 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
685 {
686 throw InvalidArgumentException(descriptorName + outputShapeError);
687 }
688 }
689 else
690 {
691 for (unsigned int i = 0; i < unsignedAxis; ++i)
692 {
693 if (outputShape[i] != inputShape[i])
694 {
695 throw InvalidArgumentException(descriptorName + outputShapeError);
696 }
697 }
698
699 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
700 {
701 if (outputShape[i - 1] != inputShape[i])
702 {
703 throw InvalidArgumentException(descriptorName + outputShapeError);
704 }
705 }
706 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100707}
708
mathad01b392e982021-04-07 12:07:30 +0100709void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
710{
711 const std::string descriptorName{"CastQueueDescriptor"};
712
713 ValidateNumInputs(workloadInfo, descriptorName, 1);
714 ValidateNumOutputs(workloadInfo, descriptorName, 1);
715
716 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
717 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
718
719 std::vector<DataType> supportedTypes =
720 {
721 DataType::BFloat16,
722 DataType::Float16,
723 DataType::Float32,
724 DataType::QAsymmS8,
725 DataType::QAsymmU8,
726 DataType::QSymmS8,
727 DataType::QSymmS16,
728 DataType::Signed32,
729 DataType::Signed64
730 };
731
732 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
733 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
734}
735
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100736void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
737{
738 const std::string descriptorName{"SoftmaxQueueDescriptor"};
739
740 ValidateNumInputs(workloadInfo, descriptorName, 1);
741 ValidateNumOutputs(workloadInfo, descriptorName, 1);
742
743 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
744 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
745
746 std::vector<DataType> supportedTypes =
747 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000748 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100749 DataType::Float16,
750 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000751 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000752 DataType::QAsymmU8,
753 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100754 };
755
756 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
757 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
758 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
759}
760
telsoa014fcda012018-03-09 14:13:49 +0000761void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
762{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100763 const std::string descriptorName{"SplitterQueueDescriptor"};
764
765 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000766
Ruomei Yan25339c32019-05-28 16:48:20 +0100767 // Check the supported data types
768 std::vector<DataType> supportedTypes =
769 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000770 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100771 DataType::Float32,
772 DataType::Float16,
773 DataType::Boolean,
774 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100775 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000776 DataType::QAsymmU8,
777 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100778 };
779
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100780 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
781 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100782 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100783 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
784 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
785
786 const std::string outputName = "output_" + std::to_string(i);
787 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100788 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100789
telsoa014fcda012018-03-09 14:13:49 +0000790 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
791 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100792 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000793 }
794
795 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
796 {
797 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100798 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000799 "has to match number of workloadInfo.m_OutputTensorInfos. "
800 "Number of windows: " +
801 to_string(m_ViewOrigins.size()) +
802 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
803 }
804
telsoa01c577f2c2018-08-31 09:22:23 +0100805 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000806 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
807 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
808 {
telsoa01c577f2c2018-08-31 09:22:23 +0100809 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000810 ViewOrigin const& e = m_ViewOrigins[w];
811 if (e.m_Origin.size() != inputDims)
812 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100813 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000814 "have the same dimensionality as the input tensor. "
815 "Window origin (index: " +
816 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
817 " dimensions, the input "
818 "tensor has " +
819 to_string(inputDims) + " dimensions.");
820 }
821 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
822 {
823 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
824 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
825 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100826 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000827 "be smaller or equal than the size of the input in that coord.");
828 }
829 }
830 }
831}
832
Jim Flynne242f2d2019-05-22 14:24:13 +0100833void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000834{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100835 const std::string descriptorName{"ConcatQueueDescriptor"};
836
837 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000838
839 if (m_Inputs.size() <= 0)
840 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100841 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000842 }
843 if (m_Outputs.size() <= 0)
844 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100845 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000846 }
847
848 if (workloadInfo.m_InputTensorInfos.size() <= 0)
849 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100850 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000851 }
852 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
853 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100854 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000855 }
856
Nikhil Raj8599a412018-11-19 14:51:07 +0000857 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
858 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100859 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000860 }
861
862 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
863 {
864 return;
865 }
866
telsoa014fcda012018-03-09 14:13:49 +0000867 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
868 {
869 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100870 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000871 "has to match number of workloadInfo.m_InputTensorInfos. "
872 "Number of windows: " +
873 to_string(m_ViewOrigins.size()) +
874 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
875 }
876
telsoa01c577f2c2018-08-31 09:22:23 +0100877 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000878 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
879 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
880 {
telsoa01c577f2c2018-08-31 09:22:23 +0100881 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000882 ViewOrigin const& e = m_ViewOrigins[w];
883 if (e.m_Origin.size() != outputDims)
884 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100885 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000886 "have the same dimensionality as the output tensor. "
887 "Window origin (index: " +
888 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
889 " dimensions, the output "
890 "tensor has " +
891 to_string(outputDims) + " dimensions.");
892 }
telsoa01c577f2c2018-08-31 09:22:23 +0100893 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000894 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
895 {
896 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
897 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
898 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100899 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000900 "be smaller or equal than the size of the output in that coord.");
901 }
902 }
903 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100904
905 // Check the supported data types
906 std::vector<DataType> supportedTypes =
907 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000908 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100909 DataType::Float32,
910 DataType::Float16,
911 DataType::Boolean,
912 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100913 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000914 DataType::QAsymmU8,
915 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100916 };
917
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100918 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
919 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100920 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100921 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
922 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
923
924 const std::string inputName = "input_" + std::to_string(i);
925 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100926 }
telsoa014fcda012018-03-09 14:13:49 +0000927}
928
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100929void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
930{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100931 const std::string descriptorName{"StackQueueDescriptor"};
932
933 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100934
935 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
936 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100937 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100938 }
939
940 // All inputs must have the same shape, which is defined in parameters
941 const TensorShape& inputShape = m_Parameters.m_InputShape;
942 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
943 {
944 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
945 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100946 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100947 }
948 }
949
Matthew Jacksondba634f2019-08-15 15:14:18 +0100950 if (inputShape.GetNumDimensions() > 4)
951 {
952 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
953 }
954
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100955 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
956 // since the output tensor has an additional dimension.
957 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
958 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100959 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100960 "than the number of input dimensions.");
961 }
962
963 // Output shape must be as inferred from the input shape
964 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
965 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
966 {
967 if (outputShape[i] != inputShape[i])
968 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100969 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100970 "match shape inferred from input tensor.");
971 }
972 }
973
974 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
975 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100976 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100977 "match shape inferred from input tensor.");
978 }
979
980 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
981 {
982 if (outputShape[i] != inputShape[i-1])
983 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100984 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100985 "match shape inferred from input tensor.");
986 }
987 }
988
Matthew Jacksondba634f2019-08-15 15:14:18 +0100989 if (outputShape.GetNumDimensions() > 5)
990 {
991 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
992 }
993
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100994 // Check the supported data types
995 std::vector<DataType> supportedTypes =
996 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000997 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100998 DataType::Float32,
999 DataType::Float16,
1000 DataType::Boolean,
1001 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001002 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001003 DataType::QAsymmU8,
1004 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001005 };
1006
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001007 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001008
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001009 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001010 {
1011 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1012 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001013 descriptorName,
1014 "input_0",
1015 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001016 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001017
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001018 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1019 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001020 descriptorName,
1021 "input_0",
1022 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001023}
1024
Ryan OSheaec6c6802020-06-05 17:17:06 +01001025void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1026{
1027 const std::string descriptorName{"FillQueueDescriptor"};
1028
1029 ValidateNumInputs(workloadInfo, descriptorName, 1);
1030 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1031
1032 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1033 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1034
1035 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1036
1037 std::vector<DataType> supportedTypes =
1038 {
1039 DataType::BFloat16,
1040 DataType::Float32,
1041 DataType::Float16,
1042 DataType::Signed32
1043 };
1044
1045 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1046}
1047
telsoa014fcda012018-03-09 14:13:49 +00001048void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1049{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001050 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001051
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001052 uint32_t numInputs = 1;
1053 if (!m_Parameters.m_ConstantWeights)
1054 {
1055 numInputs = 2;
1056 if (m_Parameters.m_BiasEnabled)
1057 {
1058 numInputs = 3;
1059 }
1060 }
1061 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001062 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1063
1064 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1065 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1066
1067 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1068
1069 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001070 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001071 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001072 }
1073
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001074 TensorInfo weightTensorInfo;
1075 if (m_Parameters.m_ConstantWeights)
1076 {
1077 ValidatePointer(m_Weight, descriptorName, "weight");
1078 weightTensorInfo = m_Weight->GetTensorInfo();
1079 }
1080 else
1081 {
1082 weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
1083 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001084 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001085
1086 if (m_Parameters.m_BiasEnabled)
1087 {
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001088 TensorInfo biasTensorInfo;
1089 if (m_Parameters.m_ConstantWeights)
1090 {
1091 ValidatePointer(m_Bias, descriptorName, "bias");
1092 biasTensorInfo = m_Bias->GetTensorInfo();
1093 }
1094 else
1095 {
1096 biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
1097 }
telsoa01c577f2c2018-08-31 09:22:23 +01001098 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001099 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001100 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1101 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001102 }
1103
Francis Murtagh46c09d02019-05-28 08:15:28 +01001104 // Check the supported data types
1105 std::vector<DataType> supportedTypes =
1106 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001107 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001108 DataType::Float32,
1109 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001110 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001111 DataType::QAsymmU8,
1112 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001113 };
1114
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001115 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001116
1117 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1118 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1119 {
1120 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1121 {
1122 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1123 "for BFloat16 input.");
1124 }
1125 }
1126 else
1127 {
1128 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1129 }
telsoa014fcda012018-03-09 14:13:49 +00001130}
1131
telsoa014fcda012018-03-09 14:13:49 +00001132void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1133{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001134 const std::string descriptorName{"NormalizationQueueDescriptor"};
1135
1136 ValidateNumInputs(workloadInfo, descriptorName, 1);
1137 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1138
1139 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1140 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001141
1142 // Check the supported data types
1143 std::vector<DataType> supportedTypes =
1144 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001145 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001146 DataType::Float16,
1147 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001148 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001149 DataType::QAsymmU8,
1150 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001151 };
1152
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001153 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001154
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001155 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001156
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001157 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001158}
1159
1160void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1161{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001162 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001163
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001164 ValidateNumInputs(workloadInfo, descriptorName, 2);
1165 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1166
1167 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1168 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1169 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1170
1171 std::vector<DataType> supportedTypes =
1172 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001173 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001174 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001175 DataType::Float16,
1176 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001177 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001178 DataType::QSymmS16,
1179 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001180 };
1181
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001182 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1183 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1184 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001185
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001186 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1187 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001188
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001189 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1190 inputTensorInfo1,
1191 outputTensorInfo,
1192 descriptorName,
1193 "input_0",
1194 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001195}
1196
telsoa014fcda012018-03-09 14:13:49 +00001197void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1198{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001199 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001200
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001201 ValidateNumInputs(workloadInfo, descriptorName, 2);
1202 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1203
1204 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1205 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1206 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1207
1208 std::vector<DataType> supportedTypes =
1209 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001210 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001211 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001212 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001213 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001214 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001215 DataType::QSymmS16,
1216 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001217 };
1218
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001219 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1220 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1221 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001222
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001223 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1224 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001225
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001226 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1227 inputTensorInfo1,
1228 outputTensorInfo,
1229 descriptorName,
1230 "input_0",
1231 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001232}
1233
1234void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1235{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001236 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001237
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001238 ValidateNumInputs(workloadInfo, descriptorName, 1);
1239 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1240
1241 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1242 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001243
1244 std::vector<DataType> supportedTypes =
1245 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001246 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001247 DataType::Float16,
1248 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001249 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001250 DataType::QAsymmU8,
1251 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001252 };
1253
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001254 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1255 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001256
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001257 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001258 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001259
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001260 ValidatePointer(m_Mean, descriptorName, "mean");
1261 ValidatePointer(m_Variance, descriptorName, "variance");
1262 ValidatePointer(m_Beta, descriptorName, "beta");
1263 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001264
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001265 const TensorInfo& mean = m_Mean->GetTensorInfo();
1266 const TensorInfo& variance = m_Variance->GetTensorInfo();
1267 const TensorInfo& beta = m_Beta->GetTensorInfo();
1268 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001269
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001270 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1271 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1272 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1273 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001274
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001275 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1276 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1277 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001278}
1279
1280void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1281{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001282 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001283
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001284 ValidateNumInputs(workloadInfo, descriptorName, 1);
1285 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001286
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001287 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1288 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001289
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001290 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1291 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001292
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001293 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001294
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001295 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1296 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001297
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001298 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001299
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001300 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001301 if (m_Parameters.m_BiasEnabled)
1302 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001303 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001304
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001305 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1306 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001307
1308 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1309 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001310 }
1311
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001312 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1313 {
1314 throw InvalidArgumentException(
1315 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1316 "cannot be either negative or 0.",
1317 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1318 }
1319
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001320 ValidatePerAxisQuantization(inputTensorInfo,
1321 outputTensorInfo,
1322 weightTensorInfo,
1323 optionalBiasTensorInfo,
1324 descriptorName);
1325
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001326 std::vector<DataType> supportedTypes =
1327 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001328 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001329 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001330 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001331 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001332 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001333 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001334 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001335 };
1336
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001337 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001338
1339 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1340 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1341 {
1342 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1343 {
1344 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1345 "for BFloat16 input.");
1346 }
1347 }
1348 else
1349 {
1350 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1351 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001352}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001353
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001354void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1355{
1356 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1357
1358 ValidateNumInputs(workloadInfo, descriptorName, 1);
1359 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1360
1361 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1362 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1363
1364 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1365 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1366
1367 ValidatePointer(m_Weight, descriptorName, "weight");
1368
1369 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1370 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1371
1372 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1373 {
1374 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001375 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1376 "cannot be smaller than 1.",
1377 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001378 }
1379
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001380 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1381 {
1382 throw InvalidArgumentException(
1383 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1384 "cannot be either negative or 0.",
1385 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1386 }
1387
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001388 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1389
1390 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1391 // inputChannels * channelMultiplier should be equal to outputChannels.
1392 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1393 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1394 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1395 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1396 {
James Ward47fce872020-09-10 11:57:28 +01001397 throw InvalidArgumentException(fmt::format(
1398 "{0}: output_channels (provided {1}) should be equal to input_channels (provided {2}) "
1399 "multiplied by channel_multiplier (provided {3}).",
1400 descriptorName, numWeightOutputChannels, numWeightInputChannels, numWeightChannelMultiplier));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001401 }
1402
Teresa Charlind8df0262019-11-11 12:28:15 +00001403 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001404
Teresa Charlind8df0262019-11-11 12:28:15 +00001405 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001406 if (m_Parameters.m_BiasEnabled)
1407 {
1408 ValidatePointer(m_Bias, descriptorName, "bias");
1409
Teresa Charlind8df0262019-11-11 12:28:15 +00001410 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1411 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001412
1413 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1414 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1415 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001416 ValidatePerAxisQuantization(inputTensorInfo,
1417 outputTensorInfo,
1418 weightTensorInfo,
1419 optionalBiasTensorInfo,
1420 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001421
1422 std::vector<DataType> supportedTypes =
1423 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001424 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001425 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001426 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001427 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001428 DataType::QAsymmU8,
1429 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001430 };
1431
1432 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1433 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001434}
1435
1436void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1437{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001438 const std::string descriptorName{"PermuteQueueDescriptor"};
1439
1440 ValidateNumInputs(workloadInfo, descriptorName, 1);
1441 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001442
1443 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1444
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001445 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1446 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001447
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001448 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1449 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001450
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001451 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001452 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001453 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001454 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001455 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1456 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1457 "must match dst dimension " + to_string(mapping[i]) +
1458 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001459 }
1460 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001461
1462 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001463}
1464
1465void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1466{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001467 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001468
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001469 ValidateNumInputs(workloadInfo, descriptorName, 1);
1470 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1471
1472 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1473 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1474
1475 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1476 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001477
1478 std::vector<DataType> supportedTypes =
1479 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001480 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001481 DataType::Float32,
1482 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001483 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001484 DataType::QAsymmU8,
1485 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001486 };
1487
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001488 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1489 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001490}
1491
1492void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1493{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001494 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001495
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001496 ValidateNumInputs(workloadInfo, descriptorName, 1);
1497 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1498
1499 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1500 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1501
1502 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1503 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001504
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001505 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001506 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001507 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001508 DataType::Float16,
1509 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001510 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001511 DataType::QAsymmU8,
1512 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001513 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001514
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001515 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1516 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001517
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001518 // ResizeBilinear only changes width and height: batch and channel count must match.
1519 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1520 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001521 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001522 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001523 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001524 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1525 descriptorName, inputBatchSize, outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001526 }
1527
Teresa Charlin970f43b2019-07-01 13:51:07 +01001528 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001529 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1530 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001531 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001532 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001533 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001534 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1535 descriptorName, inputChannelCount, outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001536 }
1537}
1538
1539void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1540{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001541 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001542
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001543 ValidateNumInputs(workloadInfo, descriptorName, 1);
1544 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1545
1546 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1547 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1548
1549 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1550 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001551
1552 std::vector<DataType> supportedTypes =
1553 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001554 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001555 DataType::Float16,
1556 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001557 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001558 DataType::QAsymmU8,
1559 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001560 };
1561
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001562 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1563 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001564
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001565 // Resize only changes width and height: batch and channel count must match.
1566 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1567 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001568 if (inputBatchSize != outputBatchSize)
1569 {
1570 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001571 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1572 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001573 }
1574
1575 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001576 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1577 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001578 if (inputChannelCount != outputChannelCount)
1579 {
1580 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001581 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1582 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001583 }
1584}
1585
1586void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1587{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001588 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001589
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001590 ValidateNumInputs(workloadInfo, descriptorName, 1);
1591 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1592
1593 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1594 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1595
1596 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1597 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1598
1599 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1600
telsoa014fcda012018-03-09 14:13:49 +00001601 if (m_Parameters.m_Min > m_Parameters.m_Max)
1602 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001603 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001604 }
telsoa014fcda012018-03-09 14:13:49 +00001605}
1606
Kevin Mayce5045a2019-10-02 14:07:47 +01001607void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1608{
1609 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1610
1611 ValidateNumInputs(workloadInfo, descriptorName, 1);
1612 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1613
1614 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1615 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1616
1617 if (inputTensorInfo.GetNumDimensions() > 4)
1618 {
1619 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1620 }
1621
1622 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1623
1624 // Check the supported data types
1625 std::vector<DataType> supportedTypes =
1626 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001627 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001628 DataType::Float32,
1629 DataType::Float16
1630 };
1631
1632 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001633 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001634}
1635
telsoa014fcda012018-03-09 14:13:49 +00001636void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1637{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001638 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001639
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001640 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001641 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1642
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001643 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1644 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1645
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001646 if (inputTensorInfo.GetNumDimensions() > 4)
1647 {
1648 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1649 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001650
1651 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001652
1653 // Check the supported data types
1654 std::vector<DataType> supportedTypes =
1655 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001656 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001657 DataType::Float32,
1658 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001659 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001660 DataType::QAsymmU8,
1661 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001662 };
1663
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001664 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001665 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1666}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001667
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001668void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1669{
1670 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1671
1672 ValidateNumInputs(workloadInfo, descriptorName, 1);
1673 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1674
1675 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1676 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1677
1678 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1679
1680 std::vector<DataType> supportedTypes =
1681 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001682 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001683 DataType::Float32,
1684 DataType::Float16,
1685 };
1686
1687 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001688 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001689}
1690
1691void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1692{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001693 const std::string descriptorName{"ConstantQueueDescriptor"};
1694
1695 ValidateNumInputs(workloadInfo, descriptorName, 0);
1696 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001697
1698 if (!m_LayerOutput)
1699 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001700 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001701 }
1702
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001703 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1704 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001705
1706 // Check the supported data types
1707 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001708 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001709 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001710 DataType::Float32,
1711 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001712 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001713 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001714 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001715 DataType::QSymmS16,
1716 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001717 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001718
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001719 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001720}
1721
1722void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1723{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001724 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001725
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001726 ValidateNumInputs(workloadInfo, descriptorName, 1);
1727 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1728
1729 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1730 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1731
1732 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001733
1734 // Check the supported data types
1735 std::vector<DataType> supportedTypes =
1736 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001737 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001738 DataType::Float32,
1739 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001740 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001741 DataType::QAsymmU8,
1742 DataType::QSymmS16,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001743 DataType::Signed32,
1744 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001745 };
1746
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001747 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1748 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001749}
1750
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001751void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1752{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001753 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001754
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001755 ValidateNumInputs(workloadInfo, descriptorName, 1);
1756 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1757
1758 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1759 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1760
1761 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1762 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001763
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001764 if (m_Parameters.m_BlockShape.size() != 2)
1765 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001766 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001767 }
1768
1769 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1770 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001771 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1772 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001773 }
1774
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001775 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001776
1777 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001778 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001779
Matthew Bentham8800c002018-11-19 13:19:28 +00001780 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001781
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001782 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1783 widthPad.first + widthPad.second;
1784 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1785 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001786
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001787 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1788 inputShape[dimensionIndices.GetChannelsIndex()];
1789 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001790
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001791 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001792 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001793 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001794 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001795 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001796 }
1797
1798 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001799 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001800 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1801 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001802 }
nikraj01120522a2019-05-31 11:33:07 +01001803
1804 std::vector<DataType> supportedTypes =
1805 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001806 DataType::BFloat16,
1807 DataType::Float16,
1808 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001809 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001810 DataType::QAsymmU8,
1811 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001812 };
1813
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001814 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1815 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001816}
1817
Keith Davisa57eccb2019-06-14 17:33:22 +01001818void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1819{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001820 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001821
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001822 ValidateNumInputs(workloadInfo, descriptorName, 1);
1823 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001824
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001825 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1826 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1827
1828 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1829 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001830
1831 std::vector<DataType> supportedTypes =
1832 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001833 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001834 DataType::Float32,
1835 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001836 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001837 DataType::QAsymmU8,
1838 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001839 };
1840
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001841 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1842 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001843
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001844 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1845
1846 if (m_Parameters.m_BlockSize == 0)
1847 {
1848 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1849 }
1850
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001851 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1852 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1853 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1854 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001855
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001856 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001857 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001858 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001859 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1860 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001861 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001862
1863 const TensorShape& outputShape = outputTensorInfo.GetShape();
1864 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1865 {
1866 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1867 "must be divisible by the square of block size." );
1868 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001869}
1870
telsoa014fcda012018-03-09 14:13:49 +00001871void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1872{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001873 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001874
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001875 ValidateNumInputs(workloadInfo, descriptorName, 1);
1876 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1877
1878 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1879 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001880
1881 std::vector<DataType> supportedTypes =
1882 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001883 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001884 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001885 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001886 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001887 };
1888
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001889 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001890
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001891 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001892 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001893 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001894 }
1895}
1896
telsoa01c577f2c2018-08-31 09:22:23 +01001897void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1898{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001899 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1900
1901 const std::string descriptorName{"LstmQueueDescriptor"};
1902
1903 // check dimensions of all inputs and outputs
1904 if (workloadInfo.m_InputTensorInfos.size() != 3)
1905 {
1906 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1907 }
1908 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1909 {
1910 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1911 }
1912
1913 std::vector<DataType> supportedTypes =
1914 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001915 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001916 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001917 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001918 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001919 };
1920
Jan Eilers38e05bd2019-06-26 13:10:09 +01001921 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001922 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1923
Jan Eilers38e05bd2019-06-26 13:10:09 +01001924 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001925 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001926 {
1927 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1928 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001929 descriptorName,
1930 "input_0",
1931 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001932 }
1933 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001934 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001935 {
1936 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1937 workloadInfo.m_OutputTensorInfos[i],
1938 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001939 "input_0",
1940 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001941 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001942
janeil0117d8d852019-11-15 15:00:16 +00001943 // Making sure clipping parameters have valid values.
1944 // == 0 means no clipping
1945 // > 0 means clipping
1946 if (m_Parameters.m_ClippingThresCell < 0.0f)
1947 {
1948 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
1949 }
1950 if (m_Parameters.m_ClippingThresProj < 0.0f)
1951 {
1952 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
1953 }
1954
Jan Eilers38e05bd2019-06-26 13:10:09 +01001955
1956 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01001957 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1958 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1959 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1960 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1961 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1962 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1963
Jan Eilers38e05bd2019-06-26 13:10:09 +01001964 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001965 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1966 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001967 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001968 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1969 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001970 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001971 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1972 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001973 // scratchBufferTensor
1974 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001975 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1976 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001977 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001978 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1979 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001980 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001981 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1982 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001983 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001984 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1985 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001986
1987
1988 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1989 if ( m_InputToInputWeights )
1990 {
1991 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1992 (n_cell * n_input), "InputLayerNormWeights");
1993 }
1994
1995 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1996 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1997 (n_cell * n_input), "InputToForgetWeights");
1998
1999 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2000 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2001 (n_cell * n_input), "InputToCellWeights");
2002
2003 if ( m_RecurrentToInputWeights )
2004 {
2005 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2006 (n_cell * n_output), "RecurrentToInputWeights");
2007 }
2008
2009 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2010 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2011 (n_cell * n_output), "RecurrentToForgetWeights");
2012
2013 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2014 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2015 (n_cell * n_output), "RecurrentToCellWeights");
2016
2017 // Make sure the input-gate's parameters are either both present (regular
2018 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2019 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2020 !m_Parameters.m_CifgEnabled) ||
2021 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2022 m_Parameters.m_CifgEnabled));
2023 if (!cifg_weights_all_or_none)
2024 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002025 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2026 "RecurrentToInputWeights must either both be present (regular LSTM) "
2027 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2028 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002029 }
2030
2031 if ( m_CellToInputWeights )
2032 {
2033 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2034 n_cell, "CellToInputWeights");
2035 }
2036 if ( m_CellToForgetWeights )
2037 {
2038 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2039 n_cell, "CellToForgetWeights");
2040 }
2041 if ( m_CellToOutputWeights )
2042 {
2043 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2044 n_cell, "CellToOutputWeights");
2045 }
2046
2047 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2048 bool peephole_weights_all_or_none =
2049 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2050 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2051 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2052 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2053 if (!peephole_weights_all_or_none)
2054 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002055 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002056 }
2057
2058 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2059 if (m_Parameters.m_CifgEnabled)
2060 {
2061 if (m_InputGateBias)
2062 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002063 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002064 }
2065 }
2066 else
2067 {
2068 if (!m_InputGateBias)
2069 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002070 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2071 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002072 }
2073 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2074 n_cell, "InputGateBias");
2075 }
2076
2077 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2078 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2079
2080 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2081 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2082
2083 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2084 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2085
2086 if (m_ProjectionWeights)
2087 {
2088 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2089 (n_cell * n_output), "ProjectionWeights");
2090 }
2091 if (m_ProjectionBias)
2092 {
2093 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2094 }
2095
2096 // Making sure the projection tensors are consistent:
2097 // 1) If projection weight is not present, then projection bias should not be
2098 // present.
2099 // 2) If projection weight is present, then projection bias is optional.
2100 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2101 !m_Parameters.m_ProjectionEnabled)
2102 || (m_ProjectionWeights && !m_ProjectionBias &&
2103 m_Parameters.m_ProjectionEnabled)
2104 || (m_ProjectionWeights && m_ProjectionBias &&
2105 m_Parameters.m_ProjectionEnabled));
2106 if (!projecton_tensors_consistent)
2107 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002108 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002109 }
2110
2111 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2112 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2113 // either all have values or none of them have values. Layer normalization is used when the values of all the
2114 // layer normalization weights are present
2115 if (m_InputLayerNormWeights)
2116 {
2117 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2118 }
2119 if (m_ForgetLayerNormWeights)
2120 {
2121 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2122 }
2123 if (m_CellLayerNormWeights)
2124 {
2125 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2126 }
2127 if (m_OutputLayerNormWeights)
2128 {
2129 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2130 }
2131
Jan Eilers38e05bd2019-06-26 13:10:09 +01002132 if (m_Parameters.m_LayerNormEnabled)
2133 {
2134 if (!m_Parameters.m_CifgEnabled)
2135 {
2136 if (!m_InputLayerNormWeights)
2137 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002138 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2139 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002140 }
2141 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2142 1, n_cell, "InputLayerNormWeights");
2143 }
2144 else if (m_InputLayerNormWeights)
2145 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002146 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2147 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002148 }
2149
2150 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2151 "ForgetLayerNormWeights");
2152 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2153
2154 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2155 "OutputLayerNormWeights");
2156 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2157
2158 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2159 "CellLayerNormWeights");
2160 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2161 }
2162 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2163 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002164 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2165 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002166 }
telsoa01c577f2c2018-08-31 09:22:23 +01002167}
2168
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002169void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2170{
2171 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2172
2173 ValidateNumInputs(workloadInfo, descriptorName, 1);
2174 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2175
2176 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2177 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2178
2179 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2180 {
2181 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2182 }
2183
2184 if (outputTensorInfo.GetDataType() != DataType::Float32)
2185 {
2186 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2187 }
2188
2189 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2190}
2191
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002192void ConvertFp32ToBf16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2193{
2194 const std::string descriptorName{"ConvertFp32ToBf16QueueDescriptor"};
2195
2196 ValidateNumInputs(workloadInfo, descriptorName, 1);
2197 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2198
2199 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2200 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2201
2202 if (inputTensorInfo.GetDataType() != DataType::Float32)
2203 {
2204 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2205 }
2206
2207 if (outputTensorInfo.GetDataType() != DataType::BFloat16)
2208 {
2209 throw InvalidArgumentException(descriptorName + ": Output tensor type must be BFloat16.");
2210 }
2211
2212 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2213}
2214
telsoa01c577f2c2018-08-31 09:22:23 +01002215void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2216{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002217 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002218
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002219 ValidateNumInputs(workloadInfo, descriptorName, 1);
2220 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2221
2222 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2223 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2224
2225 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002226 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002227 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002228 }
2229
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002230 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002231 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002232 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002233 }
2234
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002235 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002236}
2237
2238void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2239{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002240 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002241
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002242 ValidateNumInputs(workloadInfo, descriptorName, 1);
2243 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2244
2245 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2246 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2247
2248 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002249 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002250 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002251 }
2252
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002253 if (outputTensorInfo.GetDataType() != DataType::Float32)
2254 {
2255 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2256 }
2257
2258 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002259}
2260
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002261void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2262{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002263 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002264
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002265 ValidateNumInputs(workloadInfo, descriptorName, 2);
2266 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2267
2268 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2269 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2270 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2271
2272 std::vector<DataType> supportedTypes =
2273 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002274 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002275 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002276 DataType::Float32,
2277 DataType::QAsymmS8,
2278 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002279 DataType::QSymmS16,
2280 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002281 };
2282
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002283 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2284 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2285 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002286
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002287 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2288 inputTensorInfo1,
2289 outputTensorInfo,
2290 descriptorName,
2291 "input_0",
2292 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002293}
2294
David Beckc2044fe2018-09-05 15:00:38 +01002295void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2296{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002297 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002298
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002299 ValidateNumInputs(workloadInfo, descriptorName, 2);
2300 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2301
2302 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2303 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2304 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2305
2306 std::vector<DataType> supportedTypes =
2307 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002308 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002309 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002310 DataType::Float32,
2311 DataType::QAsymmS8,
2312 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002313 DataType::QSymmS16,
2314 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002315 };
2316
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002317 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2318 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2319 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002320
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002321 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2322 inputTensorInfo1,
2323 outputTensorInfo,
2324 descriptorName,
2325 "input_0",
2326 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002327}
2328
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002329void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2330{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002331 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002332
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002333 ValidateNumInputs(workloadInfo, descriptorName, 2);
2334 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2335
2336 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2337 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2338 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2339
2340 std::vector<DataType> supportedTypes =
2341 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002342 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002343 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002344 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002345 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002346 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002347 DataType::QSymmS16,
2348 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002349 };
2350
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002351 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2352 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2353 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002354
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002355 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2356 inputTensorInfo1,
2357 outputTensorInfo,
2358 descriptorName,
2359 "input_0",
2360 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002361}
2362
narpra01a6bf9122018-09-10 09:50:09 +01002363void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2364{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002365 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002366
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002367 ValidateNumInputs(workloadInfo, descriptorName, 1);
2368 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2369
2370 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2371 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002372
2373 std::vector<DataType> supportedTypes =
2374 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002375 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002376 DataType::Float32,
2377 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002378 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002379 DataType::QAsymmU8,
2380 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002381 };
narpra01eb061912018-09-10 17:35:27 +01002382
James Conroy4d1ff582019-06-10 17:06:39 +01002383 // First check if input tensor data type is supported, then
2384 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002385 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2386 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002387
narpra0132b90462018-09-13 11:07:48 +01002388 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002389 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002390 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002391 }
narpra0132b90462018-09-13 11:07:48 +01002392 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002393 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002394 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002395 }
2396 else
2397 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002398 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002399 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002400 ValidateTensorNumDimensions(outputTensorInfo,
2401 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002402 outputDim > 0 ? outputDim : 1,
2403 "output");
2404 }
narpra01a6bf9122018-09-10 09:50:09 +01002405}
2406
jimfly012c9322a2018-09-19 10:59:49 +01002407void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2408{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002409 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002410
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002411 ValidateNumInputs(workloadInfo, descriptorName, 1);
2412 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2413
2414 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2415 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002416
jimfly012c9322a2018-09-19 10:59:49 +01002417 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002418 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2419
jimfly012c9322a2018-09-19 10:59:49 +01002420 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002421 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2422 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2423 "as there are dimensions in the input tensor that is " +
2424 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2425 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002426 }
2427}
2428
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002429void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2430{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002431 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002432
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002433 ValidateNumInputs(workloadInfo, descriptorName, 1);
2434 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002435
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002436 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2437 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2438
Sadik Armagan2208b602019-07-31 16:36:27 +01002439 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002440 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002441 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002442 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002443 DataType::Float16,
2444 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002445 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002446 DataType::QAsymmU8,
2447 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002448 };
2449
2450 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002451
Keith Davis0c2eeac2020-02-11 16:51:50 +00002452 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002453 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002454 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002455 }
2456}
2457
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002458void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2459{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002460 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002461
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002462 ValidateNumInputs(workloadInfo, descriptorName, 1);
2463 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002464
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002465 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2466 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002467
2468 std::vector<DataType> supportedTypes =
2469 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002470 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002471 DataType::Float32,
2472 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002473 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002474 DataType::QAsymmU8,
2475 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002476 };
2477
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002478 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2479 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002480}
2481
Conor Kennedy430b5d82018-11-14 15:28:28 +00002482void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2483{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002484 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002485
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002486 ValidateNumInputs(workloadInfo, descriptorName, 1);
2487 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2488
2489 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2490 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002491
2492 std::vector<DataType> supportedTypes =
2493 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002494 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002495 DataType::Float16,
2496 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002497 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002498 DataType::QAsymmU8,
2499 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002500 };
2501
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002502 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2503 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002504
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002505 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002506
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002507 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002508 if (rank > 4)
2509 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002510 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002511 }
2512
Conor Kennedy430b5d82018-11-14 15:28:28 +00002513 // Begin, End & Stride length must be of rank(input0)
2514 if (m_Parameters.m_Begin.size() != rank)
2515 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002516 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002517 }
2518
2519 if (m_Parameters.m_End.size() != rank)
2520 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002521 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002522 }
2523
2524 if (m_Parameters.m_Stride.size() != rank)
2525 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002526 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002527 }
2528
2529 // Stride entries must be non-zero
2530 for (auto& stride : m_Parameters.m_Stride)
2531 {
2532 if (stride == 0)
2533 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002534 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002535 }
2536 }
2537}
2538
kevmay0190539692018-11-29 08:40:19 +00002539void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2540{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002541 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002542
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002543 ValidateNumInputs(workloadInfo, descriptorName, 2);
2544 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2545
2546 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2547 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2548 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2549
2550 std::vector<DataType> supportedTypes =
2551 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002552 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002553 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002554 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002555 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002556 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002557 DataType::QSymmS16,
2558 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002559 };
2560
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002561 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2562 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2563 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002564
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002565 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2566 inputTensorInfo1,
2567 outputTensorInfo,
2568 descriptorName,
2569 "input_0",
2570 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002571}
2572
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002573void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2574{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002575 const std::string descriptorName{"DebugQueueDescriptor"};
2576
2577 ValidateNumInputs(workloadInfo, descriptorName, 1);
2578 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002579}
2580
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002581void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2582{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002583 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002584
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002585 ValidateNumInputs(workloadInfo, descriptorName, 2);
2586 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002587
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002588 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2589 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2590 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2591
2592 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2593 inputTensorInfo1,
2594 outputTensorInfo,
2595 descriptorName,
2596 "input_0",
2597 "input_1");
2598
2599 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002600 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002601 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002602 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002603}
2604
FrancisMurtagh878f0232018-12-19 10:56:15 +00002605void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2606{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002607 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002608
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002609 ValidateNumInputs(workloadInfo, descriptorName, 2);
2610 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002611
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002612 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2613 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2614 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2615
2616 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2617 inputTensorInfo1,
2618 outputTensorInfo,
2619 descriptorName,
2620 "input_0",
2621 "input_1");
2622
2623 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002624 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002625 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002626 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002627}
2628
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002629void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2630{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002631 const std::string descriptorName{"RsqrtQueueDescriptor"};
2632
2633 ValidateNumInputs(workloadInfo, descriptorName, 1);
2634 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2635
2636 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2637 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2638
2639 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002640
2641 std::vector<DataType> supportedTypes =
2642 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002643 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002644 DataType::Float16,
2645 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002646 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002647 DataType::QAsymmU8,
2648 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002649 };
2650
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002651 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2652 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002653}
2654
narpra01b89b05f2019-01-16 09:53:09 +00002655void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2656{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002657 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002658
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002659 ValidateNumInputs(workloadInfo, descriptorName, 2);
2660 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002661
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002662 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2663 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002664 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002665 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002666 }
2667
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002668 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2669 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2670
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002671 std::vector<DataType> supportedTypes =
2672 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002673 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002674 DataType::Float16,
2675 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002676 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002677 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002678 DataType::QSymmS16,
2679 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002680 };
2681
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002682 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002683
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002684 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002685
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002686 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2687 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002688}
2689
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002690void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2691{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002692 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2693
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002694 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002695
2696 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2697 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002698 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002699 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2700 }
2701
2702 if (m_Anchors == nullptr)
2703 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002704 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002705 }
2706
2707 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002708 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2709 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2710
2711 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002712 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002713 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2714 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002715
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002716 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2717 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2718 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002719
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002720 const std::vector<DataType> supportedInputTypes =
2721 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002722 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002723 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002724 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002725 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002726 DataType::QAsymmU8,
2727 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002728 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002729
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002730 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2731 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2732 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2733
2734 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2735 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2736 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2737 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2738
2739 // NOTE: Output is always Float32 regardless of input type
2740 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2741 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2742 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2743 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002744
2745 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2746 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002747 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002748 "must be positive and less than or equal to 1.");
2749 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002750
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002751 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2752 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002753 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002754 "should be equal to number of classes + 1.");
2755 }
2756}
2757
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002758void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2759{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002760 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002761
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002762 ValidateNumInputs(workloadInfo, descriptorName, 1);
2763 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2764
2765 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2766 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2767
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002768 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002769 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002770 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002771 }
2772
Sadik Armagan2208b602019-07-31 16:36:27 +01002773 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002774 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002775 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002776 DataType::Float32,
2777 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002778 };
2779
2780 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002781}
2782
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002783void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2784{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002785 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002786
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002787 ValidateNumInputs(workloadInfo, descriptorName, 2);
2788 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002789
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002790 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2791 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2792 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002793
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002794 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2795 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2796
2797 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2798 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002799}
2800
Sadik Armaganeff363d2019-04-05 15:25:46 +01002801void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2802{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002803 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002804
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002805 ValidateNumInputs(workloadInfo, descriptorName, 2);
2806 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2807
2808 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2809 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2810
2811 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2812 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2813
2814 std::vector<DataType> supportedTypes =
2815 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002816 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002817 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002818 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002819 DataType::QAsymmU8,
2820 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002821 };
2822
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002823 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2824 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002825
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002826 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2827 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002828
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002829 ValidateTensorShapesMatch(inputTensorInfo0,
2830 outputTensorInfo0,
2831 descriptorName,
2832 "input_0",
2833 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002834
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002835 ValidateTensorShapesMatch(inputTensorInfo0,
2836 outputTensorInfo1,
2837 descriptorName,
2838 "input_0",
2839 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002840}
2841
Derek Lamberti901ea112019-12-10 22:07:09 +00002842void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002843{
2844 // This is internally generated so it should not need validation.
2845}
2846
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002847void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2848{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002849 const std::string& descriptorName{"PreluQueueDescriptor"};
2850
2851 ValidateNumInputs(workloadInfo, descriptorName, 2);
2852 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2853
2854 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2855 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2856 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002857
2858 std::vector<DataType> supportedTypes
2859 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002860 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002861 DataType::Float16,
2862 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002863 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002864 DataType::QAsymmU8,
2865 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002866 };
2867
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002868 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2869 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002870
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002871 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002872
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002873 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2874 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002875
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002876 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2877 alphaTensorInfo,
2878 outputTensorInfo,
2879 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002880 "input",
2881 "alpha");
2882}
2883
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002884void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2885{
2886 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2887
2888 ValidateNumInputs(workloadInfo, descriptorName, 1);
2889 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2890
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002891 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2892 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2893
2894 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2895 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002896
2897 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002898
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002899 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2900 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002901
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002902 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2903
2904 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002905 if (m_Parameters.m_BiasEnabled)
2906 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002907 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002908
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002909 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2910 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002911
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002912 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002913 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002914 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002915
2916 ValidatePerAxisQuantization(inputTensorInfo,
2917 outputTensorInfo,
2918 weightTensorInfo,
2919 optionalBiasTensorInfo,
2920 descriptorName);
2921
2922 std::vector<DataType> supportedTypes =
2923 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002924 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002925 DataType::Float32,
2926 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002927 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002928 DataType::QAsymmU8,
2929 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002930 };
2931
2932 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2933 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002934}
2935
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002936void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2937{
2938 const std::string descriptorName{"TransposeQueueDescriptor"};
2939
2940 ValidateNumInputs(workloadInfo, descriptorName, 1);
2941 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2942
2943 const PermutationVector& mapping = m_Parameters.m_DimMappings;
2944
2945 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2946 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2947
2948 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
2949 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
2950
2951 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
2952 {
2953 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
2954 {
2955 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
2956 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
2957 "must match dst dimension " + to_string(i) +
2958 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
2959 }
2960 }
2961
2962 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2963}
2964
James Conroy4f1f8992020-04-29 20:01:10 +01002965void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2966{
2967 const std::string descriptorName{"QLstmQueueDescriptor"};
2968
2969 // Validate number of inputs/outputs
2970 ValidateNumInputs(workloadInfo, descriptorName, 3);
2971 ValidateNumOutputs(workloadInfo, descriptorName, 3);
2972
2973 // Input/output tensor info
2974 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2975 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
2976 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
2977
2978 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2979 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2980 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
2981
2982 // Supported types for various tensors in QLSTM
2983 std::vector<DataType> inputOutputSupportedTypes =
2984 {
2985 DataType::QAsymmS8
2986 };
2987
2988 std::vector<DataType> cellStateSupportedTypes =
2989 {
2990 DataType::QSymmS16
2991 };
2992
2993 std::vector<DataType> weightsSupportedTypes =
2994 {
2995 DataType::QSymmS8
2996 };
2997
2998 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
2999 {
3000 DataType::QSymmS16
3001 };
3002
3003 std::vector<DataType> biasSupportedTypes =
3004 {
3005 DataType::Signed32
3006 };
3007
3008 // Validate types of input/output tensors
3009 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3010 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3011 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3012
3013 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3014 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3015 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3016
3017 // Validate matching types of input/output tensors
3018 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3019 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3020 "outputStateIn", "outputStateOut");
3021 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3022
3023 // Infer number of batches, number of units, input size and output size from tensor dimensions
3024 const uint32_t numBatches = inputInfo.GetShape()[0];
3025 const uint32_t inputSize = inputInfo.GetShape()[1];
3026 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3027 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3028
3029 // Validate number of dimensions and number of elements for input/output tensors
3030 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3031 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3032 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3033
3034 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3035 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3036 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3037
3038 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3039 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3040 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3041 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3042
3043 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3044 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3045 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3046
3047 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3048 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3049 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3050
3051 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3052 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3053 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3054 " RecurrentToForgetWeights");
3055
3056 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3057 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3058 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3059
3060 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3061 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3062 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3063
3064 // Validate data types for MANDATORY weights tensors (all should match each other)
3065 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3066
3067 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3068 "inputToForgetWeights", "inputToCellWeights");
3069 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3070 "inputToForgetWeights", "inputToOutputWeights");
3071
3072 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3073 "inputToForgetWeights", "recurrentToForgeteights");
3074 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3075 "inputToForgetWeights", "recurrentToCellWeights");
3076 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3077 "inputToForgetWeights", "recurrentToOutputWeights");
3078
3079 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3080 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3081 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3082 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3083
3084 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3085 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3086 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3087
3088 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3089 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3090 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3091
3092 // Validate data types for MANDATORY bias tensors
3093 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3094
3095 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3096 "forgetGateBias", "cellBias");
3097 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3098 "forgetGateBias", "outputGateBias");
3099
3100 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3101 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3102 !m_Parameters.m_CifgEnabled) ||
3103 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3104 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3105
3106 if (!allCifgParamsPresentOrNot)
3107 {
3108 throw InvalidArgumentException(descriptorName +
3109 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3110 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3111 "set appropriately.");
3112 }
3113
3114 if (!m_Parameters.m_CifgEnabled)
3115 {
3116 // Validate number of dimensions and number of elements
3117 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3118 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3119
3120 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3121 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3122 " RecurrentToInputWeights");
3123
3124 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3125 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3126
3127 // Validate data types
3128 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3129 "inputToForgetWeights", "inputToInputWeights");
3130 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3131 "inputToForgetWeights", "recurrentToInputWeights");
3132 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3133 "forgetGateBias", "inputGateBias");
3134 }
3135
3136 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3137 bool allPeepholeWeightsPresentOrNot =
3138 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3139 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3140 || (!m_CellToInputWeights && !m_CellToForgetWeights
3141 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3142
3143 if (!allPeepholeWeightsPresentOrNot)
3144 {
3145 throw InvalidArgumentException(descriptorName +
3146 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3147 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3148 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3149 "appropriately.");
3150 }
3151
3152 if (m_Parameters.m_PeepholeEnabled)
3153 {
3154 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3155 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3156 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3157
3158 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3159 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3160 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3161 "cellToForgetWeight", "cellToOutputWeights");
3162
3163 if (!m_Parameters.m_CifgEnabled)
3164 {
3165 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3166 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3167 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3168 "cellToForgetWeights", "cellToInputWeights");
3169 }
3170 }
3171
3172 // Validate OPTIONAL params: Layer Norm Weights
3173 bool allLayerNormWeightsPresentOrNot =
3174 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3175 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3176 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3177 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3178
3179 if (!allLayerNormWeightsPresentOrNot)
3180 {
3181 throw InvalidArgumentException(descriptorName +
3182 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3183 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3184 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3185 "only be present when Layer Norm is enabled and CIFG is disabled. "
3186 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3187 }
3188
3189 if (m_Parameters.m_LayerNormEnabled)
3190 {
3191 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3192 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3193 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3194
3195 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3196 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3197 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3198 "forgetLayerNormWeights", "cellLayerNormWeights");
3199
3200 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3201 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3202 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3203 "forgetLayerNormWeights", "outputLayerNormWeights");
3204
3205 if (!m_Parameters.m_CifgEnabled)
3206 {
3207 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3208 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3209 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3210 "forgetLayerNormWeights", "inputLayerNormWeights");
3211 }
3212 }
3213
3214 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3215 bool correctProjectionTensorsPresent =
3216 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3217 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3218 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3219
3220 if (!correctProjectionTensorsPresent)
3221 {
3222 throw InvalidArgumentException(descriptorName +
3223 ": If projection is enabled, ProjectionWeights should be present and "
3224 "ProjectionBias is optional. If projection is disabled, neither "
3225 "ProjectionWeights nor ProjectionBias should be present.");
3226 }
3227
3228 if (m_Parameters.m_ProjectionEnabled)
3229 {
3230 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3231 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3232 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3233
3234 if (m_ProjectionBias)
3235 {
3236 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003237 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003238 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3239 }
3240
3241 }
3242 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3243 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3244 throw InvalidArgumentException(descriptorName +
3245 ": If projection is disabled, output quantization info (scale, offset) "
3246 "should match HiddenStateScale and HiddenStateZeroPoint.");
3247 }
3248
3249}
3250
James Conroy9c3cae82019-08-01 16:01:48 +01003251void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3252{
3253 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3254
3255 // Validate number of inputs/outputs
3256 ValidateNumInputs(workloadInfo, descriptorName, 3);
3257 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3258
3259 // Input/output tensor infos
3260 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3261 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3262 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3263
3264 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3265 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3266
3267 std::vector<DataType> inputOutputSupportedTypes =
3268 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003269 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003270 };
3271
3272 std::vector<DataType> cellStateSupportedTypes =
3273 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003274 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003275 };
3276
3277 std::vector<DataType> weightsSupportedTypes =
3278 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003279 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003280 };
3281
3282 std::vector<DataType> biasSupportedTypes =
3283 {
3284 DataType::Signed32
3285 };
3286
3287 // Validate types of input/output tensors
3288 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3289 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3290 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3291
3292 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3293 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3294
3295 // Validate matching types of input/output tensors
3296 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3297 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3298 "outputStateIn", "outputStateOut");
3299 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3300
3301 // Validate matching quantization info for input/output tensors
3302 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3303 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3304 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003305
James Conroy9c3cae82019-08-01 16:01:48 +01003306 // Infer number of batches, input size and output size from tensor dimensions
3307 const uint32_t numBatches = inputInfo.GetShape()[0];
3308 const uint32_t inputSize = inputInfo.GetShape()[1];
3309 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3310
3311 // Validate number of dimensions and number of elements for input/output tensors
3312 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3313 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3314 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3315 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3316 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3317
3318 // Validate number of dimensions and number of elements for weights tensors
3319 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3320 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3321 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3322
3323 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3324 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3325 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3326
3327 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3328 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3329 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3330
3331 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3332 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3333 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3334
3335 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3336 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3337 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3338
3339 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3340 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3341 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3342 " RecurrentToForgetWeights");
3343
3344 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3345 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3346 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3347
3348 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3349 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3350 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3351
3352 // Validate data types for weights tensors (all should match each other)
3353 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3354
3355 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3356 "inputToInputWeights", "inputToForgetWeights");
3357 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3358 "inputToInputWeights", "inputToCellWeights");
3359 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3360 "inputToInputWeights", "inputToOutputWeights");
3361
3362 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3363 "inputToInputWeights", "recurrentToInputWeights");
3364 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3365 "inputToInputWeights", "recurrentToForgeteights");
3366 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3367 "inputToInputWeights", "recurrentToCellWeights");
3368 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3369 "inputToInputWeights", "recurrentToOutputWeights");
3370
3371 // Validate matching quantization info for weight tensors (all should match each other)
3372 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3373 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3374 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3375 descriptorName, "inputToInputWeights", "inputToCellWeights");
3376 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3377 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3378
3379 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3380 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3381 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3382 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3383 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3384 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3385 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3386 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3387
3388 // Validate number of dimensions and number of elements in bias tensors
3389 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3390 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3391 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3392
3393 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3394 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3395 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3396
3397 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3398 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3399 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3400
3401 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3402 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3403 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3404
3405 // Validate data types for bias tensors (all should match each other)
3406 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3407
3408 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3409 "inputGateBias", "forgetGateBias");
3410 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3411 "inputGateBias", "cellBias");
3412 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3413 "inputGateBias", "outputGateBias");
3414
3415 // Validate bias tensor quantization info
3416 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3417 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3418 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3419 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3420}
3421
Kevin May868eb142019-09-04 17:29:31 +01003422void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3423{
3424 const std::string descriptorName{"AbsQueueDescriptor"};
3425
3426 ValidateNumInputs(workloadInfo, descriptorName, 1);
3427 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3428
3429 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3430 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3431
3432 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3433
3434 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003435 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003436 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003437 DataType::Float16,
3438 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003439 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003440 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003441 DataType::QSymmS16,
3442 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003443 };
Kevin May868eb142019-09-04 17:29:31 +01003444
3445 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3446 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3447}
3448
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003449void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3450{
3451 const std::string descriptorName{"SliceQueueDescriptor"};
3452
3453 ValidateNumInputs(workloadInfo, descriptorName, 1);
3454 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3455
3456 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3457 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3458
3459 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3460
3461 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3462 if (rank > 4)
3463 {
3464 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3465 }
3466
3467 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3468
3469 // Check if m_Begin and m_Size have the expected length
3470 if (m_Parameters.m_Begin.size() != rank)
3471 {
3472 throw InvalidArgumentException(descriptorName +
3473 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3474 }
3475 if (m_Parameters.m_Size.size() != rank)
3476 {
3477 throw InvalidArgumentException(descriptorName +
3478 ": Length of size descriptor must equal rank " + std::to_string(rank));
3479 }
3480
3481 // Check if the shape of the output tensor matches m_Size
3482 const TensorShape& outputShape = outputTensorInfo.GetShape();
3483 for (unsigned int i = 0u; i < rank; ++i)
3484 {
3485 if (m_Parameters.m_Size[i] != outputShape[i])
3486 {
3487 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3488 }
3489 }
3490
3491 // Check if the sum of begin offset and size in a given dimension
3492 // does not exceed the size of corresponding input
3493 const TensorShape& inputShape = inputTensorInfo.GetShape();
3494 for(unsigned int i = 0u; i < rank; ++i)
3495 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003496 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003497 {
3498 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3499 std::to_string(i) + " exceeds input size.");
3500 }
3501 }
3502}
3503
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003504void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3505{
3506 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3507
3508 ValidateNumInputs(workloadInfo, descriptorName, 1);
3509 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3510
3511 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3512 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3513
3514 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3515 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3516
3517 std::vector<DataType> supportedTypes =
3518 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003519 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003520 DataType::Float32,
3521 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003522 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003523 DataType::QAsymmU8,
3524 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003525 };
3526
3527 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3528 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3529
3530 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3531
3532 if (m_Parameters.m_BlockSize == 0)
3533 {
3534 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3535 }
3536
3537 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3538 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3539 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3540 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3541
3542 const TensorShape& outputShape = outputInfo.GetShape();
3543 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3544 {
3545 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3546 "must be divisible by block size.");
3547 }
3548
3549 const TensorShape& inputShape = inputInfo.GetShape();
3550 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3551 {
3552 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3553 "must be divisible by the square of block size." );
3554 }
3555}
3556
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003557void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3558{
3559 const std::string descriptorName{"ComparisonQueueDescriptor"};
3560
3561 ValidateNumInputs(workloadInfo, descriptorName, 2);
3562 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3563
3564 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3565 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3566 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3567
3568 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3569 inputTensorInfo1,
3570 outputTensorInfo,
3571 descriptorName,
3572 "input_0",
3573 "input_1");
3574
3575 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3576 {
3577 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3578 }
3579}
3580
josh minor4a3c6102020-01-06 16:40:46 -06003581void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3582{
3583 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3584
3585 ValidateNumInputs(workloadInfo, descriptorName, 1);
3586 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3587
3588 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3589 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3590
3591 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3592
3593 std::vector<DataType> supportedTypes =
3594 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003595 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003596 DataType::Float16,
3597 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003598 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003599 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003600 DataType::QSymmS16,
3601 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003602 };
3603
James Conroyaba90cd2020-11-06 16:28:18 +00003604 std::vector<DataType> logicalSupportedTypes =
3605 {
3606 DataType::Boolean
3607 };
3608
3609 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3610 {
3611 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3612 }
3613 else
3614 {
3615 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3616 }
3617
3618
josh minor4a3c6102020-01-06 16:40:46 -06003619 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3620}
3621
Finn Williams2605b232020-06-10 15:53:46 +01003622void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3623{
3624 const std::string descriptorName{"RankQueueDescriptor"};
3625
3626 ValidateNumInputs(workloadInfo, descriptorName, 1);
3627 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3628
3629 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3630 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3631
3632 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3633 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3634
3635 std::vector<DataType> supportedTypes =
3636 {
3637 DataType::BFloat16,
3638 DataType::Float16,
3639 DataType::Float32,
3640 DataType::QAsymmS8,
3641 DataType::QAsymmU8,
3642 DataType::QSymmS8,
3643 DataType::QSymmS16,
3644 DataType::Signed32
3645 };
3646
3647 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3648 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3649}
3650
James Conroyaba90cd2020-11-06 16:28:18 +00003651void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3652{
3653 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3654
3655 ValidateNumInputs(workloadInfo, descriptorName, 2);
3656 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3657
3658 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3659 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3660 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3661
3662 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3663 inputTensorInfo1,
3664 outputTensorInfo,
3665 descriptorName,
3666 "input_0",
3667 "input_1");
3668
3669 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3670 {
3671 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3672 }
3673
3674 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3675 {
3676 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3677 }
3678
3679 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3680 {
3681 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3682 }
3683}
3684
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003685void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3686{
3687 const std::string descriptorName{"ReduceQueueDescriptor"};
3688
3689 ValidateNumInputs(workloadInfo, descriptorName, 1);
3690 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3691
3692 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3693 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3694
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003695 std::vector<DataType> supportedTypes =
3696 {
3697 DataType::BFloat16,
3698 DataType::Float16,
3699 DataType::Float32,
3700 DataType::QAsymmS8,
3701 DataType::QAsymmU8,
3702 DataType::QSymmS16,
3703 DataType::Signed32
3704 };
3705
3706 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3707 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3708}
3709
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003710} // namespace armnn