blob: f968ad78f7acebac24e132ee1e8de4ef5c295ff7 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00006#include <backendsCommon/WorkloadData.hpp>
7#include <backendsCommon/CpuTensorHandle.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/TensorUtils.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000013#include <string>
14#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/format.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010017#include <boost/numeric/conversion/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
telsoa01c577f2c2018-08-31 09:22:23 +010029 case DataType::Float16:
30 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000031 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000032 case DataType::Float32:
33 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000034 case DataType::QAsymmS8:
35 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000036 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000037 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000038 case DataType::QSymmS8:
39 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000040 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010041 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000042 default:
43 BOOST_ASSERT_MSG(false, "Invalid input data type");
44 return DataType::Float32;
45 }
46}
47
48namespace
49{
50
51//---------------------------------------------------------------
52//android ndk does not support std::to_string function.
53template <typename T>
54std::string to_string(T value)
55{
56 std::ostringstream os;
57 os << value;
58 return os.str();
59}
60
61//---------------------------------------------------------------
62void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
63{
64 if (!ptr)
65 {
66 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
67 paramName + " parameter must be set.");
68 }
69}
70
71//---------------------------------------------------------------
72void ValidateTensorShapesMatch(const TensorInfo& first,
73 const TensorInfo& second,
74 std::string const& descName,
75 std::string const& firstName,
76 std::string const& secondName)
77{
78 if (first.GetShape() != second.GetShape())
79 {
80 throw InvalidArgumentException(descName + ": "
81 + firstName + " & " + secondName + " must have identical shapes");
82 }
83}
84
85//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010086void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000087{
Sadik Armaganeff363d2019-04-05 15:25:46 +010088 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000089 {
90 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010091 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000092 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
93 }
94}
95
96//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010097void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000098{
Sadik Armaganeff363d2019-04-05 15:25:46 +010099 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000100 {
101 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100102 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000103 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
104 }
105}
106
107//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100108void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000109 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100110 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000111 std::string const& tensorName)
112{
113 if (tensor.GetNumDimensions() != numDimensions)
114 {
115 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
116 to_string(tensor.GetNumDimensions()) + " dimensions for " +
117 tensorName + " tensor.");
118 }
119}
120
121//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100122void ValidateTensorNumElements(const TensorInfo& tensor,
123 std::string const& descName,
124 unsigned int numElements,
125 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100126{
127 if (tensor.GetNumElements() != numElements)
128 {
129 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100130 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100131 tensorName + " tensor.");
132 }
133}
134
135//---------------------------------------------------------------
136void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100137 unsigned int numDimension,
138 unsigned int numElements,
139 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100140{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100141 const std::string functionName{"ValidateTensorNumDimNumElem"};
142 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
143 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100144}
145
146//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000147void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
148 const std::string& descName, std::string const& tensorName)
149{
150 if (tensor.GetDataType() != dataType)
151 {
152 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
153 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
154 }
155}
156
Derek Lambertid466a542020-01-22 15:37:29 +0000157void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
158{
159 ARMNN_NO_DEPRECATE_WARN_BEGIN
160 if (tensor.GetDataType() != DataType::QSymmS8 &&
161 tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
162 {
163 throw InvalidArgumentException(descName +
164 ": Expected data type which supports per-axis quantization scheme but got " +
165 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
166 }
167 ARMNN_NO_DEPRECATE_WARN_END
168}
169
telsoa014fcda012018-03-09 14:13:49 +0000170//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100171void ValidateTensorQuantizationSpace(const TensorInfo& first,
172 const TensorInfo& second,
173 const std::string& descName,
174 std::string const& firstName,
175 std::string const& secondName)
176{
177 if (!first.IsQuantized() ||
178 !second.IsQuantized())
179 {
180 // Not a quantized type, ignore the validation
181 return;
182 }
183
184 DataType firstDataType = first.GetDataType();
185 DataType secondDataType = second.GetDataType();
186
187 if (firstDataType != secondDataType)
188 {
189 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
190 " must be of the same quantized type, " +
191 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
192 secondName + " is " + GetDataTypeName(secondDataType));
193 }
194
195 if (!first.IsTypeSpaceMatch(second))
196 {
197 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
198 " must have the same quantization space, " +
199 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
200 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
201 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
202 " and scale " + to_string(second.GetQuantizationScale()));
203 }
204}
205
206//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100207void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
208 const TensorInfo& inputTensorInfo,
209 const TensorInfo& weightsTensorInfo,
210 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000211{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000212 // Helper lambda function to validate a single bias quantization scale value
213 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
214 {
ricbur013f4d7102019-10-31 16:22:18 +0000215 constexpr float tolerance = 0.000001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000216 if (std::abs(biasScale - expectedScale) > tolerance)
217 {
218 // Print the float values with extra precision to see very small differences
219 std::stringstream msg;
220 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
221 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
222 biasScale;
223 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
224 }
225 };
226
telsoa014fcda012018-03-09 14:13:49 +0000227 if (biasTensor.GetQuantizationOffset() != 0)
228 {
229 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
230 to_string(biasTensor.GetQuantizationOffset()));
231 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000232
233 if (biasTensor.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000234 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000235 // Validate per-axis quantization scales
236 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
237 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
238
239 if (weightScales.size() != biasScales.size())
240 {
241 std::stringstream msg;
242 msg << descName << ": Expected matchhing number of per-axis quantization scales, but got different "
243 << "values: weights=" << weightScales.size() << ", biases=" << biasScales.size();
244 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
245 }
246
247 for (size_t i = 0ul; i < biasScales.size(); ++i)
248 {
249 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
250 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
251 }
252 }
253 else
254 {
255 // Validate per-tensor quantization scale
256 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
257 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000258 }
259}
260
261//---------------------------------------------------------------
262void ValidateTensors(const std::vector<ITensorHandle*>& vec,
263 unsigned int numExpected,
264 const std::string& descName,
265 const std::string& varName)
266{
267 if (vec.empty() && numExpected > 0)
268 {
269 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
270 }
271
272 for (unsigned int i = 0; i < numExpected; ++i)
273 {
274 if (!vec[i])
275 {
276 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
277 }
278 }
279}
280
281//---------------------------------------------------------------
282void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
283 const TensorInfo& second,
284 const TensorInfo& output,
285 std::string const& descName,
286 std::string const& firstName,
287 std::string const& secondName)
288{
289 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
290 // broadcasted.
291 if (first.GetNumDimensions() != second.GetNumDimensions())
292 {
293 throw InvalidArgumentException(descName + ": Tensors "
294 + firstName + " & " + secondName
295 + " must have the same number of dimensions in order to be broadcasted");
296 }
297 uint32_t numDims = first.GetNumDimensions();
298 std::vector<uint32_t> outputDims(numDims, 0u);
299 for (uint32_t i = 0; i < numDims; i++)
300 {
301 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
302 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
303 if (dimsNotEqual && dimsNotOne)
304 {
305 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
306 }
307 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
308 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100309 TensorShape broadcastShape = TensorShape(boost::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000310 if (broadcastShape != output.GetShape())
311 {
312 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
313 + firstName + " & " + secondName
314 + " does not match the output shape");
315 }
316}
317
318//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100319void ValidateDataTypes(const TensorInfo& info,
320 const std::vector<armnn::DataType>& supportedTypes,
321 std::string const& descName)
322{
323 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
324 if (iterator == supportedTypes.end())
325 {
326 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
327 }
328}
329
James Conroy4d1ff582019-06-10 17:06:39 +0100330//---------------------------------------------------------------
331void ValidateTensorDataTypesMatch(const TensorInfo& first,
332 const TensorInfo& second,
333 std::string const& descName,
334 std::string const& firstName,
335 std::string const& secondName)
336{
337 if (first.GetDataType() != second.GetDataType())
338 {
339 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
340 " must have identical data types.");
341 }
342}
343
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100344//---------------------------------------------------------------
345void ValidateTensorNumElementsMatch(const TensorInfo& first,
346 const TensorInfo& second,
347 std::string const& descName,
348 std::string const& firstName,
349 std::string const& secondName)
350{
351 if (first.GetNumElements() != second.GetNumElements())
352 {
353 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
354 " must have the same number of elements.");
355 }
356}
357
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000358void ValidateWeightDataType(const TensorInfo& inputInfo,
359 const TensorInfo& weightInfo,
360 const std::string& descName)
361{
362 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000363 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000364 {
Derek Lambertid466a542020-01-22 15:37:29 +0000365 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000366 const std::vector<DataType> validTypes =
367 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000368 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000369 DataType::QAsymmS8,
Derek Lambertid466a542020-01-22 15:37:29 +0000370 DataType::QSymmS8,
371 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000372 };
Derek Lambertid466a542020-01-22 15:37:29 +0000373 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000374
375 ValidateDataTypes(weightInfo, validTypes, descName);
376 }
377 else
378 {
379 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
380 }
381}
382
383void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
384 const std::string& descName,
385 const std::string& tensorName)
386{
387 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
388 if (!quantizationDim.has_value())
389 {
390 throw InvalidArgumentException(boost::str(
391 boost::format("%1%: Quantization dimension for per-axis quantization not set on tensor %2%.")
392 % descName % tensorName));
393 }
394
395 if (quantizationDim.value() != 0)
396 {
397 throw InvalidArgumentException(boost::str(
398 boost::format("%1%: Quantization dimension for per-axis quantization expected to be 0 on tensor %2%, "
399 "but got: %3%") % descName % tensorName % quantizationDim.value()));
400 }
401}
402
403void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
404 const std::string& descName,
405 const std::string& tensorName)
406{
407 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
408 if (quantizationOffset != 0)
409 {
410 throw InvalidArgumentException(boost::str(
411 boost::format("%1%: Quantization offset for per-axis quantization expected to be 0 on tensor %2%, "
412 "but got: %3%") % descName % tensorName % quantizationOffset));
413 }
414}
415
416void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
417 const TensorInfo& outputInfo,
418 const TensorInfo& weightInfo,
419 const Optional<TensorInfo>& optionalBiasInfo,
420 const std::string& descName)
421{
422 if (weightInfo.HasPerAxisQuantization())
423 {
424 const DataType inputDataType = inputInfo.GetDataType();
425 const DataType outputDataType = outputInfo.GetDataType();
426
Keith Davis0c2eeac2020-02-11 16:51:50 +0000427 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000428
429 if (!canHavePerAxisQuantization)
430 {
431 throw InvalidArgumentException(boost::str(
432 boost::format("%1%: Per-axis quantization parameters set on tensor %2%, "
433 "but data type does not support per-axis quantization.") % descName % "weight"));
434 }
435
Derek Lambertid466a542020-01-22 15:37:29 +0000436
437 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000438 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
439 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
440
441 if (optionalBiasInfo.has_value())
442 {
443 const TensorInfo& biasInfo = optionalBiasInfo.value();
444 if (!biasInfo.HasPerAxisQuantization())
445 {
446 throw InvalidArgumentException(boost::str(
447 boost::format("%1%: Per-axis quantization parameters not set on bias tensor, despite being set on "
448 "weight tensor.") % descName));
449 }
450
451 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
452 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
453 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
454 }
455 }
456}
457
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100458} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000459
460void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
461 unsigned int numExpectedIn, unsigned int numExpectedOut) const
462{
463 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
464 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
465}
466
467//---------------------------------------------------------------
468void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
469{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100470 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000471
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100472 ValidateNumInputs(workloadInfo, descriptorName, 1);
473 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000474
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100475 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
476 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
477
478 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
479 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000480
481 if (m_Inputs.size() != m_Outputs.size())
482 {
483 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100484 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
485 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000486 }
487
488 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
489 {
490 if (!m_Inputs[i])
491 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100492 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
493 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000494 }
495
496 if (!m_Outputs[i])
497 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100498 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
499 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000500 }
501 }
502}
503
Derek Lambertif674aa02019-08-01 15:56:25 +0100504//---------------------------------------------------------------
505void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
506{
507 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
508 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
509
510 if (workloadInfo.m_InputTensorInfos.size() != 1)
511 {
512 throw InvalidArgumentException(boost::str(
513 boost::format("Number of input infos (%1%) is not 1.")
514 % workloadInfo.m_InputTensorInfos.size()));
515
516 }
517
518 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
519 {
520 throw InvalidArgumentException(boost::str(
521 boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
522 % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
523 }
524
525 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
526 {
527 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
528 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
529 {
530 throw InvalidArgumentException(boost::str(
531 boost::format("Number of elements for tensor input and output %1% does not match")
532 % i ));
533 }
534 }
535
536 if (m_Inputs.size() != 1)
537 {
538 throw InvalidArgumentException(boost::str(
539 boost::format("Number of inputs (%1%) is not 1.")
540 % m_Inputs.size()));
541 }
542
543 if (m_Inputs.size() != m_Outputs.size())
544 {
545 throw InvalidArgumentException(boost::str(
546 boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
547 % m_Inputs.size() % m_Outputs.size()));
548 }
549
550 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
551 {
552 if (!m_Inputs[i])
553 {
554 throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
555 }
556
557 if (!m_Outputs[i])
558 {
559 throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
560 }
561 }
562}
563
564//---------------------------------------------------------------
565void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
566{
567 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
568 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
569
Derek Lambertif674aa02019-08-01 15:56:25 +0100570 if (m_Inputs.size() != 1)
571 {
572 throw InvalidArgumentException(boost::str(
573 boost::format("Number of inputs (%1%) is not 1.")
574 % m_Inputs.size()));
575 }
576
577 if (m_Outputs.size() != 0)
578 {
579 throw InvalidArgumentException(boost::str(
580 boost::format("Number of outputs (%1%) is not 0.")
581 % m_Inputs.size() % m_Outputs.size()));
582 }
583
584 if (!m_Inputs[0])
585 {
586 throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
587 }
588}
589
590//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000591void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
592{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100593 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100594
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100595 ValidateNumInputs(workloadInfo, descriptorName, 1);
596 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100597
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100598 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
599 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100600
601 std::vector<DataType> supportedTypes =
602 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000603 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100604 DataType::Float16,
605 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000606 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000607 DataType::QAsymmU8,
608 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100609 };
610
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100611 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
612 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
613 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000614}
615
Nikhil Rajee391d52019-09-05 17:50:44 +0100616void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
617{
618 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
619
620 ValidateNumInputs(workloadInfo, descriptorName, 1);
621 ValidateNumOutputs(workloadInfo, descriptorName, 1);
622
623 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
624 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
625
Nikhil Raj68c2c902019-09-19 11:21:11 +0100626 if (outputTensorInfo.GetDataType() != DataType::Signed32)
627 {
628 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32.");
629 }
630
James Conroyd47a0642019-09-17 14:22:06 +0100631 std::vector<DataType> supportedInputTypes =
632 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000633 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100634 DataType::Float16,
635 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000636 DataType::QAsymmU8,
637 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000638 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +0100639 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100640
James Conroyd47a0642019-09-17 14:22:06 +0100641 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100642
643 auto inputShape = inputTensorInfo.GetShape();
644 auto outputShape = outputTensorInfo.GetShape();
645
646 auto inputNumDimensions = inputShape.GetNumDimensions();
647 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
648
649 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
650
651 // 1D input shape results in scalar output shape
652 if (inputShape.GetNumDimensions() == 1)
653 {
654 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
655 {
656 throw InvalidArgumentException(descriptorName + outputShapeError);
657 }
658 }
659 else
660 {
661 for (unsigned int i = 0; i < unsignedAxis; ++i)
662 {
663 if (outputShape[i] != inputShape[i])
664 {
665 throw InvalidArgumentException(descriptorName + outputShapeError);
666 }
667 }
668
669 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
670 {
671 if (outputShape[i - 1] != inputShape[i])
672 {
673 throw InvalidArgumentException(descriptorName + outputShapeError);
674 }
675 }
676 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100677}
678
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100679void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
680{
681 const std::string descriptorName{"SoftmaxQueueDescriptor"};
682
683 ValidateNumInputs(workloadInfo, descriptorName, 1);
684 ValidateNumOutputs(workloadInfo, descriptorName, 1);
685
686 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
687 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
688
689 std::vector<DataType> supportedTypes =
690 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000691 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100692 DataType::Float16,
693 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000694 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000695 DataType::QAsymmU8,
696 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100697 };
698
699 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
700 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
701 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
702}
703
telsoa014fcda012018-03-09 14:13:49 +0000704void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
705{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100706 const std::string descriptorName{"SplitterQueueDescriptor"};
707
708 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000709
Ruomei Yan25339c32019-05-28 16:48:20 +0100710 // Check the supported data types
711 std::vector<DataType> supportedTypes =
712 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000713 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100714 DataType::Float32,
715 DataType::Float16,
716 DataType::Boolean,
717 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000718 DataType::QAsymmU8,
719 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100720 };
721
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100722 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
723 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100724 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100725 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
726 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
727
728 const std::string outputName = "output_" + std::to_string(i);
729 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100730 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100731
telsoa014fcda012018-03-09 14:13:49 +0000732 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
733 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100734 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000735 }
736
737 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
738 {
739 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100740 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000741 "has to match number of workloadInfo.m_OutputTensorInfos. "
742 "Number of windows: " +
743 to_string(m_ViewOrigins.size()) +
744 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
745 }
746
telsoa01c577f2c2018-08-31 09:22:23 +0100747 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000748 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
749 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
750 {
telsoa01c577f2c2018-08-31 09:22:23 +0100751 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000752 ViewOrigin const& e = m_ViewOrigins[w];
753 if (e.m_Origin.size() != inputDims)
754 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100755 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000756 "have the same dimensionality as the input tensor. "
757 "Window origin (index: " +
758 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
759 " dimensions, the input "
760 "tensor has " +
761 to_string(inputDims) + " dimensions.");
762 }
763 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
764 {
765 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
766 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
767 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100768 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000769 "be smaller or equal than the size of the input in that coord.");
770 }
771 }
772 }
773}
774
Jim Flynne242f2d2019-05-22 14:24:13 +0100775void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000776{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100777 const std::string descriptorName{"ConcatQueueDescriptor"};
778
779 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000780
781 if (m_Inputs.size() <= 0)
782 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100783 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000784 }
785 if (m_Outputs.size() <= 0)
786 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100787 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000788 }
789
790 if (workloadInfo.m_InputTensorInfos.size() <= 0)
791 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100792 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000793 }
794 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
795 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100796 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000797 }
798
Nikhil Raj8599a412018-11-19 14:51:07 +0000799 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
800 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100801 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000802 }
803
804 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
805 {
806 return;
807 }
808
telsoa014fcda012018-03-09 14:13:49 +0000809 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
810 {
811 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100812 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000813 "has to match number of workloadInfo.m_InputTensorInfos. "
814 "Number of windows: " +
815 to_string(m_ViewOrigins.size()) +
816 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
817 }
818
telsoa01c577f2c2018-08-31 09:22:23 +0100819 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000820 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
821 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
822 {
telsoa01c577f2c2018-08-31 09:22:23 +0100823 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000824 ViewOrigin const& e = m_ViewOrigins[w];
825 if (e.m_Origin.size() != outputDims)
826 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100827 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000828 "have the same dimensionality as the output tensor. "
829 "Window origin (index: " +
830 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
831 " dimensions, the output "
832 "tensor has " +
833 to_string(outputDims) + " dimensions.");
834 }
telsoa01c577f2c2018-08-31 09:22:23 +0100835 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000836 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
837 {
838 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
839 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
840 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100841 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000842 "be smaller or equal than the size of the output in that coord.");
843 }
844 }
845 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100846
847 // Check the supported data types
848 std::vector<DataType> supportedTypes =
849 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000850 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100851 DataType::Float32,
852 DataType::Float16,
853 DataType::Boolean,
854 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000855 DataType::QAsymmU8,
856 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100857 };
858
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100859 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
860 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100861 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100862 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
863 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
864
865 const std::string inputName = "input_" + std::to_string(i);
866 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100867 }
telsoa014fcda012018-03-09 14:13:49 +0000868}
869
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100870void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
871{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100872 const std::string descriptorName{"StackQueueDescriptor"};
873
874 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100875
876 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
877 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100878 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100879 }
880
881 // All inputs must have the same shape, which is defined in parameters
882 const TensorShape& inputShape = m_Parameters.m_InputShape;
883 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
884 {
885 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
886 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100887 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100888 }
889 }
890
Matthew Jacksondba634f2019-08-15 15:14:18 +0100891 if (inputShape.GetNumDimensions() > 4)
892 {
893 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
894 }
895
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100896 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
897 // since the output tensor has an additional dimension.
898 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
899 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100900 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100901 "than the number of input dimensions.");
902 }
903
904 // Output shape must be as inferred from the input shape
905 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
906 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
907 {
908 if (outputShape[i] != inputShape[i])
909 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100910 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100911 "match shape inferred from input tensor.");
912 }
913 }
914
915 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
916 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100917 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100918 "match shape inferred from input tensor.");
919 }
920
921 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
922 {
923 if (outputShape[i] != inputShape[i-1])
924 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100925 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100926 "match shape inferred from input tensor.");
927 }
928 }
929
Matthew Jacksondba634f2019-08-15 15:14:18 +0100930 if (outputShape.GetNumDimensions() > 5)
931 {
932 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
933 }
934
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100935 // Check the supported data types
936 std::vector<DataType> supportedTypes =
937 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000938 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100939 DataType::Float32,
940 DataType::Float16,
941 DataType::Boolean,
942 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000943 DataType::QAsymmU8,
944 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100945 };
946
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100947 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100948
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100949 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100950 {
951 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
952 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100953 descriptorName,
954 "input_0",
955 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100956 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100957
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100958 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
959 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100960 descriptorName,
961 "input_0",
962 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100963}
964
telsoa014fcda012018-03-09 14:13:49 +0000965void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
966{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100967 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000968
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100969 ValidateNumInputs(workloadInfo, descriptorName, 1);
970 ValidateNumOutputs(workloadInfo, descriptorName, 1);
971
972 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
973 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
974
975 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
976
977 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +0000978 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100979 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +0000980 }
981
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100982 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000983
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100984 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
985 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000986
987 if (m_Parameters.m_BiasEnabled)
988 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100989 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000990
telsoa01c577f2c2018-08-31 09:22:23 +0100991 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100992 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
993 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000994
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100995 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
996 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000997 }
998
Francis Murtagh46c09d02019-05-28 08:15:28 +0100999 // Check the supported data types
1000 std::vector<DataType> supportedTypes =
1001 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001002 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001003 DataType::Float32,
1004 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001005 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001006 DataType::QAsymmU8,
1007 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001008 };
1009
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001010 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001011
1012 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1013 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1014 {
1015 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1016 {
1017 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1018 "for BFloat16 input.");
1019 }
1020 }
1021 else
1022 {
1023 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1024 }
telsoa014fcda012018-03-09 14:13:49 +00001025}
1026
telsoa014fcda012018-03-09 14:13:49 +00001027void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1028{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001029 const std::string descriptorName{"NormalizationQueueDescriptor"};
1030
1031 ValidateNumInputs(workloadInfo, descriptorName, 1);
1032 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1033
1034 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1035 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001036
1037 // Check the supported data types
1038 std::vector<DataType> supportedTypes =
1039 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001040 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001041 DataType::Float16,
1042 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001043 DataType::QAsymmU8,
1044 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001045 };
1046
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001047 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001048
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001049 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001050
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001051 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001052}
1053
1054void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1055{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001056 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001057
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001058 ValidateNumInputs(workloadInfo, descriptorName, 2);
1059 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1060
1061 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1062 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1063 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1064
1065 std::vector<DataType> supportedTypes =
1066 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001067 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001068 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001069 DataType::Float16,
1070 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001071 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001072 DataType::QSymmS16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001073 };
1074
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001075 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1076 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1077 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001078
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001079 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1080 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001081
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001082 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1083 inputTensorInfo1,
1084 outputTensorInfo,
1085 descriptorName,
1086 "input_0",
1087 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001088}
1089
telsoa014fcda012018-03-09 14:13:49 +00001090void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1091{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001092 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001093
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001094 ValidateNumInputs(workloadInfo, descriptorName, 2);
1095 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1096
1097 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1098 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1099 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1100
1101 std::vector<DataType> supportedTypes =
1102 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001103 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001104 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001105 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001106 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001107 DataType::QSymmS16,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001108 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001109 };
1110
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001111 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1112 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1113 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001114
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001115 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1116 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001117
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001118 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1119 inputTensorInfo1,
1120 outputTensorInfo,
1121 descriptorName,
1122 "input_0",
1123 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001124}
1125
1126void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1127{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001128 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001129
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001130 ValidateNumInputs(workloadInfo, descriptorName, 1);
1131 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1132
1133 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1134 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001135
1136 std::vector<DataType> supportedTypes =
1137 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001138 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001139 DataType::Float16,
1140 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001141 DataType::QAsymmU8,
1142 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001143 };
1144
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001145 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1146 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001147
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001148 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001149 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001150
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001151 ValidatePointer(m_Mean, descriptorName, "mean");
1152 ValidatePointer(m_Variance, descriptorName, "variance");
1153 ValidatePointer(m_Beta, descriptorName, "beta");
1154 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001155
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001156 const TensorInfo& mean = m_Mean->GetTensorInfo();
1157 const TensorInfo& variance = m_Variance->GetTensorInfo();
1158 const TensorInfo& beta = m_Beta->GetTensorInfo();
1159 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001160
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001161 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1162 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1163 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1164 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001165
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001166 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1167 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1168 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001169}
1170
1171void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1172{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001173 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001174
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001175 ValidateNumInputs(workloadInfo, descriptorName, 1);
1176 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001177
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001178 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1179 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001180
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001181 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1182 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001183
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001184 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001185
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001186 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1187 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001188
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001189 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001190
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001191 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001192 if (m_Parameters.m_BiasEnabled)
1193 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001194 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001195
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001196 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1197 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001198
1199 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1200 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001201 }
1202
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001203 ValidatePerAxisQuantization(inputTensorInfo,
1204 outputTensorInfo,
1205 weightTensorInfo,
1206 optionalBiasTensorInfo,
1207 descriptorName);
1208
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001209 std::vector<DataType> supportedTypes =
1210 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001211 DataType::BFloat16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001212 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001213 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001214 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001215 DataType::QSymmS16,
Keith Davis5204aa82020-01-27 15:24:59 +00001216 DataType::QSymmS8,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001217 DataType::Float16
1218 };
1219
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001220 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001221
1222 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1223 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1224 {
1225 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1226 {
1227 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1228 "for BFloat16 input.");
1229 }
1230 }
1231 else
1232 {
1233 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1234 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001235}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001236
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001237void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1238{
1239 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1240
1241 ValidateNumInputs(workloadInfo, descriptorName, 1);
1242 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1243
1244 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1245 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1246
1247 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1248 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1249
1250 ValidatePointer(m_Weight, descriptorName, "weight");
1251
1252 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1253 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1254
1255 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1256 {
1257 throw InvalidArgumentException(
1258 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
1259 "cannot be smaller than 1.") % descriptorName %
1260 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
1261 }
1262
1263 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1264
1265 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1266 // inputChannels * channelMultiplier should be equal to outputChannels.
1267 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1268 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1269 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1270 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1271 {
1272 throw InvalidArgumentException(
1273 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
1274 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
1275 "(provided %4%).") % descriptorName % numWeightOutputChannels %
1276 numWeightInputChannels % numWeightChannelMultiplier));
1277 }
1278
Teresa Charlind8df0262019-11-11 12:28:15 +00001279 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001280
Teresa Charlind8df0262019-11-11 12:28:15 +00001281 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001282 if (m_Parameters.m_BiasEnabled)
1283 {
1284 ValidatePointer(m_Bias, descriptorName, "bias");
1285
Teresa Charlind8df0262019-11-11 12:28:15 +00001286 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1287 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001288
1289 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1290 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1291 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001292 ValidatePerAxisQuantization(inputTensorInfo,
1293 outputTensorInfo,
1294 weightTensorInfo,
1295 optionalBiasTensorInfo,
1296 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001297
1298 std::vector<DataType> supportedTypes =
1299 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001300 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001301 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001302 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001303 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001304 DataType::QSymmS16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001305 DataType::Float16
1306 };
1307
1308 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1309 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001310}
1311
1312void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1313{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001314 const std::string descriptorName{"PermuteQueueDescriptor"};
1315
1316 ValidateNumInputs(workloadInfo, descriptorName, 1);
1317 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001318
1319 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1320
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001321 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1322 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001323
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001324 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1325 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001326
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001327 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001328 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001329 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001330 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001331 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1332 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1333 "must match dst dimension " + to_string(mapping[i]) +
1334 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001335 }
1336 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001337
1338 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001339}
1340
1341void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1342{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001343 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001344
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001345 ValidateNumInputs(workloadInfo, descriptorName, 1);
1346 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1347
1348 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1349 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1350
1351 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1352 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001353
1354 std::vector<DataType> supportedTypes =
1355 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001356 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001357 DataType::Float32,
1358 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001359 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001360 DataType::QAsymmU8,
1361 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001362 };
1363
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001364 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1365 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001366}
1367
1368void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1369{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001370 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001371
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001372 ValidateNumInputs(workloadInfo, descriptorName, 1);
1373 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1374
1375 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1376 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1377
1378 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1379 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001380
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001381 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001382 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001383 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001384 DataType::Float16,
1385 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001386 DataType::QAsymmU8,
1387 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001388 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001389
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001390 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1391 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001392
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001393 // ResizeBilinear only changes width and height: batch and channel count must match.
1394 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1395 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001396 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001397 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001398 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001399 boost::str(boost::format("%1%: Input batch size (%2%) "
1400 "does not match output batch size (%3%)") %
1401 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001402 }
1403
Teresa Charlin970f43b2019-07-01 13:51:07 +01001404 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001405 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1406 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001407 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001408 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001409 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001410 boost::str(boost::format("%1%: Input channel count (%2%) "
1411 "does not match output channel count (%3%)") %
1412 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001413 }
1414}
1415
1416void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1417{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001418 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001419
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001420 ValidateNumInputs(workloadInfo, descriptorName, 1);
1421 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1422
1423 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1424 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1425
1426 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1427 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001428
1429 std::vector<DataType> supportedTypes =
1430 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001431 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001432 DataType::Float16,
1433 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001434 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001435 DataType::QAsymmU8,
1436 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001437 };
1438
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001439 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1440 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001441
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001442 // Resize only changes width and height: batch and channel count must match.
1443 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1444 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001445 if (inputBatchSize != outputBatchSize)
1446 {
1447 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001448 boost::str(boost::format("%1%: Input batch size (%2%) "
1449 "does not match output batch size (%3%)") %
1450 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001451 }
1452
1453 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001454 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1455 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001456 if (inputChannelCount != outputChannelCount)
1457 {
1458 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001459 boost::str(boost::format("%1%: Input channel count (%2%) "
1460 "does not match output channel count (%3%)") %
1461 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001462 }
1463}
1464
1465void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1466{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001467 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001468
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001469 ValidateNumInputs(workloadInfo, descriptorName, 1);
1470 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1471
1472 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1473 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1474
1475 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1476 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1477
1478 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1479
telsoa014fcda012018-03-09 14:13:49 +00001480 if (m_Parameters.m_Min > m_Parameters.m_Max)
1481 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001482 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001483 }
telsoa014fcda012018-03-09 14:13:49 +00001484}
1485
Kevin Mayce5045a2019-10-02 14:07:47 +01001486void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1487{
1488 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1489
1490 ValidateNumInputs(workloadInfo, descriptorName, 1);
1491 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1492
1493 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1494 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1495
1496 if (inputTensorInfo.GetNumDimensions() > 4)
1497 {
1498 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1499 }
1500
1501 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1502
1503 // Check the supported data types
1504 std::vector<DataType> supportedTypes =
1505 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001506 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001507 DataType::Float32,
1508 DataType::Float16
1509 };
1510
1511 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001512 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001513}
1514
telsoa014fcda012018-03-09 14:13:49 +00001515void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1516{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001517 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001518
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001519 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001520 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1521
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001522 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1523 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1524
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001525 if (inputTensorInfo.GetNumDimensions() > 4)
1526 {
1527 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1528 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001529
1530 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001531
1532 // Check the supported data types
1533 std::vector<DataType> supportedTypes =
1534 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001535 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001536 DataType::Float32,
1537 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001538 DataType::QAsymmU8,
1539 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001540 };
1541
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001542 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001543 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1544}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001545
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001546void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1547{
1548 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1549
1550 ValidateNumInputs(workloadInfo, descriptorName, 1);
1551 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1552
1553 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1554 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1555
1556 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1557
1558 std::vector<DataType> supportedTypes =
1559 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001560 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001561 DataType::Float32,
1562 DataType::Float16,
1563 };
1564
1565 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001566 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001567}
1568
1569void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1570{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001571 const std::string descriptorName{"ConstantQueueDescriptor"};
1572
1573 ValidateNumInputs(workloadInfo, descriptorName, 0);
1574 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001575
1576 if (!m_LayerOutput)
1577 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001578 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001579 }
1580
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001581 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1582 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001583
1584 // Check the supported data types
1585 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001586 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001587 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001588 DataType::Float32,
1589 DataType::Float16,
1590 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001591 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001592 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +00001593 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001594 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001595 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001596
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001597 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001598}
1599
1600void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1601{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001602 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001603
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001604 ValidateNumInputs(workloadInfo, descriptorName, 1);
1605 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1606
1607 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1608 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1609
1610 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001611
1612 // Check the supported data types
1613 std::vector<DataType> supportedTypes =
1614 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001615 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001616 DataType::Float32,
1617 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001618 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001619 DataType::QSymmS16,
1620 DataType::QAsymmS8,
Keith Davis67e6c542020-02-19 10:08:33 +00001621 DataType::QAsymmU8
Nina Drozd2f2778f2019-05-27 10:37:05 +01001622 };
1623
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001624 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1625 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001626}
1627
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001628void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1629{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001630 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001631
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001632 ValidateNumInputs(workloadInfo, descriptorName, 1);
1633 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1634
1635 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1636 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1637
1638 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1639 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001640
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001641 if (m_Parameters.m_BlockShape.size() != 2)
1642 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001643 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001644 }
1645
1646 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1647 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001648 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1649 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001650 }
1651
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001652 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001653
1654 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001655 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001656
Matthew Bentham8800c002018-11-19 13:19:28 +00001657 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001658
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001659 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1660 widthPad.first + widthPad.second;
1661 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1662 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001663
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001664 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1665 inputShape[dimensionIndices.GetChannelsIndex()];
1666 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001667
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001668 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001669 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001670 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001671 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001672 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001673 }
1674
1675 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001676 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001677 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1678 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001679 }
nikraj01120522a2019-05-31 11:33:07 +01001680
1681 std::vector<DataType> supportedTypes =
1682 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001683 DataType::BFloat16,
1684 DataType::Float16,
1685 DataType::Float32,
1686 DataType::QAsymmU8,
1687 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001688 };
1689
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001690 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1691 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001692}
1693
Keith Davisa57eccb2019-06-14 17:33:22 +01001694void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1695{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001696 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001697
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001698 ValidateNumInputs(workloadInfo, descriptorName, 1);
1699 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001700
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001701 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1702 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1703
1704 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1705 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001706
1707 std::vector<DataType> supportedTypes =
1708 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001709 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001710 DataType::Float32,
1711 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001712 DataType::QAsymmU8,
1713 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001714 };
1715
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001716 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1717 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001718
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001719 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1720
1721 if (m_Parameters.m_BlockSize == 0)
1722 {
1723 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1724 }
1725
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001726 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1727 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1728 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1729 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001730
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001731 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001732 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001733 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001734 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1735 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001736 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001737
1738 const TensorShape& outputShape = outputTensorInfo.GetShape();
1739 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1740 {
1741 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1742 "must be divisible by the square of block size." );
1743 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001744}
1745
telsoa014fcda012018-03-09 14:13:49 +00001746void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1747{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001748 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001749
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001750 ValidateNumInputs(workloadInfo, descriptorName, 1);
1751 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1752
1753 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1754 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001755
1756 std::vector<DataType> supportedTypes =
1757 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001758 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001759 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001760 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001761 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001762 };
1763
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001764 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001765
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001766 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001767 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001768 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001769 }
1770}
1771
telsoa01c577f2c2018-08-31 09:22:23 +01001772void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1773{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001774 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1775
1776 const std::string descriptorName{"LstmQueueDescriptor"};
1777
1778 // check dimensions of all inputs and outputs
1779 if (workloadInfo.m_InputTensorInfos.size() != 3)
1780 {
1781 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1782 }
1783 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1784 {
1785 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1786 }
1787
1788 std::vector<DataType> supportedTypes =
1789 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001790 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001791 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001792 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001793 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001794 };
1795
Jan Eilers38e05bd2019-06-26 13:10:09 +01001796 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001797 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1798
Jan Eilers38e05bd2019-06-26 13:10:09 +01001799 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001800 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001801 {
1802 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1803 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001804 descriptorName,
1805 "input_0",
1806 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001807 }
1808 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001809 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001810 {
1811 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1812 workloadInfo.m_OutputTensorInfos[i],
1813 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001814 "input_0",
1815 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001816 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001817
janeil0117d8d852019-11-15 15:00:16 +00001818 // Making sure clipping parameters have valid values.
1819 // == 0 means no clipping
1820 // > 0 means clipping
1821 if (m_Parameters.m_ClippingThresCell < 0.0f)
1822 {
1823 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
1824 }
1825 if (m_Parameters.m_ClippingThresProj < 0.0f)
1826 {
1827 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
1828 }
1829
Jan Eilers38e05bd2019-06-26 13:10:09 +01001830
1831 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01001832 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1833 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1834 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1835 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1836 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1837 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1838
Jan Eilers38e05bd2019-06-26 13:10:09 +01001839 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001840 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1841 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001842 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001843 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1844 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001845 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001846 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1847 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001848 // scratchBufferTensor
1849 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001850 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1851 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001852 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001853 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1854 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001855 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001856 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1857 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001858 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001859 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1860 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001861
1862
1863 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1864 if ( m_InputToInputWeights )
1865 {
1866 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1867 (n_cell * n_input), "InputLayerNormWeights");
1868 }
1869
1870 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1871 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1872 (n_cell * n_input), "InputToForgetWeights");
1873
1874 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1875 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1876 (n_cell * n_input), "InputToCellWeights");
1877
1878 if ( m_RecurrentToInputWeights )
1879 {
1880 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1881 (n_cell * n_output), "RecurrentToInputWeights");
1882 }
1883
1884 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1885 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1886 (n_cell * n_output), "RecurrentToForgetWeights");
1887
1888 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1889 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1890 (n_cell * n_output), "RecurrentToCellWeights");
1891
1892 // Make sure the input-gate's parameters are either both present (regular
1893 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1894 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1895 !m_Parameters.m_CifgEnabled) ||
1896 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1897 m_Parameters.m_CifgEnabled));
1898 if (!cifg_weights_all_or_none)
1899 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001900 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1901 "RecurrentToInputWeights must either both be present (regular LSTM) "
1902 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1903 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001904 }
1905
1906 if ( m_CellToInputWeights )
1907 {
1908 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1909 n_cell, "CellToInputWeights");
1910 }
1911 if ( m_CellToForgetWeights )
1912 {
1913 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1914 n_cell, "CellToForgetWeights");
1915 }
1916 if ( m_CellToOutputWeights )
1917 {
1918 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1919 n_cell, "CellToOutputWeights");
1920 }
1921
1922 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1923 bool peephole_weights_all_or_none =
1924 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1925 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1926 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1927 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1928 if (!peephole_weights_all_or_none)
1929 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001930 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001931 }
1932
1933 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1934 if (m_Parameters.m_CifgEnabled)
1935 {
1936 if (m_InputGateBias)
1937 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001938 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001939 }
1940 }
1941 else
1942 {
1943 if (!m_InputGateBias)
1944 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001945 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1946 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001947 }
1948 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1949 n_cell, "InputGateBias");
1950 }
1951
1952 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1953 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1954
1955 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1956 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1957
1958 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1959 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1960
1961 if (m_ProjectionWeights)
1962 {
1963 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
1964 (n_cell * n_output), "ProjectionWeights");
1965 }
1966 if (m_ProjectionBias)
1967 {
1968 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
1969 }
1970
1971 // Making sure the projection tensors are consistent:
1972 // 1) If projection weight is not present, then projection bias should not be
1973 // present.
1974 // 2) If projection weight is present, then projection bias is optional.
1975 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
1976 !m_Parameters.m_ProjectionEnabled)
1977 || (m_ProjectionWeights && !m_ProjectionBias &&
1978 m_Parameters.m_ProjectionEnabled)
1979 || (m_ProjectionWeights && m_ProjectionBias &&
1980 m_Parameters.m_ProjectionEnabled));
1981 if (!projecton_tensors_consistent)
1982 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001983 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001984 }
1985
1986 // The four layer normalization weights either all have values or none of them have values. Additionally, if
1987 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
1988 // either all have values or none of them have values. Layer normalization is used when the values of all the
1989 // layer normalization weights are present
1990 if (m_InputLayerNormWeights)
1991 {
1992 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
1993 }
1994 if (m_ForgetLayerNormWeights)
1995 {
1996 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1997 }
1998 if (m_CellLayerNormWeights)
1999 {
2000 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2001 }
2002 if (m_OutputLayerNormWeights)
2003 {
2004 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2005 }
2006
Jan Eilers38e05bd2019-06-26 13:10:09 +01002007 if (m_Parameters.m_LayerNormEnabled)
2008 {
2009 if (!m_Parameters.m_CifgEnabled)
2010 {
2011 if (!m_InputLayerNormWeights)
2012 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002013 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2014 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002015 }
2016 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2017 1, n_cell, "InputLayerNormWeights");
2018 }
2019 else if (m_InputLayerNormWeights)
2020 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002021 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2022 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002023 }
2024
2025 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2026 "ForgetLayerNormWeights");
2027 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2028
2029 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2030 "OutputLayerNormWeights");
2031 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2032
2033 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2034 "CellLayerNormWeights");
2035 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2036 }
2037 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2038 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002039 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2040 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002041 }
telsoa01c577f2c2018-08-31 09:22:23 +01002042}
2043
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002044void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2045{
2046 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2047
2048 ValidateNumInputs(workloadInfo, descriptorName, 1);
2049 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2050
2051 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2052 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2053
2054 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2055 {
2056 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2057 }
2058
2059 if (outputTensorInfo.GetDataType() != DataType::Float32)
2060 {
2061 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2062 }
2063
2064 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2065}
2066
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002067void ConvertFp32ToBf16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2068{
2069 const std::string descriptorName{"ConvertFp32ToBf16QueueDescriptor"};
2070
2071 ValidateNumInputs(workloadInfo, descriptorName, 1);
2072 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2073
2074 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2075 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2076
2077 if (inputTensorInfo.GetDataType() != DataType::Float32)
2078 {
2079 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2080 }
2081
2082 if (outputTensorInfo.GetDataType() != DataType::BFloat16)
2083 {
2084 throw InvalidArgumentException(descriptorName + ": Output tensor type must be BFloat16.");
2085 }
2086
2087 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2088}
2089
telsoa01c577f2c2018-08-31 09:22:23 +01002090void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2091{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002092 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002093
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002094 ValidateNumInputs(workloadInfo, descriptorName, 1);
2095 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2096
2097 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2098 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2099
2100 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002101 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002102 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002103 }
2104
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002105 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002106 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002107 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002108 }
2109
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002110 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002111}
2112
2113void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2114{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002115 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002116
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002117 ValidateNumInputs(workloadInfo, descriptorName, 1);
2118 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2119
2120 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2121 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2122
2123 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002124 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002125 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002126 }
2127
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002128 if (outputTensorInfo.GetDataType() != DataType::Float32)
2129 {
2130 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2131 }
2132
2133 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002134}
2135
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002136void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2137{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002138 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002139
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002140 ValidateNumInputs(workloadInfo, descriptorName, 2);
2141 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2142
2143 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2144 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2145 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2146
2147 std::vector<DataType> supportedTypes =
2148 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002149 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002150 DataType::QAsymmU8,
2151 DataType::QSymmS16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002152 DataType::Float16,
2153 DataType::BFloat16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002154 };
2155
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002156 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2157 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2158 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002159
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002160 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2161 inputTensorInfo1,
2162 outputTensorInfo,
2163 descriptorName,
2164 "input_0",
2165 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002166}
2167
David Beckc2044fe2018-09-05 15:00:38 +01002168void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2169{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002170 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002171
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002172 ValidateNumInputs(workloadInfo, descriptorName, 2);
2173 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2174
2175 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2176 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2177 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2178
2179 std::vector<DataType> supportedTypes =
2180 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002181 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002182 DataType::QAsymmU8,
2183 DataType::QSymmS16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002184 DataType::Float16,
2185 DataType::BFloat16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002186 };
2187
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002188 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2189 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2190 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002191
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002192 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2193 inputTensorInfo1,
2194 outputTensorInfo,
2195 descriptorName,
2196 "input_0",
2197 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002198}
2199
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002200void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2201{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002202 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002203
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002204 ValidateNumInputs(workloadInfo, descriptorName, 2);
2205 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2206
2207 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2208 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2209 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2210
2211 std::vector<DataType> supportedTypes =
2212 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002213 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002214 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002215 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002216 DataType::Signed32,
Keith Davis67e6c542020-02-19 10:08:33 +00002217 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002218 DataType::QAsymmU8,
2219 DataType::QSymmS16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002220 };
2221
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002222 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2223 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2224 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002225
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002226 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2227 inputTensorInfo1,
2228 outputTensorInfo,
2229 descriptorName,
2230 "input_0",
2231 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002232}
2233
narpra01a6bf9122018-09-10 09:50:09 +01002234void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2235{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002236 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002237
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002238 ValidateNumInputs(workloadInfo, descriptorName, 1);
2239 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2240
2241 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2242 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002243
2244 std::vector<DataType> supportedTypes =
2245 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002246 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002247 DataType::Float32,
2248 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002249 DataType::QAsymmU8,
2250 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002251 };
narpra01eb061912018-09-10 17:35:27 +01002252
James Conroy4d1ff582019-06-10 17:06:39 +01002253 // First check if input tensor data type is supported, then
2254 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002255 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2256 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002257
narpra0132b90462018-09-13 11:07:48 +01002258 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002259 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002260 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002261 }
narpra0132b90462018-09-13 11:07:48 +01002262 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002263 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002264 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002265 }
2266 else
2267 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002268 unsigned int outputDim =
2269 inputTensorInfo.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
2270 ValidateTensorNumDimensions(outputTensorInfo,
2271 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002272 outputDim > 0 ? outputDim : 1,
2273 "output");
2274 }
narpra01a6bf9122018-09-10 09:50:09 +01002275}
2276
jimfly012c9322a2018-09-19 10:59:49 +01002277void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2278{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002279 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002280
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002281 ValidateNumInputs(workloadInfo, descriptorName, 1);
2282 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2283
2284 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2285 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002286
jimfly012c9322a2018-09-19 10:59:49 +01002287 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002288 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2289
jimfly012c9322a2018-09-19 10:59:49 +01002290 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002291 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2292 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2293 "as there are dimensions in the input tensor that is " +
2294 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2295 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002296 }
2297}
2298
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002299void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2300{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002301 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002302
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002303 ValidateNumInputs(workloadInfo, descriptorName, 1);
2304 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002305
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002306 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2307 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2308
Sadik Armagan2208b602019-07-31 16:36:27 +01002309 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002310 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002311 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002312 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002313 DataType::Float16,
2314 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002315 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002316 DataType::QAsymmU8,
2317 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002318 };
2319
2320 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002321
Keith Davis0c2eeac2020-02-11 16:51:50 +00002322 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002323 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002324 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002325 }
2326}
2327
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002328void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2329{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002330 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002331
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002332 ValidateNumInputs(workloadInfo, descriptorName, 1);
2333 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002334
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002335 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2336 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002337
2338 std::vector<DataType> supportedTypes =
2339 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002340 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002341 DataType::Float32,
2342 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002343 DataType::QAsymmU8,
2344 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002345 };
2346
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002347 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2348 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002349}
2350
Conor Kennedy430b5d82018-11-14 15:28:28 +00002351void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2352{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002353 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002354
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002355 ValidateNumInputs(workloadInfo, descriptorName, 1);
2356 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2357
2358 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2359 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002360
2361 std::vector<DataType> supportedTypes =
2362 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002363 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002364 DataType::Float16,
2365 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002366 DataType::QAsymmU8,
2367 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002368 };
2369
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002370 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2371 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002372
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002373 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002374
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002375 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002376 if (rank > 4)
2377 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002378 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002379 }
2380
Conor Kennedy430b5d82018-11-14 15:28:28 +00002381 // Begin, End & Stride length must be of rank(input0)
2382 if (m_Parameters.m_Begin.size() != rank)
2383 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002384 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002385 }
2386
2387 if (m_Parameters.m_End.size() != rank)
2388 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002389 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002390 }
2391
2392 if (m_Parameters.m_Stride.size() != rank)
2393 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002394 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002395 }
2396
2397 // Stride entries must be non-zero
2398 for (auto& stride : m_Parameters.m_Stride)
2399 {
2400 if (stride == 0)
2401 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002402 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002403 }
2404 }
2405}
2406
kevmay0190539692018-11-29 08:40:19 +00002407void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2408{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002409 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002410
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002411 ValidateNumInputs(workloadInfo, descriptorName, 2);
2412 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2413
2414 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2415 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2416 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2417
2418 std::vector<DataType> supportedTypes =
2419 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002420 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002421 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002422 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002423 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002424 DataType::QAsymmU8,
2425 DataType::QSymmS16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002426 };
2427
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002428 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2429 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2430 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002431
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002432 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2433 inputTensorInfo1,
2434 outputTensorInfo,
2435 descriptorName,
2436 "input_0",
2437 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002438}
2439
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002440void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2441{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002442 const std::string descriptorName{"DebugQueueDescriptor"};
2443
2444 ValidateNumInputs(workloadInfo, descriptorName, 1);
2445 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002446}
2447
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002448void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2449{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002450 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002451
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002452 ValidateNumInputs(workloadInfo, descriptorName, 2);
2453 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002454
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002455 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2456 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2457 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2458
2459 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2460 inputTensorInfo1,
2461 outputTensorInfo,
2462 descriptorName,
2463 "input_0",
2464 "input_1");
2465
2466 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002467 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002468 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002469 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002470}
2471
FrancisMurtagh878f0232018-12-19 10:56:15 +00002472void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2473{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002474 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002475
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002476 ValidateNumInputs(workloadInfo, descriptorName, 2);
2477 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002478
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002479 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2480 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2481 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2482
2483 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2484 inputTensorInfo1,
2485 outputTensorInfo,
2486 descriptorName,
2487 "input_0",
2488 "input_1");
2489
2490 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002491 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002492 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002493 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002494}
2495
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002496void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2497{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002498 const std::string descriptorName{"RsqrtQueueDescriptor"};
2499
2500 ValidateNumInputs(workloadInfo, descriptorName, 1);
2501 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2502
2503 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2504 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2505
2506 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002507
2508 std::vector<DataType> supportedTypes =
2509 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002510 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002511 DataType::Float16,
2512 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002513 DataType::QAsymmU8,
2514 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002515 };
2516
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002517 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2518 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002519}
2520
narpra01b89b05f2019-01-16 09:53:09 +00002521void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2522{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002523 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002524
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002525 ValidateNumInputs(workloadInfo, descriptorName, 2);
2526 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002527
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002528 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2529 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002530 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002531 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002532 }
2533
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002534 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2535 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2536
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002537 std::vector<DataType> supportedTypes =
2538 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002539 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002540 DataType::Float16,
2541 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002542 DataType::QAsymmU8,
2543 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002544 };
2545
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002546 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002547
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002548 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002549
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002550 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2551 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002552}
2553
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002554void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2555{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002556 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2557
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002558 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002559
2560 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2561 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002562 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002563 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2564 }
2565
2566 if (m_Anchors == nullptr)
2567 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002568 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002569 }
2570
2571 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002572 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2573 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2574
2575 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002576 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002577 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2578 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002579
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002580 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2581 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2582 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002583
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002584 const std::vector<DataType> supportedInputTypes =
2585 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002586 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002587 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002588 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002589 DataType::QAsymmU8,
2590 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002591 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002592
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002593 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2594 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2595 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2596
2597 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2598 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2599 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2600 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2601
2602 // NOTE: Output is always Float32 regardless of input type
2603 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2604 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2605 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2606 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002607
2608 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2609 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002610 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002611 "must be positive and less than or equal to 1.");
2612 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002613
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002614 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2615 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002616 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002617 "should be equal to number of classes + 1.");
2618 }
2619}
2620
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002621void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2622{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002623 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002624
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002625 ValidateNumInputs(workloadInfo, descriptorName, 1);
2626 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2627
2628 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2629 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2630
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002631 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002632 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002633 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002634 }
2635
Sadik Armagan2208b602019-07-31 16:36:27 +01002636 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002637 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002638 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002639 DataType::Float32,
2640 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002641 };
2642
2643 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002644}
2645
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002646void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2647{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002648 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002649
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002650 ValidateNumInputs(workloadInfo, descriptorName, 2);
2651 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002652
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002653 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2654 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2655 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002656
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002657 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2658 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2659
2660 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2661 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002662}
2663
Sadik Armaganeff363d2019-04-05 15:25:46 +01002664void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2665{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002666 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002667
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002668 ValidateNumInputs(workloadInfo, descriptorName, 2);
2669 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2670
2671 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2672 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2673
2674 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2675 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2676
2677 std::vector<DataType> supportedTypes =
2678 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002679 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002680 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002681 DataType::QAsymmU8,
2682 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002683 };
2684
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002685 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2686 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002687
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002688 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2689 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002690
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002691 ValidateTensorShapesMatch(inputTensorInfo0,
2692 outputTensorInfo0,
2693 descriptorName,
2694 "input_0",
2695 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002696
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002697 ValidateTensorShapesMatch(inputTensorInfo0,
2698 outputTensorInfo1,
2699 descriptorName,
2700 "input_0",
2701 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002702}
2703
Derek Lamberti901ea112019-12-10 22:07:09 +00002704void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002705{
2706 // This is internally generated so it should not need validation.
2707}
2708
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002709void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2710{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002711 const std::string& descriptorName{"PreluQueueDescriptor"};
2712
2713 ValidateNumInputs(workloadInfo, descriptorName, 2);
2714 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2715
2716 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2717 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2718 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002719
2720 std::vector<DataType> supportedTypes
2721 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002722 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002723 DataType::Float16,
2724 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002725 DataType::QAsymmU8,
2726 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002727 };
2728
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002729 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2730 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002731
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002732 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002733
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002734 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2735 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002736
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002737 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2738 alphaTensorInfo,
2739 outputTensorInfo,
2740 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002741 "input",
2742 "alpha");
2743}
2744
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002745void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2746{
2747 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2748
2749 ValidateNumInputs(workloadInfo, descriptorName, 1);
2750 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2751
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002752 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2753 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2754
2755 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2756 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002757
2758 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002759
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002760 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2761 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002762
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002763 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2764
2765 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002766 if (m_Parameters.m_BiasEnabled)
2767 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002768 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002769
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002770 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2771 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002772
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002773 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002774 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002775 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002776
2777 ValidatePerAxisQuantization(inputTensorInfo,
2778 outputTensorInfo,
2779 weightTensorInfo,
2780 optionalBiasTensorInfo,
2781 descriptorName);
2782
2783 std::vector<DataType> supportedTypes =
2784 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002785 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002786 DataType::Float32,
2787 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002788 DataType::QAsymmU8,
2789 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002790 };
2791
2792 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2793 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002794}
2795
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002796void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2797{
2798 const std::string descriptorName{"TransposeQueueDescriptor"};
2799
2800 ValidateNumInputs(workloadInfo, descriptorName, 1);
2801 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2802
2803 const PermutationVector& mapping = m_Parameters.m_DimMappings;
2804
2805 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2806 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2807
2808 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
2809 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
2810
2811 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
2812 {
2813 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
2814 {
2815 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
2816 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
2817 "must match dst dimension " + to_string(i) +
2818 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
2819 }
2820 }
2821
2822 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2823}
2824
James Conroy9c3cae82019-08-01 16:01:48 +01002825void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2826{
2827 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
2828
2829 // Validate number of inputs/outputs
2830 ValidateNumInputs(workloadInfo, descriptorName, 3);
2831 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2832
2833 // Input/output tensor infos
2834 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2835 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
2836 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
2837
2838 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2839 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2840
2841 std::vector<DataType> inputOutputSupportedTypes =
2842 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002843 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01002844 };
2845
2846 std::vector<DataType> cellStateSupportedTypes =
2847 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002848 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01002849 };
2850
2851 std::vector<DataType> weightsSupportedTypes =
2852 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002853 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01002854 };
2855
2856 std::vector<DataType> biasSupportedTypes =
2857 {
2858 DataType::Signed32
2859 };
2860
2861 // Validate types of input/output tensors
2862 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2863 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2864 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2865
2866 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2867 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2868
2869 // Validate matching types of input/output tensors
2870 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2871 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2872 "outputStateIn", "outputStateOut");
2873 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2874
2875 // Validate matching quantization info for input/output tensors
2876 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2877 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
2878 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002879
James Conroy9c3cae82019-08-01 16:01:48 +01002880 // Infer number of batches, input size and output size from tensor dimensions
2881 const uint32_t numBatches = inputInfo.GetShape()[0];
2882 const uint32_t inputSize = inputInfo.GetShape()[1];
2883 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
2884
2885 // Validate number of dimensions and number of elements for input/output tensors
2886 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2887 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
2888 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2889 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
2890 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2891
2892 // Validate number of dimensions and number of elements for weights tensors
2893 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
2894 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
2895 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
2896
2897 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2898 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2899 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
2900
2901 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2902 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2903 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
2904
2905 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2906 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2907 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
2908
2909 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
2910 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
2911 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
2912
2913 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2914 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2915 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
2916 " RecurrentToForgetWeights");
2917
2918 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2919 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2920 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2921
2922 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2923 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2924 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2925
2926 // Validate data types for weights tensors (all should match each other)
2927 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
2928
2929 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
2930 "inputToInputWeights", "inputToForgetWeights");
2931 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
2932 "inputToInputWeights", "inputToCellWeights");
2933 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2934 "inputToInputWeights", "inputToOutputWeights");
2935
2936 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
2937 "inputToInputWeights", "recurrentToInputWeights");
2938 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2939 "inputToInputWeights", "recurrentToForgeteights");
2940 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2941 "inputToInputWeights", "recurrentToCellWeights");
2942 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2943 "inputToInputWeights", "recurrentToOutputWeights");
2944
2945 // Validate matching quantization info for weight tensors (all should match each other)
2946 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
2947 descriptorName, "inputToInputWeights", "inputToForgetWeights");
2948 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
2949 descriptorName, "inputToInputWeights", "inputToCellWeights");
2950 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
2951 descriptorName, "inputToInputWeights", "inputToOutputWeights");
2952
2953 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
2954 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
2955 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
2956 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
2957 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
2958 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
2959 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
2960 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
2961
2962 // Validate number of dimensions and number of elements in bias tensors
2963 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
2964 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
2965 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
2966
2967 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2968 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2969 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
2970
2971 ValidatePointer(m_CellBias, descriptorName, "CellBias");
2972 auto cellBiasInfo = m_CellBias->GetTensorInfo();
2973 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
2974
2975 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
2976 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
2977 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
2978
2979 // Validate data types for bias tensors (all should match each other)
2980 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
2981
2982 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
2983 "inputGateBias", "forgetGateBias");
2984 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
2985 "inputGateBias", "cellBias");
2986 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
2987 "inputGateBias", "outputGateBias");
2988
2989 // Validate bias tensor quantization info
2990 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2991 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2992 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2993 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2994}
2995
Kevin May868eb142019-09-04 17:29:31 +01002996void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2997{
2998 const std::string descriptorName{"AbsQueueDescriptor"};
2999
3000 ValidateNumInputs(workloadInfo, descriptorName, 1);
3001 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3002
3003 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3004 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3005
3006 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3007
3008 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003009 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003010 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003011 DataType::Float16,
3012 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003013 DataType::QAsymmU8,
3014 DataType::QSymmS16
James Conroyd47a0642019-09-17 14:22:06 +01003015 };
Kevin May868eb142019-09-04 17:29:31 +01003016
3017 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3018 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3019}
3020
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003021void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3022{
3023 const std::string descriptorName{"SliceQueueDescriptor"};
3024
3025 ValidateNumInputs(workloadInfo, descriptorName, 1);
3026 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3027
3028 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3029 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3030
3031 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3032
3033 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3034 if (rank > 4)
3035 {
3036 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3037 }
3038
3039 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3040
3041 // Check if m_Begin and m_Size have the expected length
3042 if (m_Parameters.m_Begin.size() != rank)
3043 {
3044 throw InvalidArgumentException(descriptorName +
3045 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3046 }
3047 if (m_Parameters.m_Size.size() != rank)
3048 {
3049 throw InvalidArgumentException(descriptorName +
3050 ": Length of size descriptor must equal rank " + std::to_string(rank));
3051 }
3052
3053 // Check if the shape of the output tensor matches m_Size
3054 const TensorShape& outputShape = outputTensorInfo.GetShape();
3055 for (unsigned int i = 0u; i < rank; ++i)
3056 {
3057 if (m_Parameters.m_Size[i] != outputShape[i])
3058 {
3059 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3060 }
3061 }
3062
3063 // Check if the sum of begin offset and size in a given dimension
3064 // does not exceed the size of corresponding input
3065 const TensorShape& inputShape = inputTensorInfo.GetShape();
3066 for(unsigned int i = 0u; i < rank; ++i)
3067 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003068 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003069 {
3070 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3071 std::to_string(i) + " exceeds input size.");
3072 }
3073 }
3074}
3075
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003076void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3077{
3078 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3079
3080 ValidateNumInputs(workloadInfo, descriptorName, 1);
3081 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3082
3083 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3084 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3085
3086 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3087 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3088
3089 std::vector<DataType> supportedTypes =
3090 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003091 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003092 DataType::Float32,
3093 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003094 DataType::QAsymmU8,
3095 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003096 };
3097
3098 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3099 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3100
3101 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3102
3103 if (m_Parameters.m_BlockSize == 0)
3104 {
3105 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3106 }
3107
3108 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3109 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3110 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3111 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3112
3113 const TensorShape& outputShape = outputInfo.GetShape();
3114 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3115 {
3116 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3117 "must be divisible by block size.");
3118 }
3119
3120 const TensorShape& inputShape = inputInfo.GetShape();
3121 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3122 {
3123 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3124 "must be divisible by the square of block size." );
3125 }
3126}
3127
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003128void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3129{
3130 const std::string descriptorName{"ComparisonQueueDescriptor"};
3131
3132 ValidateNumInputs(workloadInfo, descriptorName, 2);
3133 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3134
3135 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3136 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3137 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3138
3139 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3140 inputTensorInfo1,
3141 outputTensorInfo,
3142 descriptorName,
3143 "input_0",
3144 "input_1");
3145
3146 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3147 {
3148 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3149 }
3150}
3151
josh minor4a3c6102020-01-06 16:40:46 -06003152void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3153{
3154 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3155
3156 ValidateNumInputs(workloadInfo, descriptorName, 1);
3157 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3158
3159 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3160 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3161
3162 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3163
3164 std::vector<DataType> supportedTypes =
3165 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003166 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003167 DataType::Float16,
3168 DataType::Float32,
3169 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003170 DataType::QSymmS16,
3171 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003172 };
3173
3174 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3175 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3176}
3177
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003178} // namespace armnn