blob: ff97fc7f417401ceaa3bc7fe082d4aad6d87a358 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00006#include <backendsCommon/WorkloadData.hpp>
7#include <backendsCommon/CpuTensorHandle.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/TensorUtils.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010010#include <armnn/utility/NumericCast.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000011
telsoa014fcda012018-03-09 14:13:49 +000012#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000014#include <string>
15#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000016
17#include <boost/format.hpp>
18
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
telsoa01c577f2c2018-08-31 09:22:23 +010029 case DataType::Float16:
30 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000031 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000032 case DataType::Float32:
33 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000034 case DataType::QAsymmS8:
35 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000036 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000037 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000038 case DataType::QSymmS8:
39 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000040 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010041 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000042 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010043 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000044 return DataType::Float32;
45 }
46}
47
48namespace
49{
50
51//---------------------------------------------------------------
52//android ndk does not support std::to_string function.
53template <typename T>
54std::string to_string(T value)
55{
56 std::ostringstream os;
57 os << value;
58 return os.str();
59}
60
61//---------------------------------------------------------------
62void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
63{
64 if (!ptr)
65 {
66 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
67 paramName + " parameter must be set.");
68 }
69}
70
71//---------------------------------------------------------------
72void ValidateTensorShapesMatch(const TensorInfo& first,
73 const TensorInfo& second,
74 std::string const& descName,
75 std::string const& firstName,
76 std::string const& secondName)
77{
78 if (first.GetShape() != second.GetShape())
79 {
80 throw InvalidArgumentException(descName + ": "
81 + firstName + " & " + secondName + " must have identical shapes");
82 }
83}
84
85//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010086void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000087{
Sadik Armaganeff363d2019-04-05 15:25:46 +010088 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000089 {
90 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010091 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000092 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
93 }
94}
95
96//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010097void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000098{
Sadik Armaganeff363d2019-04-05 15:25:46 +010099 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000100 {
101 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100102 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000103 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
104 }
105}
106
107//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100108void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000109 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100110 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000111 std::string const& tensorName)
112{
113 if (tensor.GetNumDimensions() != numDimensions)
114 {
115 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
116 to_string(tensor.GetNumDimensions()) + " dimensions for " +
117 tensorName + " tensor.");
118 }
119}
120
121//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100122void ValidateTensorNumElements(const TensorInfo& tensor,
123 std::string const& descName,
124 unsigned int numElements,
125 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100126{
127 if (tensor.GetNumElements() != numElements)
128 {
129 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100130 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100131 tensorName + " tensor.");
132 }
133}
134
135//---------------------------------------------------------------
136void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100137 unsigned int numDimension,
138 unsigned int numElements,
139 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100140{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100141 const std::string functionName{"ValidateTensorNumDimNumElem"};
142 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
143 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100144}
145
146//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000147void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
148 const std::string& descName, std::string const& tensorName)
149{
150 if (tensor.GetDataType() != dataType)
151 {
152 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
153 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
154 }
155}
156
Derek Lambertid466a542020-01-22 15:37:29 +0000157void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
158{
159 ARMNN_NO_DEPRECATE_WARN_BEGIN
160 if (tensor.GetDataType() != DataType::QSymmS8 &&
161 tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
162 {
163 throw InvalidArgumentException(descName +
164 ": Expected data type which supports per-axis quantization scheme but got " +
165 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
166 }
167 ARMNN_NO_DEPRECATE_WARN_END
168}
169
telsoa014fcda012018-03-09 14:13:49 +0000170//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100171void ValidateTensorQuantizationSpace(const TensorInfo& first,
172 const TensorInfo& second,
173 const std::string& descName,
174 std::string const& firstName,
175 std::string const& secondName)
176{
177 if (!first.IsQuantized() ||
178 !second.IsQuantized())
179 {
180 // Not a quantized type, ignore the validation
181 return;
182 }
183
184 DataType firstDataType = first.GetDataType();
185 DataType secondDataType = second.GetDataType();
186
187 if (firstDataType != secondDataType)
188 {
189 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
190 " must be of the same quantized type, " +
191 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
192 secondName + " is " + GetDataTypeName(secondDataType));
193 }
194
195 if (!first.IsTypeSpaceMatch(second))
196 {
197 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
198 " must have the same quantization space, " +
199 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
200 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
201 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
202 " and scale " + to_string(second.GetQuantizationScale()));
203 }
204}
205
206//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100207void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
208 const TensorInfo& inputTensorInfo,
209 const TensorInfo& weightsTensorInfo,
210 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000211{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000212 // Helper lambda function to validate a single bias quantization scale value
213 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
214 {
ricbur013f4d7102019-10-31 16:22:18 +0000215 constexpr float tolerance = 0.000001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000216 if (std::abs(biasScale - expectedScale) > tolerance)
217 {
218 // Print the float values with extra precision to see very small differences
219 std::stringstream msg;
220 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
221 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
222 biasScale;
223 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
224 }
225 };
226
telsoa014fcda012018-03-09 14:13:49 +0000227 if (biasTensor.GetQuantizationOffset() != 0)
228 {
229 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
230 to_string(biasTensor.GetQuantizationOffset()));
231 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000232
233 if (biasTensor.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000234 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000235 // Validate per-axis quantization scales
236 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
237 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
238
239 if (weightScales.size() != biasScales.size())
240 {
241 std::stringstream msg;
242 msg << descName << ": Expected matchhing number of per-axis quantization scales, but got different "
243 << "values: weights=" << weightScales.size() << ", biases=" << biasScales.size();
244 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
245 }
246
247 for (size_t i = 0ul; i < biasScales.size(); ++i)
248 {
249 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
250 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
251 }
252 }
253 else
254 {
255 // Validate per-tensor quantization scale
256 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
257 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000258 }
259}
260
261//---------------------------------------------------------------
262void ValidateTensors(const std::vector<ITensorHandle*>& vec,
263 unsigned int numExpected,
264 const std::string& descName,
265 const std::string& varName)
266{
267 if (vec.empty() && numExpected > 0)
268 {
269 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
270 }
271
272 for (unsigned int i = 0; i < numExpected; ++i)
273 {
274 if (!vec[i])
275 {
276 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
277 }
278 }
279}
280
281//---------------------------------------------------------------
282void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
283 const TensorInfo& second,
284 const TensorInfo& output,
285 std::string const& descName,
286 std::string const& firstName,
287 std::string const& secondName)
288{
289 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
290 // broadcasted.
291 if (first.GetNumDimensions() != second.GetNumDimensions())
292 {
293 throw InvalidArgumentException(descName + ": Tensors "
294 + firstName + " & " + secondName
295 + " must have the same number of dimensions in order to be broadcasted");
296 }
297 uint32_t numDims = first.GetNumDimensions();
298 std::vector<uint32_t> outputDims(numDims, 0u);
299 for (uint32_t i = 0; i < numDims; i++)
300 {
301 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
302 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
303 if (dimsNotEqual && dimsNotOne)
304 {
305 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
306 }
307 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
308 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100309 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000310 if (broadcastShape != output.GetShape())
311 {
312 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
313 + firstName + " & " + secondName
314 + " does not match the output shape");
315 }
316}
317
318//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100319void ValidateDataTypes(const TensorInfo& info,
320 const std::vector<armnn::DataType>& supportedTypes,
321 std::string const& descName)
322{
323 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
324 if (iterator == supportedTypes.end())
325 {
326 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
327 }
328}
329
James Conroy4d1ff582019-06-10 17:06:39 +0100330//---------------------------------------------------------------
331void ValidateTensorDataTypesMatch(const TensorInfo& first,
332 const TensorInfo& second,
333 std::string const& descName,
334 std::string const& firstName,
335 std::string const& secondName)
336{
337 if (first.GetDataType() != second.GetDataType())
338 {
339 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
340 " must have identical data types.");
341 }
342}
343
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100344//---------------------------------------------------------------
345void ValidateTensorNumElementsMatch(const TensorInfo& first,
346 const TensorInfo& second,
347 std::string const& descName,
348 std::string const& firstName,
349 std::string const& secondName)
350{
351 if (first.GetNumElements() != second.GetNumElements())
352 {
353 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
354 " must have the same number of elements.");
355 }
356}
357
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000358void ValidateWeightDataType(const TensorInfo& inputInfo,
359 const TensorInfo& weightInfo,
360 const std::string& descName)
361{
362 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000363 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000364 {
Derek Lambertid466a542020-01-22 15:37:29 +0000365 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000366 const std::vector<DataType> validTypes =
367 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000368 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100369 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000370 DataType::QSymmS8,
371 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000372 };
Derek Lambertid466a542020-01-22 15:37:29 +0000373 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000374
375 ValidateDataTypes(weightInfo, validTypes, descName);
376 }
377 else
378 {
379 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
380 }
381}
382
383void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
384 const std::string& descName,
385 const std::string& tensorName)
386{
387 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
388 if (!quantizationDim.has_value())
389 {
390 throw InvalidArgumentException(boost::str(
391 boost::format("%1%: Quantization dimension for per-axis quantization not set on tensor %2%.")
392 % descName % tensorName));
393 }
394
395 if (quantizationDim.value() != 0)
396 {
397 throw InvalidArgumentException(boost::str(
398 boost::format("%1%: Quantization dimension for per-axis quantization expected to be 0 on tensor %2%, "
399 "but got: %3%") % descName % tensorName % quantizationDim.value()));
400 }
401}
402
403void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
404 const std::string& descName,
405 const std::string& tensorName)
406{
407 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
408 if (quantizationOffset != 0)
409 {
410 throw InvalidArgumentException(boost::str(
411 boost::format("%1%: Quantization offset for per-axis quantization expected to be 0 on tensor %2%, "
412 "but got: %3%") % descName % tensorName % quantizationOffset));
413 }
414}
415
416void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
417 const TensorInfo& outputInfo,
418 const TensorInfo& weightInfo,
419 const Optional<TensorInfo>& optionalBiasInfo,
420 const std::string& descName)
421{
422 if (weightInfo.HasPerAxisQuantization())
423 {
424 const DataType inputDataType = inputInfo.GetDataType();
425 const DataType outputDataType = outputInfo.GetDataType();
426
Keith Davis0c2eeac2020-02-11 16:51:50 +0000427 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000428
429 if (!canHavePerAxisQuantization)
430 {
431 throw InvalidArgumentException(boost::str(
432 boost::format("%1%: Per-axis quantization parameters set on tensor %2%, "
433 "but data type does not support per-axis quantization.") % descName % "weight"));
434 }
435
Derek Lambertid466a542020-01-22 15:37:29 +0000436
437 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000438 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
439 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
440
441 if (optionalBiasInfo.has_value())
442 {
443 const TensorInfo& biasInfo = optionalBiasInfo.value();
444 if (!biasInfo.HasPerAxisQuantization())
445 {
446 throw InvalidArgumentException(boost::str(
447 boost::format("%1%: Per-axis quantization parameters not set on bias tensor, despite being set on "
448 "weight tensor.") % descName));
449 }
450
451 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
452 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
453 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
454 }
455 }
456}
457
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100458} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000459
460void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
461 unsigned int numExpectedIn, unsigned int numExpectedOut) const
462{
463 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
464 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
465}
466
467//---------------------------------------------------------------
468void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
469{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100470 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000471
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100472 ValidateNumInputs(workloadInfo, descriptorName, 1);
473 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000474
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100475 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
476 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
477
478 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
479 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000480
481 if (m_Inputs.size() != m_Outputs.size())
482 {
483 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100484 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
485 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000486 }
487
488 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
489 {
490 if (!m_Inputs[i])
491 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100492 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
493 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000494 }
495
496 if (!m_Outputs[i])
497 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100498 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
499 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000500 }
501 }
502}
503
Derek Lambertif674aa02019-08-01 15:56:25 +0100504//---------------------------------------------------------------
505void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
506{
507 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
508 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
509
510 if (workloadInfo.m_InputTensorInfos.size() != 1)
511 {
512 throw InvalidArgumentException(boost::str(
513 boost::format("Number of input infos (%1%) is not 1.")
514 % workloadInfo.m_InputTensorInfos.size()));
515
516 }
517
518 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
519 {
520 throw InvalidArgumentException(boost::str(
521 boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
522 % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
523 }
524
525 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
526 {
527 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
528 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
529 {
530 throw InvalidArgumentException(boost::str(
531 boost::format("Number of elements for tensor input and output %1% does not match")
532 % i ));
533 }
534 }
535
536 if (m_Inputs.size() != 1)
537 {
538 throw InvalidArgumentException(boost::str(
539 boost::format("Number of inputs (%1%) is not 1.")
540 % m_Inputs.size()));
541 }
542
543 if (m_Inputs.size() != m_Outputs.size())
544 {
545 throw InvalidArgumentException(boost::str(
546 boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
547 % m_Inputs.size() % m_Outputs.size()));
548 }
549
550 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
551 {
552 if (!m_Inputs[i])
553 {
554 throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
555 }
556
557 if (!m_Outputs[i])
558 {
559 throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
560 }
561 }
562}
563
564//---------------------------------------------------------------
565void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
566{
567 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
568 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
569
Derek Lambertif674aa02019-08-01 15:56:25 +0100570 if (m_Inputs.size() != 1)
571 {
572 throw InvalidArgumentException(boost::str(
573 boost::format("Number of inputs (%1%) is not 1.")
574 % m_Inputs.size()));
575 }
576
577 if (m_Outputs.size() != 0)
578 {
579 throw InvalidArgumentException(boost::str(
580 boost::format("Number of outputs (%1%) is not 0.")
581 % m_Inputs.size() % m_Outputs.size()));
582 }
583
584 if (!m_Inputs[0])
585 {
586 throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
587 }
588}
589
590//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000591void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
592{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100593 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100594
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100595 ValidateNumInputs(workloadInfo, descriptorName, 1);
596 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100597
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100598 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
599 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100600
601 std::vector<DataType> supportedTypes =
602 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000603 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100604 DataType::Float16,
605 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000606 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000607 DataType::QAsymmU8,
608 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100609 };
610
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100611 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
612 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
613 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000614}
615
Nikhil Rajee391d52019-09-05 17:50:44 +0100616void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
617{
618 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
619
620 ValidateNumInputs(workloadInfo, descriptorName, 1);
621 ValidateNumOutputs(workloadInfo, descriptorName, 1);
622
623 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
624 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
625
Inki Daed4619e22020-09-10 15:33:54 +0900626 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
627 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100628 {
Inki Daed4619e22020-09-10 15:33:54 +0900629 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100630 }
631
James Conroyd47a0642019-09-17 14:22:06 +0100632 std::vector<DataType> supportedInputTypes =
633 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000634 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100635 DataType::Float16,
636 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100637 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000638 DataType::QAsymmU8,
639 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900640 DataType::Signed32,
641 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100642 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100643
James Conroyd47a0642019-09-17 14:22:06 +0100644 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100645
646 auto inputShape = inputTensorInfo.GetShape();
647 auto outputShape = outputTensorInfo.GetShape();
648
649 auto inputNumDimensions = inputShape.GetNumDimensions();
650 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
651
652 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
653
654 // 1D input shape results in scalar output shape
655 if (inputShape.GetNumDimensions() == 1)
656 {
657 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
658 {
659 throw InvalidArgumentException(descriptorName + outputShapeError);
660 }
661 }
662 else
663 {
664 for (unsigned int i = 0; i < unsignedAxis; ++i)
665 {
666 if (outputShape[i] != inputShape[i])
667 {
668 throw InvalidArgumentException(descriptorName + outputShapeError);
669 }
670 }
671
672 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
673 {
674 if (outputShape[i - 1] != inputShape[i])
675 {
676 throw InvalidArgumentException(descriptorName + outputShapeError);
677 }
678 }
679 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100680}
681
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100682void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
683{
684 const std::string descriptorName{"SoftmaxQueueDescriptor"};
685
686 ValidateNumInputs(workloadInfo, descriptorName, 1);
687 ValidateNumOutputs(workloadInfo, descriptorName, 1);
688
689 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
690 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
691
692 std::vector<DataType> supportedTypes =
693 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000694 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100695 DataType::Float16,
696 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000697 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000698 DataType::QAsymmU8,
699 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100700 };
701
702 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
703 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
704 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
705}
706
telsoa014fcda012018-03-09 14:13:49 +0000707void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
708{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100709 const std::string descriptorName{"SplitterQueueDescriptor"};
710
711 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000712
Ruomei Yan25339c32019-05-28 16:48:20 +0100713 // Check the supported data types
714 std::vector<DataType> supportedTypes =
715 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000716 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100717 DataType::Float32,
718 DataType::Float16,
719 DataType::Boolean,
720 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100721 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000722 DataType::QAsymmU8,
723 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100724 };
725
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100726 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
727 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100728 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100729 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
730 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
731
732 const std::string outputName = "output_" + std::to_string(i);
733 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100734 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100735
telsoa014fcda012018-03-09 14:13:49 +0000736 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
737 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100738 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000739 }
740
741 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
742 {
743 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100744 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000745 "has to match number of workloadInfo.m_OutputTensorInfos. "
746 "Number of windows: " +
747 to_string(m_ViewOrigins.size()) +
748 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
749 }
750
telsoa01c577f2c2018-08-31 09:22:23 +0100751 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000752 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
753 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
754 {
telsoa01c577f2c2018-08-31 09:22:23 +0100755 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000756 ViewOrigin const& e = m_ViewOrigins[w];
757 if (e.m_Origin.size() != inputDims)
758 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100759 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000760 "have the same dimensionality as the input tensor. "
761 "Window origin (index: " +
762 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
763 " dimensions, the input "
764 "tensor has " +
765 to_string(inputDims) + " dimensions.");
766 }
767 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
768 {
769 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
770 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
771 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100772 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000773 "be smaller or equal than the size of the input in that coord.");
774 }
775 }
776 }
777}
778
Jim Flynne242f2d2019-05-22 14:24:13 +0100779void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000780{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100781 const std::string descriptorName{"ConcatQueueDescriptor"};
782
783 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000784
785 if (m_Inputs.size() <= 0)
786 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100787 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000788 }
789 if (m_Outputs.size() <= 0)
790 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100791 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000792 }
793
794 if (workloadInfo.m_InputTensorInfos.size() <= 0)
795 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100796 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000797 }
798 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
799 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100800 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000801 }
802
Nikhil Raj8599a412018-11-19 14:51:07 +0000803 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
804 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100805 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000806 }
807
808 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
809 {
810 return;
811 }
812
telsoa014fcda012018-03-09 14:13:49 +0000813 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
814 {
815 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100816 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000817 "has to match number of workloadInfo.m_InputTensorInfos. "
818 "Number of windows: " +
819 to_string(m_ViewOrigins.size()) +
820 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
821 }
822
telsoa01c577f2c2018-08-31 09:22:23 +0100823 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000824 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
825 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
826 {
telsoa01c577f2c2018-08-31 09:22:23 +0100827 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000828 ViewOrigin const& e = m_ViewOrigins[w];
829 if (e.m_Origin.size() != outputDims)
830 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100831 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000832 "have the same dimensionality as the output tensor. "
833 "Window origin (index: " +
834 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
835 " dimensions, the output "
836 "tensor has " +
837 to_string(outputDims) + " dimensions.");
838 }
telsoa01c577f2c2018-08-31 09:22:23 +0100839 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000840 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
841 {
842 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
843 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
844 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100845 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000846 "be smaller or equal than the size of the output in that coord.");
847 }
848 }
849 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100850
851 // Check the supported data types
852 std::vector<DataType> supportedTypes =
853 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000854 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100855 DataType::Float32,
856 DataType::Float16,
857 DataType::Boolean,
858 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100859 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000860 DataType::QAsymmU8,
861 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100862 };
863
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100864 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
865 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100866 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100867 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
868 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
869
870 const std::string inputName = "input_" + std::to_string(i);
871 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100872 }
telsoa014fcda012018-03-09 14:13:49 +0000873}
874
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100875void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
876{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100877 const std::string descriptorName{"StackQueueDescriptor"};
878
879 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100880
881 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
882 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100883 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100884 }
885
886 // All inputs must have the same shape, which is defined in parameters
887 const TensorShape& inputShape = m_Parameters.m_InputShape;
888 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
889 {
890 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
891 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100892 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100893 }
894 }
895
Matthew Jacksondba634f2019-08-15 15:14:18 +0100896 if (inputShape.GetNumDimensions() > 4)
897 {
898 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
899 }
900
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100901 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
902 // since the output tensor has an additional dimension.
903 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
904 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100905 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100906 "than the number of input dimensions.");
907 }
908
909 // Output shape must be as inferred from the input shape
910 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
911 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
912 {
913 if (outputShape[i] != inputShape[i])
914 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100915 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100916 "match shape inferred from input tensor.");
917 }
918 }
919
920 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
921 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100922 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100923 "match shape inferred from input tensor.");
924 }
925
926 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
927 {
928 if (outputShape[i] != inputShape[i-1])
929 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100930 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100931 "match shape inferred from input tensor.");
932 }
933 }
934
Matthew Jacksondba634f2019-08-15 15:14:18 +0100935 if (outputShape.GetNumDimensions() > 5)
936 {
937 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
938 }
939
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100940 // Check the supported data types
941 std::vector<DataType> supportedTypes =
942 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000943 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100944 DataType::Float32,
945 DataType::Float16,
946 DataType::Boolean,
947 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100948 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000949 DataType::QAsymmU8,
950 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100951 };
952
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100953 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100954
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100955 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100956 {
957 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
958 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100959 descriptorName,
960 "input_0",
961 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100962 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100963
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100964 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
965 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100966 descriptorName,
967 "input_0",
968 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100969}
970
Ryan OSheaec6c6802020-06-05 17:17:06 +0100971void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
972{
973 const std::string descriptorName{"FillQueueDescriptor"};
974
975 ValidateNumInputs(workloadInfo, descriptorName, 1);
976 ValidateNumOutputs(workloadInfo, descriptorName, 1);
977
978 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
979 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
980
981 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
982
983 std::vector<DataType> supportedTypes =
984 {
985 DataType::BFloat16,
986 DataType::Float32,
987 DataType::Float16,
988 DataType::Signed32
989 };
990
991 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
992}
993
telsoa014fcda012018-03-09 14:13:49 +0000994void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
995{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100996 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000997
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100998 ValidateNumInputs(workloadInfo, descriptorName, 1);
999 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1000
1001 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1002 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1003
1004 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1005
1006 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001007 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001008 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001009 }
1010
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001011 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001012
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001013 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1014 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001015
1016 if (m_Parameters.m_BiasEnabled)
1017 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001018 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001019
telsoa01c577f2c2018-08-31 09:22:23 +01001020 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001021 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
1022 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001023
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001024 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1025 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001026 }
1027
Francis Murtagh46c09d02019-05-28 08:15:28 +01001028 // Check the supported data types
1029 std::vector<DataType> supportedTypes =
1030 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001031 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001032 DataType::Float32,
1033 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001034 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001035 DataType::QAsymmU8,
1036 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001037 };
1038
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001039 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001040
1041 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1042 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1043 {
1044 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1045 {
1046 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1047 "for BFloat16 input.");
1048 }
1049 }
1050 else
1051 {
1052 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1053 }
telsoa014fcda012018-03-09 14:13:49 +00001054}
1055
telsoa014fcda012018-03-09 14:13:49 +00001056void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1057{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001058 const std::string descriptorName{"NormalizationQueueDescriptor"};
1059
1060 ValidateNumInputs(workloadInfo, descriptorName, 1);
1061 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1062
1063 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1064 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001065
1066 // Check the supported data types
1067 std::vector<DataType> supportedTypes =
1068 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001069 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001070 DataType::Float16,
1071 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001072 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001073 DataType::QAsymmU8,
1074 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001075 };
1076
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001077 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001078
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001079 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001080
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001081 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001082}
1083
1084void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1085{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001086 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001087
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001088 ValidateNumInputs(workloadInfo, descriptorName, 2);
1089 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1090
1091 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1092 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1093 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1094
1095 std::vector<DataType> supportedTypes =
1096 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001097 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001098 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001099 DataType::Float16,
1100 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001101 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001102 DataType::QSymmS16,
1103 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001104 };
1105
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001106 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1107 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1108 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001109
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001110 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1111 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001112
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001113 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1114 inputTensorInfo1,
1115 outputTensorInfo,
1116 descriptorName,
1117 "input_0",
1118 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001119}
1120
telsoa014fcda012018-03-09 14:13:49 +00001121void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1122{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001123 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001124
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001125 ValidateNumInputs(workloadInfo, descriptorName, 2);
1126 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1127
1128 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1129 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1130 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1131
1132 std::vector<DataType> supportedTypes =
1133 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001134 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001135 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001136 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001137 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001138 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001139 DataType::QSymmS16,
1140 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001141 };
1142
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001143 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1144 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1145 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001146
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001147 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1148 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001149
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001150 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1151 inputTensorInfo1,
1152 outputTensorInfo,
1153 descriptorName,
1154 "input_0",
1155 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001156}
1157
1158void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1159{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001160 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001161
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001162 ValidateNumInputs(workloadInfo, descriptorName, 1);
1163 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1164
1165 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1166 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001167
1168 std::vector<DataType> supportedTypes =
1169 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001170 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001171 DataType::Float16,
1172 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001173 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001174 DataType::QAsymmU8,
1175 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001176 };
1177
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001178 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1179 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001180
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001181 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001182 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001183
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001184 ValidatePointer(m_Mean, descriptorName, "mean");
1185 ValidatePointer(m_Variance, descriptorName, "variance");
1186 ValidatePointer(m_Beta, descriptorName, "beta");
1187 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001188
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001189 const TensorInfo& mean = m_Mean->GetTensorInfo();
1190 const TensorInfo& variance = m_Variance->GetTensorInfo();
1191 const TensorInfo& beta = m_Beta->GetTensorInfo();
1192 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001193
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001194 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1195 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1196 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1197 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001198
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001199 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1200 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1201 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001202}
1203
1204void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1205{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001206 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001207
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001208 ValidateNumInputs(workloadInfo, descriptorName, 1);
1209 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001210
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001211 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1212 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001213
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001214 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1215 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001216
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001217 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001218
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001219 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1220 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001221
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001222 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001223
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001224 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001225 if (m_Parameters.m_BiasEnabled)
1226 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001227 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001228
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001229 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1230 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001231
1232 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1233 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001234 }
1235
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001236 ValidatePerAxisQuantization(inputTensorInfo,
1237 outputTensorInfo,
1238 weightTensorInfo,
1239 optionalBiasTensorInfo,
1240 descriptorName);
1241
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001242 std::vector<DataType> supportedTypes =
1243 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001244 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001245 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001246 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001247 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001248 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001249 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001250 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001251 };
1252
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001253 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001254
1255 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1256 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1257 {
1258 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1259 {
1260 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1261 "for BFloat16 input.");
1262 }
1263 }
1264 else
1265 {
1266 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1267 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001268}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001269
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001270void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1271{
1272 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1273
1274 ValidateNumInputs(workloadInfo, descriptorName, 1);
1275 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1276
1277 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1278 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1279
1280 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1281 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1282
1283 ValidatePointer(m_Weight, descriptorName, "weight");
1284
1285 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1286 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1287
1288 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1289 {
1290 throw InvalidArgumentException(
1291 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
1292 "cannot be smaller than 1.") % descriptorName %
1293 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
1294 }
1295
1296 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1297
1298 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1299 // inputChannels * channelMultiplier should be equal to outputChannels.
1300 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1301 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1302 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1303 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1304 {
1305 throw InvalidArgumentException(
1306 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
1307 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
1308 "(provided %4%).") % descriptorName % numWeightOutputChannels %
1309 numWeightInputChannels % numWeightChannelMultiplier));
1310 }
1311
Teresa Charlind8df0262019-11-11 12:28:15 +00001312 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001313
Teresa Charlind8df0262019-11-11 12:28:15 +00001314 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001315 if (m_Parameters.m_BiasEnabled)
1316 {
1317 ValidatePointer(m_Bias, descriptorName, "bias");
1318
Teresa Charlind8df0262019-11-11 12:28:15 +00001319 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1320 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001321
1322 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1323 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1324 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001325 ValidatePerAxisQuantization(inputTensorInfo,
1326 outputTensorInfo,
1327 weightTensorInfo,
1328 optionalBiasTensorInfo,
1329 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001330
1331 std::vector<DataType> supportedTypes =
1332 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001333 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001334 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001335 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001336 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001337 DataType::QAsymmU8,
1338 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001339 };
1340
1341 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1342 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001343}
1344
1345void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1346{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001347 const std::string descriptorName{"PermuteQueueDescriptor"};
1348
1349 ValidateNumInputs(workloadInfo, descriptorName, 1);
1350 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001351
1352 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1353
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001354 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1355 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001356
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001357 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1358 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001359
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001360 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001361 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001362 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001363 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001364 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1365 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1366 "must match dst dimension " + to_string(mapping[i]) +
1367 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001368 }
1369 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001370
1371 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001372}
1373
1374void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1375{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001376 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001377
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001378 ValidateNumInputs(workloadInfo, descriptorName, 1);
1379 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1380
1381 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1382 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1383
1384 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1385 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001386
1387 std::vector<DataType> supportedTypes =
1388 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001389 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001390 DataType::Float32,
1391 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001392 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001393 DataType::QAsymmU8,
1394 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001395 };
1396
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001397 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1398 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001399}
1400
1401void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1402{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001403 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001404
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001405 ValidateNumInputs(workloadInfo, descriptorName, 1);
1406 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1407
1408 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1409 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1410
1411 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1412 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001413
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001414 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001415 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001416 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001417 DataType::Float16,
1418 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001419 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001420 DataType::QAsymmU8,
1421 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001422 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001423
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001424 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1425 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001426
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001427 // ResizeBilinear only changes width and height: batch and channel count must match.
1428 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1429 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001430 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001431 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001432 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001433 boost::str(boost::format("%1%: Input batch size (%2%) "
1434 "does not match output batch size (%3%)") %
1435 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001436 }
1437
Teresa Charlin970f43b2019-07-01 13:51:07 +01001438 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001439 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1440 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001441 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001442 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001443 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001444 boost::str(boost::format("%1%: Input channel count (%2%) "
1445 "does not match output channel count (%3%)") %
1446 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001447 }
1448}
1449
1450void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1451{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001452 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001453
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001454 ValidateNumInputs(workloadInfo, descriptorName, 1);
1455 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1456
1457 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1458 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1459
1460 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1461 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001462
1463 std::vector<DataType> supportedTypes =
1464 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001465 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001466 DataType::Float16,
1467 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001468 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001469 DataType::QAsymmU8,
1470 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001471 };
1472
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001473 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1474 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001475
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001476 // Resize only changes width and height: batch and channel count must match.
1477 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1478 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001479 if (inputBatchSize != outputBatchSize)
1480 {
1481 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001482 boost::str(boost::format("%1%: Input batch size (%2%) "
1483 "does not match output batch size (%3%)") %
1484 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001485 }
1486
1487 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001488 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1489 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001490 if (inputChannelCount != outputChannelCount)
1491 {
1492 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001493 boost::str(boost::format("%1%: Input channel count (%2%) "
1494 "does not match output channel count (%3%)") %
1495 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001496 }
1497}
1498
1499void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1500{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001501 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001502
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001503 ValidateNumInputs(workloadInfo, descriptorName, 1);
1504 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1505
1506 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1507 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1508
1509 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1510 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1511
1512 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1513
telsoa014fcda012018-03-09 14:13:49 +00001514 if (m_Parameters.m_Min > m_Parameters.m_Max)
1515 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001516 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001517 }
telsoa014fcda012018-03-09 14:13:49 +00001518}
1519
Kevin Mayce5045a2019-10-02 14:07:47 +01001520void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1521{
1522 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1523
1524 ValidateNumInputs(workloadInfo, descriptorName, 1);
1525 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1526
1527 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1528 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1529
1530 if (inputTensorInfo.GetNumDimensions() > 4)
1531 {
1532 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1533 }
1534
1535 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1536
1537 // Check the supported data types
1538 std::vector<DataType> supportedTypes =
1539 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001540 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001541 DataType::Float32,
1542 DataType::Float16
1543 };
1544
1545 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001546 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001547}
1548
telsoa014fcda012018-03-09 14:13:49 +00001549void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1550{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001551 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001552
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001553 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001554 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1555
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001556 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1557 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1558
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001559 if (inputTensorInfo.GetNumDimensions() > 4)
1560 {
1561 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1562 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001563
1564 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001565
1566 // Check the supported data types
1567 std::vector<DataType> supportedTypes =
1568 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001569 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001570 DataType::Float32,
1571 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001572 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001573 DataType::QAsymmU8,
1574 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001575 };
1576
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001577 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001578 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1579}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001580
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001581void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1582{
1583 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1584
1585 ValidateNumInputs(workloadInfo, descriptorName, 1);
1586 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1587
1588 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1589 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1590
1591 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1592
1593 std::vector<DataType> supportedTypes =
1594 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001595 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001596 DataType::Float32,
1597 DataType::Float16,
1598 };
1599
1600 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001601 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001602}
1603
1604void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1605{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001606 const std::string descriptorName{"ConstantQueueDescriptor"};
1607
1608 ValidateNumInputs(workloadInfo, descriptorName, 0);
1609 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001610
1611 if (!m_LayerOutput)
1612 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001613 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001614 }
1615
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001616 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1617 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001618
1619 // Check the supported data types
1620 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001621 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001622 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001623 DataType::Float32,
1624 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001625 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001626 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001627 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001628 DataType::QSymmS16,
1629 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001630 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001631
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001632 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001633}
1634
1635void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1636{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001637 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001638
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001639 ValidateNumInputs(workloadInfo, descriptorName, 1);
1640 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1641
1642 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1643 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1644
1645 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001646
1647 // Check the supported data types
1648 std::vector<DataType> supportedTypes =
1649 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001650 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001651 DataType::Float32,
1652 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001653 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001654 DataType::QAsymmU8,
1655 DataType::QSymmS16,
1656 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001657 };
1658
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001659 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1660 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001661}
1662
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001663void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1664{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001665 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001666
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001667 ValidateNumInputs(workloadInfo, descriptorName, 1);
1668 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1669
1670 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1671 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1672
1673 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1674 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001675
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001676 if (m_Parameters.m_BlockShape.size() != 2)
1677 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001678 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001679 }
1680
1681 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1682 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001683 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1684 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001685 }
1686
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001687 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001688
1689 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001690 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001691
Matthew Bentham8800c002018-11-19 13:19:28 +00001692 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001693
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001694 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1695 widthPad.first + widthPad.second;
1696 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1697 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001698
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001699 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1700 inputShape[dimensionIndices.GetChannelsIndex()];
1701 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001702
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001703 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001704 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001705 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001706 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001707 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001708 }
1709
1710 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001711 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001712 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1713 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001714 }
nikraj01120522a2019-05-31 11:33:07 +01001715
1716 std::vector<DataType> supportedTypes =
1717 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001718 DataType::BFloat16,
1719 DataType::Float16,
1720 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001721 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001722 DataType::QAsymmU8,
1723 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001724 };
1725
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001726 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1727 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001728}
1729
Keith Davisa57eccb2019-06-14 17:33:22 +01001730void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1731{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001732 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001733
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001734 ValidateNumInputs(workloadInfo, descriptorName, 1);
1735 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001736
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001737 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1738 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1739
1740 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1741 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001742
1743 std::vector<DataType> supportedTypes =
1744 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001745 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001746 DataType::Float32,
1747 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001748 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001749 DataType::QAsymmU8,
1750 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001751 };
1752
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001753 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1754 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001755
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001756 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1757
1758 if (m_Parameters.m_BlockSize == 0)
1759 {
1760 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1761 }
1762
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001763 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1764 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1765 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1766 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001767
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001768 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001769 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001770 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001771 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1772 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001773 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001774
1775 const TensorShape& outputShape = outputTensorInfo.GetShape();
1776 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1777 {
1778 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1779 "must be divisible by the square of block size." );
1780 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001781}
1782
telsoa014fcda012018-03-09 14:13:49 +00001783void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1784{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001785 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001786
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001787 ValidateNumInputs(workloadInfo, descriptorName, 1);
1788 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1789
1790 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1791 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001792
1793 std::vector<DataType> supportedTypes =
1794 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001795 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001796 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001797 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001798 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001799 };
1800
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001801 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001802
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001803 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001804 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001805 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001806 }
1807}
1808
telsoa01c577f2c2018-08-31 09:22:23 +01001809void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1810{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001811 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1812
1813 const std::string descriptorName{"LstmQueueDescriptor"};
1814
1815 // check dimensions of all inputs and outputs
1816 if (workloadInfo.m_InputTensorInfos.size() != 3)
1817 {
1818 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1819 }
1820 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1821 {
1822 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1823 }
1824
1825 std::vector<DataType> supportedTypes =
1826 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001827 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001828 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001829 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001830 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001831 };
1832
Jan Eilers38e05bd2019-06-26 13:10:09 +01001833 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001834 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1835
Jan Eilers38e05bd2019-06-26 13:10:09 +01001836 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001837 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001838 {
1839 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1840 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001841 descriptorName,
1842 "input_0",
1843 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001844 }
1845 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001846 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001847 {
1848 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1849 workloadInfo.m_OutputTensorInfos[i],
1850 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001851 "input_0",
1852 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001853 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001854
janeil0117d8d852019-11-15 15:00:16 +00001855 // Making sure clipping parameters have valid values.
1856 // == 0 means no clipping
1857 // > 0 means clipping
1858 if (m_Parameters.m_ClippingThresCell < 0.0f)
1859 {
1860 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
1861 }
1862 if (m_Parameters.m_ClippingThresProj < 0.0f)
1863 {
1864 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
1865 }
1866
Jan Eilers38e05bd2019-06-26 13:10:09 +01001867
1868 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01001869 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1870 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1871 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1872 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1873 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1874 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1875
Jan Eilers38e05bd2019-06-26 13:10:09 +01001876 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001877 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1878 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001879 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001880 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1881 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001882 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001883 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1884 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001885 // scratchBufferTensor
1886 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001887 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1888 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001889 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001890 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1891 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001892 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001893 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1894 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001895 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001896 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1897 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001898
1899
1900 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1901 if ( m_InputToInputWeights )
1902 {
1903 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1904 (n_cell * n_input), "InputLayerNormWeights");
1905 }
1906
1907 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1908 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1909 (n_cell * n_input), "InputToForgetWeights");
1910
1911 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1912 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1913 (n_cell * n_input), "InputToCellWeights");
1914
1915 if ( m_RecurrentToInputWeights )
1916 {
1917 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1918 (n_cell * n_output), "RecurrentToInputWeights");
1919 }
1920
1921 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1922 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1923 (n_cell * n_output), "RecurrentToForgetWeights");
1924
1925 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1926 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1927 (n_cell * n_output), "RecurrentToCellWeights");
1928
1929 // Make sure the input-gate's parameters are either both present (regular
1930 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1931 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1932 !m_Parameters.m_CifgEnabled) ||
1933 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1934 m_Parameters.m_CifgEnabled));
1935 if (!cifg_weights_all_or_none)
1936 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001937 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1938 "RecurrentToInputWeights must either both be present (regular LSTM) "
1939 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1940 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001941 }
1942
1943 if ( m_CellToInputWeights )
1944 {
1945 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1946 n_cell, "CellToInputWeights");
1947 }
1948 if ( m_CellToForgetWeights )
1949 {
1950 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1951 n_cell, "CellToForgetWeights");
1952 }
1953 if ( m_CellToOutputWeights )
1954 {
1955 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1956 n_cell, "CellToOutputWeights");
1957 }
1958
1959 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1960 bool peephole_weights_all_or_none =
1961 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1962 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1963 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1964 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1965 if (!peephole_weights_all_or_none)
1966 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001967 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001968 }
1969
1970 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1971 if (m_Parameters.m_CifgEnabled)
1972 {
1973 if (m_InputGateBias)
1974 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001975 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001976 }
1977 }
1978 else
1979 {
1980 if (!m_InputGateBias)
1981 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001982 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1983 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001984 }
1985 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1986 n_cell, "InputGateBias");
1987 }
1988
1989 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1990 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1991
1992 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1993 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1994
1995 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1996 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1997
1998 if (m_ProjectionWeights)
1999 {
2000 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2001 (n_cell * n_output), "ProjectionWeights");
2002 }
2003 if (m_ProjectionBias)
2004 {
2005 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2006 }
2007
2008 // Making sure the projection tensors are consistent:
2009 // 1) If projection weight is not present, then projection bias should not be
2010 // present.
2011 // 2) If projection weight is present, then projection bias is optional.
2012 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2013 !m_Parameters.m_ProjectionEnabled)
2014 || (m_ProjectionWeights && !m_ProjectionBias &&
2015 m_Parameters.m_ProjectionEnabled)
2016 || (m_ProjectionWeights && m_ProjectionBias &&
2017 m_Parameters.m_ProjectionEnabled));
2018 if (!projecton_tensors_consistent)
2019 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002020 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002021 }
2022
2023 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2024 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2025 // either all have values or none of them have values. Layer normalization is used when the values of all the
2026 // layer normalization weights are present
2027 if (m_InputLayerNormWeights)
2028 {
2029 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2030 }
2031 if (m_ForgetLayerNormWeights)
2032 {
2033 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2034 }
2035 if (m_CellLayerNormWeights)
2036 {
2037 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2038 }
2039 if (m_OutputLayerNormWeights)
2040 {
2041 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2042 }
2043
Jan Eilers38e05bd2019-06-26 13:10:09 +01002044 if (m_Parameters.m_LayerNormEnabled)
2045 {
2046 if (!m_Parameters.m_CifgEnabled)
2047 {
2048 if (!m_InputLayerNormWeights)
2049 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002050 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2051 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002052 }
2053 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2054 1, n_cell, "InputLayerNormWeights");
2055 }
2056 else if (m_InputLayerNormWeights)
2057 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002058 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2059 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002060 }
2061
2062 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2063 "ForgetLayerNormWeights");
2064 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2065
2066 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2067 "OutputLayerNormWeights");
2068 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2069
2070 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2071 "CellLayerNormWeights");
2072 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2073 }
2074 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2075 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002076 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2077 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002078 }
telsoa01c577f2c2018-08-31 09:22:23 +01002079}
2080
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002081void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2082{
2083 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2084
2085 ValidateNumInputs(workloadInfo, descriptorName, 1);
2086 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2087
2088 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2089 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2090
2091 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2092 {
2093 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2094 }
2095
2096 if (outputTensorInfo.GetDataType() != DataType::Float32)
2097 {
2098 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2099 }
2100
2101 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2102}
2103
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002104void ConvertFp32ToBf16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2105{
2106 const std::string descriptorName{"ConvertFp32ToBf16QueueDescriptor"};
2107
2108 ValidateNumInputs(workloadInfo, descriptorName, 1);
2109 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2110
2111 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2112 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2113
2114 if (inputTensorInfo.GetDataType() != DataType::Float32)
2115 {
2116 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2117 }
2118
2119 if (outputTensorInfo.GetDataType() != DataType::BFloat16)
2120 {
2121 throw InvalidArgumentException(descriptorName + ": Output tensor type must be BFloat16.");
2122 }
2123
2124 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2125}
2126
telsoa01c577f2c2018-08-31 09:22:23 +01002127void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2128{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002129 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002130
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002131 ValidateNumInputs(workloadInfo, descriptorName, 1);
2132 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2133
2134 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2135 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2136
2137 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002138 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002139 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002140 }
2141
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002142 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002143 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002144 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002145 }
2146
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002147 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002148}
2149
2150void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2151{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002152 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002153
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002154 ValidateNumInputs(workloadInfo, descriptorName, 1);
2155 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2156
2157 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2158 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2159
2160 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002161 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002162 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002163 }
2164
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002165 if (outputTensorInfo.GetDataType() != DataType::Float32)
2166 {
2167 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2168 }
2169
2170 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002171}
2172
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002173void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2174{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002175 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002176
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002177 ValidateNumInputs(workloadInfo, descriptorName, 2);
2178 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2179
2180 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2181 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2182 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2183
2184 std::vector<DataType> supportedTypes =
2185 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002186 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002187 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002188 DataType::Float32,
2189 DataType::QAsymmS8,
2190 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002191 DataType::QSymmS16,
2192 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002193 };
2194
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002195 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2196 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2197 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002198
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002199 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2200 inputTensorInfo1,
2201 outputTensorInfo,
2202 descriptorName,
2203 "input_0",
2204 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002205}
2206
David Beckc2044fe2018-09-05 15:00:38 +01002207void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2208{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002209 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002210
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002211 ValidateNumInputs(workloadInfo, descriptorName, 2);
2212 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2213
2214 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2215 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2216 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2217
2218 std::vector<DataType> supportedTypes =
2219 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002220 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002221 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002222 DataType::Float32,
2223 DataType::QAsymmS8,
2224 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002225 DataType::QSymmS16,
2226 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002227 };
2228
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002229 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2230 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2231 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002232
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002233 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2234 inputTensorInfo1,
2235 outputTensorInfo,
2236 descriptorName,
2237 "input_0",
2238 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002239}
2240
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002241void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2242{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002243 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002244
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002245 ValidateNumInputs(workloadInfo, descriptorName, 2);
2246 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2247
2248 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2249 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2250 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2251
2252 std::vector<DataType> supportedTypes =
2253 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002254 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002255 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002256 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002257 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002258 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002259 DataType::QSymmS16,
2260 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002261 };
2262
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002263 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2264 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2265 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002266
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002267 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2268 inputTensorInfo1,
2269 outputTensorInfo,
2270 descriptorName,
2271 "input_0",
2272 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002273}
2274
narpra01a6bf9122018-09-10 09:50:09 +01002275void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2276{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002277 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002278
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002279 ValidateNumInputs(workloadInfo, descriptorName, 1);
2280 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2281
2282 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2283 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002284
2285 std::vector<DataType> supportedTypes =
2286 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002287 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002288 DataType::Float32,
2289 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002290 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002291 DataType::QAsymmU8,
2292 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002293 };
narpra01eb061912018-09-10 17:35:27 +01002294
James Conroy4d1ff582019-06-10 17:06:39 +01002295 // First check if input tensor data type is supported, then
2296 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002297 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2298 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002299
narpra0132b90462018-09-13 11:07:48 +01002300 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002301 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002302 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002303 }
narpra0132b90462018-09-13 11:07:48 +01002304 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002305 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002306 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002307 }
2308 else
2309 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002310 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002311 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002312 ValidateTensorNumDimensions(outputTensorInfo,
2313 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002314 outputDim > 0 ? outputDim : 1,
2315 "output");
2316 }
narpra01a6bf9122018-09-10 09:50:09 +01002317}
2318
jimfly012c9322a2018-09-19 10:59:49 +01002319void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2320{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002321 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002322
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002323 ValidateNumInputs(workloadInfo, descriptorName, 1);
2324 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2325
2326 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2327 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002328
jimfly012c9322a2018-09-19 10:59:49 +01002329 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002330 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2331
jimfly012c9322a2018-09-19 10:59:49 +01002332 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002333 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2334 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2335 "as there are dimensions in the input tensor that is " +
2336 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2337 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002338 }
2339}
2340
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002341void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2342{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002343 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002344
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002345 ValidateNumInputs(workloadInfo, descriptorName, 1);
2346 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002347
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002348 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2349 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2350
Sadik Armagan2208b602019-07-31 16:36:27 +01002351 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002352 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002353 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002354 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002355 DataType::Float16,
2356 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002357 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002358 DataType::QAsymmU8,
2359 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002360 };
2361
2362 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002363
Keith Davis0c2eeac2020-02-11 16:51:50 +00002364 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002365 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002366 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002367 }
2368}
2369
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002370void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2371{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002372 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002373
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002374 ValidateNumInputs(workloadInfo, descriptorName, 1);
2375 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002376
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002377 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2378 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002379
2380 std::vector<DataType> supportedTypes =
2381 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002382 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002383 DataType::Float32,
2384 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002385 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002386 DataType::QAsymmU8,
2387 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002388 };
2389
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002390 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2391 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002392}
2393
Conor Kennedy430b5d82018-11-14 15:28:28 +00002394void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2395{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002396 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002397
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002398 ValidateNumInputs(workloadInfo, descriptorName, 1);
2399 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2400
2401 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2402 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002403
2404 std::vector<DataType> supportedTypes =
2405 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002406 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002407 DataType::Float16,
2408 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002409 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002410 DataType::QAsymmU8,
2411 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002412 };
2413
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002414 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2415 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002416
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002417 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002418
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002419 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002420 if (rank > 4)
2421 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002422 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002423 }
2424
Conor Kennedy430b5d82018-11-14 15:28:28 +00002425 // Begin, End & Stride length must be of rank(input0)
2426 if (m_Parameters.m_Begin.size() != rank)
2427 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002428 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002429 }
2430
2431 if (m_Parameters.m_End.size() != rank)
2432 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002433 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002434 }
2435
2436 if (m_Parameters.m_Stride.size() != rank)
2437 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002438 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002439 }
2440
2441 // Stride entries must be non-zero
2442 for (auto& stride : m_Parameters.m_Stride)
2443 {
2444 if (stride == 0)
2445 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002446 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002447 }
2448 }
2449}
2450
kevmay0190539692018-11-29 08:40:19 +00002451void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2452{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002453 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002454
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002455 ValidateNumInputs(workloadInfo, descriptorName, 2);
2456 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2457
2458 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2459 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2460 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2461
2462 std::vector<DataType> supportedTypes =
2463 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002464 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002465 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002466 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002467 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002468 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002469 DataType::QSymmS16,
2470 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002471 };
2472
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002473 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2474 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2475 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002476
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002477 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2478 inputTensorInfo1,
2479 outputTensorInfo,
2480 descriptorName,
2481 "input_0",
2482 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002483}
2484
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002485void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2486{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002487 const std::string descriptorName{"DebugQueueDescriptor"};
2488
2489 ValidateNumInputs(workloadInfo, descriptorName, 1);
2490 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002491}
2492
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002493void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2494{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002495 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002496
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002497 ValidateNumInputs(workloadInfo, descriptorName, 2);
2498 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002499
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002500 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2501 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2502 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2503
2504 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2505 inputTensorInfo1,
2506 outputTensorInfo,
2507 descriptorName,
2508 "input_0",
2509 "input_1");
2510
2511 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002512 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002513 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002514 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002515}
2516
FrancisMurtagh878f0232018-12-19 10:56:15 +00002517void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2518{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002519 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002520
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002521 ValidateNumInputs(workloadInfo, descriptorName, 2);
2522 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002523
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002524 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2525 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2526 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2527
2528 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2529 inputTensorInfo1,
2530 outputTensorInfo,
2531 descriptorName,
2532 "input_0",
2533 "input_1");
2534
2535 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002536 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002537 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002538 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002539}
2540
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002541void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2542{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002543 const std::string descriptorName{"RsqrtQueueDescriptor"};
2544
2545 ValidateNumInputs(workloadInfo, descriptorName, 1);
2546 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2547
2548 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2549 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2550
2551 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002552
2553 std::vector<DataType> supportedTypes =
2554 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002555 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002556 DataType::Float16,
2557 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002558 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002559 DataType::QAsymmU8,
2560 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002561 };
2562
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002563 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2564 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002565}
2566
narpra01b89b05f2019-01-16 09:53:09 +00002567void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2568{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002569 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002570
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002571 ValidateNumInputs(workloadInfo, descriptorName, 2);
2572 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002573
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002574 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2575 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002576 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002577 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002578 }
2579
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002580 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2581 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2582
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002583 std::vector<DataType> supportedTypes =
2584 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002585 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002586 DataType::Float16,
2587 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002588 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002589 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002590 DataType::QSymmS16,
2591 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002592 };
2593
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002594 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002595
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002596 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002597
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002598 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2599 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002600}
2601
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002602void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2603{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002604 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2605
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002606 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002607
2608 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2609 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002610 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002611 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2612 }
2613
2614 if (m_Anchors == nullptr)
2615 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002616 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002617 }
2618
2619 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002620 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2621 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2622
2623 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002624 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002625 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2626 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002627
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002628 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2629 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2630 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002631
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002632 const std::vector<DataType> supportedInputTypes =
2633 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002634 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002635 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002636 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002637 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002638 DataType::QAsymmU8,
2639 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002640 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002641
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002642 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2643 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2644 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2645
2646 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2647 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2648 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2649 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2650
2651 // NOTE: Output is always Float32 regardless of input type
2652 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2653 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2654 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2655 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002656
2657 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2658 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002659 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002660 "must be positive and less than or equal to 1.");
2661 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002662
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002663 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2664 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002665 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002666 "should be equal to number of classes + 1.");
2667 }
2668}
2669
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002670void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2671{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002672 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002673
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002674 ValidateNumInputs(workloadInfo, descriptorName, 1);
2675 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2676
2677 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2678 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2679
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002680 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002681 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002682 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002683 }
2684
Sadik Armagan2208b602019-07-31 16:36:27 +01002685 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002686 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002687 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002688 DataType::Float32,
2689 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002690 };
2691
2692 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002693}
2694
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002695void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2696{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002697 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002698
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002699 ValidateNumInputs(workloadInfo, descriptorName, 2);
2700 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002701
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002702 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2703 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2704 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002705
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002706 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2707 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2708
2709 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2710 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002711}
2712
Sadik Armaganeff363d2019-04-05 15:25:46 +01002713void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2714{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002715 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002716
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002717 ValidateNumInputs(workloadInfo, descriptorName, 2);
2718 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2719
2720 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2721 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2722
2723 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2724 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2725
2726 std::vector<DataType> supportedTypes =
2727 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002728 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002729 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002730 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002731 DataType::QAsymmU8,
2732 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002733 };
2734
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002735 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2736 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002737
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002738 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2739 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002740
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002741 ValidateTensorShapesMatch(inputTensorInfo0,
2742 outputTensorInfo0,
2743 descriptorName,
2744 "input_0",
2745 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002746
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002747 ValidateTensorShapesMatch(inputTensorInfo0,
2748 outputTensorInfo1,
2749 descriptorName,
2750 "input_0",
2751 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002752}
2753
Derek Lamberti901ea112019-12-10 22:07:09 +00002754void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002755{
2756 // This is internally generated so it should not need validation.
2757}
2758
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002759void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2760{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002761 const std::string& descriptorName{"PreluQueueDescriptor"};
2762
2763 ValidateNumInputs(workloadInfo, descriptorName, 2);
2764 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2765
2766 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2767 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2768 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002769
2770 std::vector<DataType> supportedTypes
2771 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002772 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002773 DataType::Float16,
2774 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002775 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002776 DataType::QAsymmU8,
2777 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002778 };
2779
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002780 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2781 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002782
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002783 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002784
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002785 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2786 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002787
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002788 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2789 alphaTensorInfo,
2790 outputTensorInfo,
2791 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002792 "input",
2793 "alpha");
2794}
2795
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002796void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2797{
2798 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2799
2800 ValidateNumInputs(workloadInfo, descriptorName, 1);
2801 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2802
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002803 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2804 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2805
2806 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2807 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002808
2809 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002810
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002811 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2812 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002813
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002814 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2815
2816 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002817 if (m_Parameters.m_BiasEnabled)
2818 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002819 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002820
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002821 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2822 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002823
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002824 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002825 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002826 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002827
2828 ValidatePerAxisQuantization(inputTensorInfo,
2829 outputTensorInfo,
2830 weightTensorInfo,
2831 optionalBiasTensorInfo,
2832 descriptorName);
2833
2834 std::vector<DataType> supportedTypes =
2835 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002836 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002837 DataType::Float32,
2838 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002839 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002840 DataType::QAsymmU8,
2841 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002842 };
2843
2844 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2845 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002846}
2847
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002848void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2849{
2850 const std::string descriptorName{"TransposeQueueDescriptor"};
2851
2852 ValidateNumInputs(workloadInfo, descriptorName, 1);
2853 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2854
2855 const PermutationVector& mapping = m_Parameters.m_DimMappings;
2856
2857 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2858 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2859
2860 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
2861 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
2862
2863 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
2864 {
2865 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
2866 {
2867 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
2868 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
2869 "must match dst dimension " + to_string(i) +
2870 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
2871 }
2872 }
2873
2874 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2875}
2876
James Conroy4f1f8992020-04-29 20:01:10 +01002877void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2878{
2879 const std::string descriptorName{"QLstmQueueDescriptor"};
2880
2881 // Validate number of inputs/outputs
2882 ValidateNumInputs(workloadInfo, descriptorName, 3);
2883 ValidateNumOutputs(workloadInfo, descriptorName, 3);
2884
2885 // Input/output tensor info
2886 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2887 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
2888 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
2889
2890 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2891 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2892 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
2893
2894 // Supported types for various tensors in QLSTM
2895 std::vector<DataType> inputOutputSupportedTypes =
2896 {
2897 DataType::QAsymmS8
2898 };
2899
2900 std::vector<DataType> cellStateSupportedTypes =
2901 {
2902 DataType::QSymmS16
2903 };
2904
2905 std::vector<DataType> weightsSupportedTypes =
2906 {
2907 DataType::QSymmS8
2908 };
2909
2910 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
2911 {
2912 DataType::QSymmS16
2913 };
2914
2915 std::vector<DataType> biasSupportedTypes =
2916 {
2917 DataType::Signed32
2918 };
2919
2920 // Validate types of input/output tensors
2921 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2922 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2923 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2924
2925 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2926 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2927 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
2928
2929 // Validate matching types of input/output tensors
2930 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2931 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2932 "outputStateIn", "outputStateOut");
2933 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2934
2935 // Infer number of batches, number of units, input size and output size from tensor dimensions
2936 const uint32_t numBatches = inputInfo.GetShape()[0];
2937 const uint32_t inputSize = inputInfo.GetShape()[1];
2938 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
2939 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
2940
2941 // Validate number of dimensions and number of elements for input/output tensors
2942 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2943 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2944 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
2945
2946 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2947 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
2948 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
2949
2950 // Validate number of dimensions and number of elements for MANDATORY weight tensors
2951 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2952 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2953 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
2954
2955 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2956 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2957 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
2958
2959 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2960 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2961 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
2962
2963 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2964 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2965 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
2966 " RecurrentToForgetWeights");
2967
2968 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2969 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2970 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
2971
2972 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2973 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2974 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
2975
2976 // Validate data types for MANDATORY weights tensors (all should match each other)
2977 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
2978
2979 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
2980 "inputToForgetWeights", "inputToCellWeights");
2981 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2982 "inputToForgetWeights", "inputToOutputWeights");
2983
2984 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2985 "inputToForgetWeights", "recurrentToForgeteights");
2986 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2987 "inputToForgetWeights", "recurrentToCellWeights");
2988 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2989 "inputToForgetWeights", "recurrentToOutputWeights");
2990
2991 // Validate number of dimensions and number of elements for MANDATORY bias tensors
2992 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2993 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2994 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
2995
2996 ValidatePointer(m_CellBias, descriptorName, "CellBias");
2997 auto cellBiasInfo = m_CellBias->GetTensorInfo();
2998 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
2999
3000 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3001 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3002 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3003
3004 // Validate data types for MANDATORY bias tensors
3005 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3006
3007 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3008 "forgetGateBias", "cellBias");
3009 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3010 "forgetGateBias", "outputGateBias");
3011
3012 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3013 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3014 !m_Parameters.m_CifgEnabled) ||
3015 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3016 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3017
3018 if (!allCifgParamsPresentOrNot)
3019 {
3020 throw InvalidArgumentException(descriptorName +
3021 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3022 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3023 "set appropriately.");
3024 }
3025
3026 if (!m_Parameters.m_CifgEnabled)
3027 {
3028 // Validate number of dimensions and number of elements
3029 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3030 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3031
3032 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3033 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3034 " RecurrentToInputWeights");
3035
3036 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3037 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3038
3039 // Validate data types
3040 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3041 "inputToForgetWeights", "inputToInputWeights");
3042 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3043 "inputToForgetWeights", "recurrentToInputWeights");
3044 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3045 "forgetGateBias", "inputGateBias");
3046 }
3047
3048 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3049 bool allPeepholeWeightsPresentOrNot =
3050 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3051 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3052 || (!m_CellToInputWeights && !m_CellToForgetWeights
3053 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3054
3055 if (!allPeepholeWeightsPresentOrNot)
3056 {
3057 throw InvalidArgumentException(descriptorName +
3058 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3059 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3060 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3061 "appropriately.");
3062 }
3063
3064 if (m_Parameters.m_PeepholeEnabled)
3065 {
3066 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3067 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3068 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3069
3070 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3071 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3072 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3073 "cellToForgetWeight", "cellToOutputWeights");
3074
3075 if (!m_Parameters.m_CifgEnabled)
3076 {
3077 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3078 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3079 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3080 "cellToForgetWeights", "cellToInputWeights");
3081 }
3082 }
3083
3084 // Validate OPTIONAL params: Layer Norm Weights
3085 bool allLayerNormWeightsPresentOrNot =
3086 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3087 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3088 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3089 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3090
3091 if (!allLayerNormWeightsPresentOrNot)
3092 {
3093 throw InvalidArgumentException(descriptorName +
3094 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3095 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3096 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3097 "only be present when Layer Norm is enabled and CIFG is disabled. "
3098 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3099 }
3100
3101 if (m_Parameters.m_LayerNormEnabled)
3102 {
3103 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3104 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3105 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3106
3107 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3108 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3109 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3110 "forgetLayerNormWeights", "cellLayerNormWeights");
3111
3112 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3113 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3114 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3115 "forgetLayerNormWeights", "outputLayerNormWeights");
3116
3117 if (!m_Parameters.m_CifgEnabled)
3118 {
3119 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3120 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3121 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3122 "forgetLayerNormWeights", "inputLayerNormWeights");
3123 }
3124 }
3125
3126 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3127 bool correctProjectionTensorsPresent =
3128 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3129 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3130 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3131
3132 if (!correctProjectionTensorsPresent)
3133 {
3134 throw InvalidArgumentException(descriptorName +
3135 ": If projection is enabled, ProjectionWeights should be present and "
3136 "ProjectionBias is optional. If projection is disabled, neither "
3137 "ProjectionWeights nor ProjectionBias should be present.");
3138 }
3139
3140 if (m_Parameters.m_ProjectionEnabled)
3141 {
3142 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3143 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3144 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3145
3146 if (m_ProjectionBias)
3147 {
3148 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003149 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003150 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3151 }
3152
3153 }
3154 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3155 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3156 throw InvalidArgumentException(descriptorName +
3157 ": If projection is disabled, output quantization info (scale, offset) "
3158 "should match HiddenStateScale and HiddenStateZeroPoint.");
3159 }
3160
3161}
3162
James Conroy9c3cae82019-08-01 16:01:48 +01003163void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3164{
3165 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3166
3167 // Validate number of inputs/outputs
3168 ValidateNumInputs(workloadInfo, descriptorName, 3);
3169 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3170
3171 // Input/output tensor infos
3172 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3173 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3174 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3175
3176 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3177 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3178
3179 std::vector<DataType> inputOutputSupportedTypes =
3180 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003181 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003182 };
3183
3184 std::vector<DataType> cellStateSupportedTypes =
3185 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003186 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003187 };
3188
3189 std::vector<DataType> weightsSupportedTypes =
3190 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003191 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003192 };
3193
3194 std::vector<DataType> biasSupportedTypes =
3195 {
3196 DataType::Signed32
3197 };
3198
3199 // Validate types of input/output tensors
3200 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3201 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3202 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3203
3204 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3205 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3206
3207 // Validate matching types of input/output tensors
3208 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3209 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3210 "outputStateIn", "outputStateOut");
3211 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3212
3213 // Validate matching quantization info for input/output tensors
3214 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3215 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3216 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003217
James Conroy9c3cae82019-08-01 16:01:48 +01003218 // Infer number of batches, input size and output size from tensor dimensions
3219 const uint32_t numBatches = inputInfo.GetShape()[0];
3220 const uint32_t inputSize = inputInfo.GetShape()[1];
3221 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3222
3223 // Validate number of dimensions and number of elements for input/output tensors
3224 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3225 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3226 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3227 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3228 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3229
3230 // Validate number of dimensions and number of elements for weights tensors
3231 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3232 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3233 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3234
3235 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3236 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3237 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3238
3239 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3240 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3241 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3242
3243 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3244 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3245 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3246
3247 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3248 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3249 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3250
3251 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3252 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3253 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3254 " RecurrentToForgetWeights");
3255
3256 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3257 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3258 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3259
3260 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3261 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3262 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3263
3264 // Validate data types for weights tensors (all should match each other)
3265 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3266
3267 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3268 "inputToInputWeights", "inputToForgetWeights");
3269 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3270 "inputToInputWeights", "inputToCellWeights");
3271 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3272 "inputToInputWeights", "inputToOutputWeights");
3273
3274 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3275 "inputToInputWeights", "recurrentToInputWeights");
3276 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3277 "inputToInputWeights", "recurrentToForgeteights");
3278 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3279 "inputToInputWeights", "recurrentToCellWeights");
3280 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3281 "inputToInputWeights", "recurrentToOutputWeights");
3282
3283 // Validate matching quantization info for weight tensors (all should match each other)
3284 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3285 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3286 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3287 descriptorName, "inputToInputWeights", "inputToCellWeights");
3288 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3289 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3290
3291 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3292 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3293 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3294 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3295 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3296 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3297 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3298 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3299
3300 // Validate number of dimensions and number of elements in bias tensors
3301 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3302 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3303 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3304
3305 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3306 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3307 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3308
3309 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3310 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3311 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3312
3313 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3314 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3315 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3316
3317 // Validate data types for bias tensors (all should match each other)
3318 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3319
3320 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3321 "inputGateBias", "forgetGateBias");
3322 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3323 "inputGateBias", "cellBias");
3324 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3325 "inputGateBias", "outputGateBias");
3326
3327 // Validate bias tensor quantization info
3328 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3329 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3330 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3331 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3332}
3333
Kevin May868eb142019-09-04 17:29:31 +01003334void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3335{
3336 const std::string descriptorName{"AbsQueueDescriptor"};
3337
3338 ValidateNumInputs(workloadInfo, descriptorName, 1);
3339 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3340
3341 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3342 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3343
3344 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3345
3346 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003347 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003348 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003349 DataType::Float16,
3350 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003351 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003352 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003353 DataType::QSymmS16,
3354 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003355 };
Kevin May868eb142019-09-04 17:29:31 +01003356
3357 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3358 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3359}
3360
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003361void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3362{
3363 const std::string descriptorName{"SliceQueueDescriptor"};
3364
3365 ValidateNumInputs(workloadInfo, descriptorName, 1);
3366 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3367
3368 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3369 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3370
3371 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3372
3373 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3374 if (rank > 4)
3375 {
3376 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3377 }
3378
3379 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3380
3381 // Check if m_Begin and m_Size have the expected length
3382 if (m_Parameters.m_Begin.size() != rank)
3383 {
3384 throw InvalidArgumentException(descriptorName +
3385 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3386 }
3387 if (m_Parameters.m_Size.size() != rank)
3388 {
3389 throw InvalidArgumentException(descriptorName +
3390 ": Length of size descriptor must equal rank " + std::to_string(rank));
3391 }
3392
3393 // Check if the shape of the output tensor matches m_Size
3394 const TensorShape& outputShape = outputTensorInfo.GetShape();
3395 for (unsigned int i = 0u; i < rank; ++i)
3396 {
3397 if (m_Parameters.m_Size[i] != outputShape[i])
3398 {
3399 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3400 }
3401 }
3402
3403 // Check if the sum of begin offset and size in a given dimension
3404 // does not exceed the size of corresponding input
3405 const TensorShape& inputShape = inputTensorInfo.GetShape();
3406 for(unsigned int i = 0u; i < rank; ++i)
3407 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003408 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003409 {
3410 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3411 std::to_string(i) + " exceeds input size.");
3412 }
3413 }
3414}
3415
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003416void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3417{
3418 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3419
3420 ValidateNumInputs(workloadInfo, descriptorName, 1);
3421 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3422
3423 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3424 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3425
3426 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3427 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3428
3429 std::vector<DataType> supportedTypes =
3430 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003431 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003432 DataType::Float32,
3433 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003434 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003435 DataType::QAsymmU8,
3436 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003437 };
3438
3439 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3440 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3441
3442 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3443
3444 if (m_Parameters.m_BlockSize == 0)
3445 {
3446 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3447 }
3448
3449 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3450 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3451 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3452 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3453
3454 const TensorShape& outputShape = outputInfo.GetShape();
3455 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3456 {
3457 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3458 "must be divisible by block size.");
3459 }
3460
3461 const TensorShape& inputShape = inputInfo.GetShape();
3462 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3463 {
3464 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3465 "must be divisible by the square of block size." );
3466 }
3467}
3468
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003469void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3470{
3471 const std::string descriptorName{"ComparisonQueueDescriptor"};
3472
3473 ValidateNumInputs(workloadInfo, descriptorName, 2);
3474 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3475
3476 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3477 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3478 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3479
3480 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3481 inputTensorInfo1,
3482 outputTensorInfo,
3483 descriptorName,
3484 "input_0",
3485 "input_1");
3486
3487 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3488 {
3489 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3490 }
3491}
3492
josh minor4a3c6102020-01-06 16:40:46 -06003493void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3494{
3495 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3496
3497 ValidateNumInputs(workloadInfo, descriptorName, 1);
3498 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3499
3500 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3501 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3502
3503 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3504
3505 std::vector<DataType> supportedTypes =
3506 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003507 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003508 DataType::Float16,
3509 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003510 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003511 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003512 DataType::QSymmS16,
3513 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003514 };
3515
3516 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3517 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3518}
3519
Finn Williams2605b232020-06-10 15:53:46 +01003520void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3521{
3522 const std::string descriptorName{"RankQueueDescriptor"};
3523
3524 ValidateNumInputs(workloadInfo, descriptorName, 1);
3525 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3526
3527 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3528 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3529
3530 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3531 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3532
3533 std::vector<DataType> supportedTypes =
3534 {
3535 DataType::BFloat16,
3536 DataType::Float16,
3537 DataType::Float32,
3538 DataType::QAsymmS8,
3539 DataType::QAsymmU8,
3540 DataType::QSymmS8,
3541 DataType::QSymmS16,
3542 DataType::Signed32
3543 };
3544
3545 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3546 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3547}
3548
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003549} // namespace armnn