blob: 3949fa945dd56c0d6dc52035c6428fcbd6de06eb [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00006#include <backendsCommon/WorkloadData.hpp>
7#include <backendsCommon/CpuTensorHandle.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/TensorUtils.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000013#include <string>
14#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/format.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010017#include <boost/numeric/conversion/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
telsoa01c577f2c2018-08-31 09:22:23 +010029 case DataType::Float16:
30 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000031 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000032 case DataType::Float32:
33 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000034 case DataType::QAsymmS8:
35 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000036 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000037 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000038 case DataType::QSymmS8:
39 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000040 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010041 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000042 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010043 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000044 return DataType::Float32;
45 }
46}
47
48namespace
49{
50
51//---------------------------------------------------------------
52//android ndk does not support std::to_string function.
53template <typename T>
54std::string to_string(T value)
55{
56 std::ostringstream os;
57 os << value;
58 return os.str();
59}
60
61//---------------------------------------------------------------
62void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
63{
64 if (!ptr)
65 {
66 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
67 paramName + " parameter must be set.");
68 }
69}
70
71//---------------------------------------------------------------
72void ValidateTensorShapesMatch(const TensorInfo& first,
73 const TensorInfo& second,
74 std::string const& descName,
75 std::string const& firstName,
76 std::string const& secondName)
77{
78 if (first.GetShape() != second.GetShape())
79 {
80 throw InvalidArgumentException(descName + ": "
81 + firstName + " & " + secondName + " must have identical shapes");
82 }
83}
84
85//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010086void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000087{
Sadik Armaganeff363d2019-04-05 15:25:46 +010088 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000089 {
90 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010091 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000092 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
93 }
94}
95
96//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010097void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000098{
Sadik Armaganeff363d2019-04-05 15:25:46 +010099 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000100 {
101 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100102 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000103 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
104 }
105}
106
107//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100108void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000109 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100110 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000111 std::string const& tensorName)
112{
113 if (tensor.GetNumDimensions() != numDimensions)
114 {
115 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
116 to_string(tensor.GetNumDimensions()) + " dimensions for " +
117 tensorName + " tensor.");
118 }
119}
120
121//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100122void ValidateTensorNumElements(const TensorInfo& tensor,
123 std::string const& descName,
124 unsigned int numElements,
125 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100126{
127 if (tensor.GetNumElements() != numElements)
128 {
129 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100130 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100131 tensorName + " tensor.");
132 }
133}
134
135//---------------------------------------------------------------
136void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100137 unsigned int numDimension,
138 unsigned int numElements,
139 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100140{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100141 const std::string functionName{"ValidateTensorNumDimNumElem"};
142 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
143 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100144}
145
146//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000147void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
148 const std::string& descName, std::string const& tensorName)
149{
150 if (tensor.GetDataType() != dataType)
151 {
152 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
153 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
154 }
155}
156
Derek Lambertid466a542020-01-22 15:37:29 +0000157void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
158{
159 ARMNN_NO_DEPRECATE_WARN_BEGIN
160 if (tensor.GetDataType() != DataType::QSymmS8 &&
161 tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
162 {
163 throw InvalidArgumentException(descName +
164 ": Expected data type which supports per-axis quantization scheme but got " +
165 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
166 }
167 ARMNN_NO_DEPRECATE_WARN_END
168}
169
telsoa014fcda012018-03-09 14:13:49 +0000170//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100171void ValidateTensorQuantizationSpace(const TensorInfo& first,
172 const TensorInfo& second,
173 const std::string& descName,
174 std::string const& firstName,
175 std::string const& secondName)
176{
177 if (!first.IsQuantized() ||
178 !second.IsQuantized())
179 {
180 // Not a quantized type, ignore the validation
181 return;
182 }
183
184 DataType firstDataType = first.GetDataType();
185 DataType secondDataType = second.GetDataType();
186
187 if (firstDataType != secondDataType)
188 {
189 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
190 " must be of the same quantized type, " +
191 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
192 secondName + " is " + GetDataTypeName(secondDataType));
193 }
194
195 if (!first.IsTypeSpaceMatch(second))
196 {
197 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
198 " must have the same quantization space, " +
199 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
200 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
201 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
202 " and scale " + to_string(second.GetQuantizationScale()));
203 }
204}
205
206//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100207void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
208 const TensorInfo& inputTensorInfo,
209 const TensorInfo& weightsTensorInfo,
210 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000211{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000212 // Helper lambda function to validate a single bias quantization scale value
213 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
214 {
ricbur013f4d7102019-10-31 16:22:18 +0000215 constexpr float tolerance = 0.000001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000216 if (std::abs(biasScale - expectedScale) > tolerance)
217 {
218 // Print the float values with extra precision to see very small differences
219 std::stringstream msg;
220 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
221 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
222 biasScale;
223 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
224 }
225 };
226
telsoa014fcda012018-03-09 14:13:49 +0000227 if (biasTensor.GetQuantizationOffset() != 0)
228 {
229 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
230 to_string(biasTensor.GetQuantizationOffset()));
231 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000232
233 if (biasTensor.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000234 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000235 // Validate per-axis quantization scales
236 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
237 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
238
239 if (weightScales.size() != biasScales.size())
240 {
241 std::stringstream msg;
242 msg << descName << ": Expected matchhing number of per-axis quantization scales, but got different "
243 << "values: weights=" << weightScales.size() << ", biases=" << biasScales.size();
244 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
245 }
246
247 for (size_t i = 0ul; i < biasScales.size(); ++i)
248 {
249 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
250 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
251 }
252 }
253 else
254 {
255 // Validate per-tensor quantization scale
256 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
257 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000258 }
259}
260
261//---------------------------------------------------------------
262void ValidateTensors(const std::vector<ITensorHandle*>& vec,
263 unsigned int numExpected,
264 const std::string& descName,
265 const std::string& varName)
266{
267 if (vec.empty() && numExpected > 0)
268 {
269 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
270 }
271
272 for (unsigned int i = 0; i < numExpected; ++i)
273 {
274 if (!vec[i])
275 {
276 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
277 }
278 }
279}
280
281//---------------------------------------------------------------
282void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
283 const TensorInfo& second,
284 const TensorInfo& output,
285 std::string const& descName,
286 std::string const& firstName,
287 std::string const& secondName)
288{
289 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
290 // broadcasted.
291 if (first.GetNumDimensions() != second.GetNumDimensions())
292 {
293 throw InvalidArgumentException(descName + ": Tensors "
294 + firstName + " & " + secondName
295 + " must have the same number of dimensions in order to be broadcasted");
296 }
297 uint32_t numDims = first.GetNumDimensions();
298 std::vector<uint32_t> outputDims(numDims, 0u);
299 for (uint32_t i = 0; i < numDims; i++)
300 {
301 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
302 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
303 if (dimsNotEqual && dimsNotOne)
304 {
305 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
306 }
307 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
308 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100309 TensorShape broadcastShape = TensorShape(boost::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000310 if (broadcastShape != output.GetShape())
311 {
312 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
313 + firstName + " & " + secondName
314 + " does not match the output shape");
315 }
316}
317
318//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100319void ValidateDataTypes(const TensorInfo& info,
320 const std::vector<armnn::DataType>& supportedTypes,
321 std::string const& descName)
322{
323 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
324 if (iterator == supportedTypes.end())
325 {
326 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
327 }
328}
329
James Conroy4d1ff582019-06-10 17:06:39 +0100330//---------------------------------------------------------------
331void ValidateTensorDataTypesMatch(const TensorInfo& first,
332 const TensorInfo& second,
333 std::string const& descName,
334 std::string const& firstName,
335 std::string const& secondName)
336{
337 if (first.GetDataType() != second.GetDataType())
338 {
339 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
340 " must have identical data types.");
341 }
342}
343
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100344//---------------------------------------------------------------
345void ValidateTensorNumElementsMatch(const TensorInfo& first,
346 const TensorInfo& second,
347 std::string const& descName,
348 std::string const& firstName,
349 std::string const& secondName)
350{
351 if (first.GetNumElements() != second.GetNumElements())
352 {
353 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
354 " must have the same number of elements.");
355 }
356}
357
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000358void ValidateWeightDataType(const TensorInfo& inputInfo,
359 const TensorInfo& weightInfo,
360 const std::string& descName)
361{
362 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000363 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000364 {
Derek Lambertid466a542020-01-22 15:37:29 +0000365 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000366 const std::vector<DataType> validTypes =
367 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000368 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100369 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000370 DataType::QSymmS8,
371 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000372 };
Derek Lambertid466a542020-01-22 15:37:29 +0000373 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000374
375 ValidateDataTypes(weightInfo, validTypes, descName);
376 }
377 else
378 {
379 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
380 }
381}
382
383void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
384 const std::string& descName,
385 const std::string& tensorName)
386{
387 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
388 if (!quantizationDim.has_value())
389 {
390 throw InvalidArgumentException(boost::str(
391 boost::format("%1%: Quantization dimension for per-axis quantization not set on tensor %2%.")
392 % descName % tensorName));
393 }
394
395 if (quantizationDim.value() != 0)
396 {
397 throw InvalidArgumentException(boost::str(
398 boost::format("%1%: Quantization dimension for per-axis quantization expected to be 0 on tensor %2%, "
399 "but got: %3%") % descName % tensorName % quantizationDim.value()));
400 }
401}
402
403void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
404 const std::string& descName,
405 const std::string& tensorName)
406{
407 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
408 if (quantizationOffset != 0)
409 {
410 throw InvalidArgumentException(boost::str(
411 boost::format("%1%: Quantization offset for per-axis quantization expected to be 0 on tensor %2%, "
412 "but got: %3%") % descName % tensorName % quantizationOffset));
413 }
414}
415
416void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
417 const TensorInfo& outputInfo,
418 const TensorInfo& weightInfo,
419 const Optional<TensorInfo>& optionalBiasInfo,
420 const std::string& descName)
421{
422 if (weightInfo.HasPerAxisQuantization())
423 {
424 const DataType inputDataType = inputInfo.GetDataType();
425 const DataType outputDataType = outputInfo.GetDataType();
426
Keith Davis0c2eeac2020-02-11 16:51:50 +0000427 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000428
429 if (!canHavePerAxisQuantization)
430 {
431 throw InvalidArgumentException(boost::str(
432 boost::format("%1%: Per-axis quantization parameters set on tensor %2%, "
433 "but data type does not support per-axis quantization.") % descName % "weight"));
434 }
435
Derek Lambertid466a542020-01-22 15:37:29 +0000436
437 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000438 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
439 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
440
441 if (optionalBiasInfo.has_value())
442 {
443 const TensorInfo& biasInfo = optionalBiasInfo.value();
444 if (!biasInfo.HasPerAxisQuantization())
445 {
446 throw InvalidArgumentException(boost::str(
447 boost::format("%1%: Per-axis quantization parameters not set on bias tensor, despite being set on "
448 "weight tensor.") % descName));
449 }
450
451 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
452 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
453 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
454 }
455 }
456}
457
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100458} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000459
460void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
461 unsigned int numExpectedIn, unsigned int numExpectedOut) const
462{
463 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
464 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
465}
466
467//---------------------------------------------------------------
468void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
469{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100470 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000471
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100472 ValidateNumInputs(workloadInfo, descriptorName, 1);
473 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000474
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100475 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
476 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
477
478 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
479 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000480
481 if (m_Inputs.size() != m_Outputs.size())
482 {
483 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100484 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
485 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000486 }
487
488 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
489 {
490 if (!m_Inputs[i])
491 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100492 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
493 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000494 }
495
496 if (!m_Outputs[i])
497 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100498 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
499 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000500 }
501 }
502}
503
Derek Lambertif674aa02019-08-01 15:56:25 +0100504//---------------------------------------------------------------
505void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
506{
507 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
508 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
509
510 if (workloadInfo.m_InputTensorInfos.size() != 1)
511 {
512 throw InvalidArgumentException(boost::str(
513 boost::format("Number of input infos (%1%) is not 1.")
514 % workloadInfo.m_InputTensorInfos.size()));
515
516 }
517
518 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
519 {
520 throw InvalidArgumentException(boost::str(
521 boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
522 % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
523 }
524
525 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
526 {
527 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
528 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
529 {
530 throw InvalidArgumentException(boost::str(
531 boost::format("Number of elements for tensor input and output %1% does not match")
532 % i ));
533 }
534 }
535
536 if (m_Inputs.size() != 1)
537 {
538 throw InvalidArgumentException(boost::str(
539 boost::format("Number of inputs (%1%) is not 1.")
540 % m_Inputs.size()));
541 }
542
543 if (m_Inputs.size() != m_Outputs.size())
544 {
545 throw InvalidArgumentException(boost::str(
546 boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
547 % m_Inputs.size() % m_Outputs.size()));
548 }
549
550 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
551 {
552 if (!m_Inputs[i])
553 {
554 throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
555 }
556
557 if (!m_Outputs[i])
558 {
559 throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
560 }
561 }
562}
563
564//---------------------------------------------------------------
565void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
566{
567 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
568 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
569
Derek Lambertif674aa02019-08-01 15:56:25 +0100570 if (m_Inputs.size() != 1)
571 {
572 throw InvalidArgumentException(boost::str(
573 boost::format("Number of inputs (%1%) is not 1.")
574 % m_Inputs.size()));
575 }
576
577 if (m_Outputs.size() != 0)
578 {
579 throw InvalidArgumentException(boost::str(
580 boost::format("Number of outputs (%1%) is not 0.")
581 % m_Inputs.size() % m_Outputs.size()));
582 }
583
584 if (!m_Inputs[0])
585 {
586 throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
587 }
588}
589
590//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000591void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
592{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100593 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100594
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100595 ValidateNumInputs(workloadInfo, descriptorName, 1);
596 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100597
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100598 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
599 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100600
601 std::vector<DataType> supportedTypes =
602 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000603 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100604 DataType::Float16,
605 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000606 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000607 DataType::QAsymmU8,
608 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100609 };
610
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100611 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
612 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
613 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000614}
615
Nikhil Rajee391d52019-09-05 17:50:44 +0100616void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
617{
618 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
619
620 ValidateNumInputs(workloadInfo, descriptorName, 1);
621 ValidateNumOutputs(workloadInfo, descriptorName, 1);
622
623 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
624 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
625
Nikhil Raj68c2c902019-09-19 11:21:11 +0100626 if (outputTensorInfo.GetDataType() != DataType::Signed32)
627 {
628 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32.");
629 }
630
James Conroyd47a0642019-09-17 14:22:06 +0100631 std::vector<DataType> supportedInputTypes =
632 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000633 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100634 DataType::Float16,
635 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100636 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000637 DataType::QAsymmU8,
638 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000639 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +0100640 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100641
James Conroyd47a0642019-09-17 14:22:06 +0100642 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100643
644 auto inputShape = inputTensorInfo.GetShape();
645 auto outputShape = outputTensorInfo.GetShape();
646
647 auto inputNumDimensions = inputShape.GetNumDimensions();
648 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
649
650 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
651
652 // 1D input shape results in scalar output shape
653 if (inputShape.GetNumDimensions() == 1)
654 {
655 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
656 {
657 throw InvalidArgumentException(descriptorName + outputShapeError);
658 }
659 }
660 else
661 {
662 for (unsigned int i = 0; i < unsignedAxis; ++i)
663 {
664 if (outputShape[i] != inputShape[i])
665 {
666 throw InvalidArgumentException(descriptorName + outputShapeError);
667 }
668 }
669
670 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
671 {
672 if (outputShape[i - 1] != inputShape[i])
673 {
674 throw InvalidArgumentException(descriptorName + outputShapeError);
675 }
676 }
677 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100678}
679
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100680void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
681{
682 const std::string descriptorName{"SoftmaxQueueDescriptor"};
683
684 ValidateNumInputs(workloadInfo, descriptorName, 1);
685 ValidateNumOutputs(workloadInfo, descriptorName, 1);
686
687 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
688 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
689
690 std::vector<DataType> supportedTypes =
691 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000692 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100693 DataType::Float16,
694 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000695 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000696 DataType::QAsymmU8,
697 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100698 };
699
700 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
701 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
702 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
703}
704
telsoa014fcda012018-03-09 14:13:49 +0000705void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
706{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100707 const std::string descriptorName{"SplitterQueueDescriptor"};
708
709 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000710
Ruomei Yan25339c32019-05-28 16:48:20 +0100711 // Check the supported data types
712 std::vector<DataType> supportedTypes =
713 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000714 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100715 DataType::Float32,
716 DataType::Float16,
717 DataType::Boolean,
718 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100719 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000720 DataType::QAsymmU8,
721 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100722 };
723
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100724 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
725 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100726 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100727 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
728 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
729
730 const std::string outputName = "output_" + std::to_string(i);
731 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100732 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100733
telsoa014fcda012018-03-09 14:13:49 +0000734 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
735 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100736 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000737 }
738
739 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
740 {
741 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100742 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000743 "has to match number of workloadInfo.m_OutputTensorInfos. "
744 "Number of windows: " +
745 to_string(m_ViewOrigins.size()) +
746 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
747 }
748
telsoa01c577f2c2018-08-31 09:22:23 +0100749 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000750 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
751 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
752 {
telsoa01c577f2c2018-08-31 09:22:23 +0100753 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000754 ViewOrigin const& e = m_ViewOrigins[w];
755 if (e.m_Origin.size() != inputDims)
756 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100757 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000758 "have the same dimensionality as the input tensor. "
759 "Window origin (index: " +
760 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
761 " dimensions, the input "
762 "tensor has " +
763 to_string(inputDims) + " dimensions.");
764 }
765 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
766 {
767 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
768 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
769 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100770 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000771 "be smaller or equal than the size of the input in that coord.");
772 }
773 }
774 }
775}
776
Jim Flynne242f2d2019-05-22 14:24:13 +0100777void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000778{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100779 const std::string descriptorName{"ConcatQueueDescriptor"};
780
781 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000782
783 if (m_Inputs.size() <= 0)
784 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100785 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000786 }
787 if (m_Outputs.size() <= 0)
788 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100789 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000790 }
791
792 if (workloadInfo.m_InputTensorInfos.size() <= 0)
793 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100794 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000795 }
796 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
797 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100798 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000799 }
800
Nikhil Raj8599a412018-11-19 14:51:07 +0000801 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
802 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100803 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000804 }
805
806 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
807 {
808 return;
809 }
810
telsoa014fcda012018-03-09 14:13:49 +0000811 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
812 {
813 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100814 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000815 "has to match number of workloadInfo.m_InputTensorInfos. "
816 "Number of windows: " +
817 to_string(m_ViewOrigins.size()) +
818 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
819 }
820
telsoa01c577f2c2018-08-31 09:22:23 +0100821 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000822 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
823 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
824 {
telsoa01c577f2c2018-08-31 09:22:23 +0100825 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000826 ViewOrigin const& e = m_ViewOrigins[w];
827 if (e.m_Origin.size() != outputDims)
828 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100829 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000830 "have the same dimensionality as the output tensor. "
831 "Window origin (index: " +
832 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
833 " dimensions, the output "
834 "tensor has " +
835 to_string(outputDims) + " dimensions.");
836 }
telsoa01c577f2c2018-08-31 09:22:23 +0100837 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000838 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
839 {
840 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
841 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
842 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100843 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000844 "be smaller or equal than the size of the output in that coord.");
845 }
846 }
847 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100848
849 // Check the supported data types
850 std::vector<DataType> supportedTypes =
851 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000852 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100853 DataType::Float32,
854 DataType::Float16,
855 DataType::Boolean,
856 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100857 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000858 DataType::QAsymmU8,
859 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100860 };
861
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100862 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
863 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100864 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100865 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
866 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
867
868 const std::string inputName = "input_" + std::to_string(i);
869 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100870 }
telsoa014fcda012018-03-09 14:13:49 +0000871}
872
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100873void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
874{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100875 const std::string descriptorName{"StackQueueDescriptor"};
876
877 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100878
879 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
880 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100881 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100882 }
883
884 // All inputs must have the same shape, which is defined in parameters
885 const TensorShape& inputShape = m_Parameters.m_InputShape;
886 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
887 {
888 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
889 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100890 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100891 }
892 }
893
Matthew Jacksondba634f2019-08-15 15:14:18 +0100894 if (inputShape.GetNumDimensions() > 4)
895 {
896 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
897 }
898
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100899 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
900 // since the output tensor has an additional dimension.
901 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
902 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100903 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100904 "than the number of input dimensions.");
905 }
906
907 // Output shape must be as inferred from the input shape
908 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
909 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
910 {
911 if (outputShape[i] != inputShape[i])
912 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100913 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100914 "match shape inferred from input tensor.");
915 }
916 }
917
918 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
919 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100920 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100921 "match shape inferred from input tensor.");
922 }
923
924 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
925 {
926 if (outputShape[i] != inputShape[i-1])
927 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100928 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100929 "match shape inferred from input tensor.");
930 }
931 }
932
Matthew Jacksondba634f2019-08-15 15:14:18 +0100933 if (outputShape.GetNumDimensions() > 5)
934 {
935 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
936 }
937
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100938 // Check the supported data types
939 std::vector<DataType> supportedTypes =
940 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000941 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100942 DataType::Float32,
943 DataType::Float16,
944 DataType::Boolean,
945 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100946 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000947 DataType::QAsymmU8,
948 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100949 };
950
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100951 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100952
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100953 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100954 {
955 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
956 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100957 descriptorName,
958 "input_0",
959 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100960 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100961
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100962 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
963 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100964 descriptorName,
965 "input_0",
966 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100967}
968
Ryan OSheaec6c6802020-06-05 17:17:06 +0100969void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
970{
971 const std::string descriptorName{"FillQueueDescriptor"};
972
973 ValidateNumInputs(workloadInfo, descriptorName, 1);
974 ValidateNumOutputs(workloadInfo, descriptorName, 1);
975
976 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
977 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
978
979 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
980
981 std::vector<DataType> supportedTypes =
982 {
983 DataType::BFloat16,
984 DataType::Float32,
985 DataType::Float16,
986 DataType::Signed32
987 };
988
989 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
990}
991
telsoa014fcda012018-03-09 14:13:49 +0000992void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
993{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100994 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000995
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100996 ValidateNumInputs(workloadInfo, descriptorName, 1);
997 ValidateNumOutputs(workloadInfo, descriptorName, 1);
998
999 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1000 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1001
1002 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1003
1004 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001005 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001006 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001007 }
1008
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001009 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001010
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001011 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1012 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001013
1014 if (m_Parameters.m_BiasEnabled)
1015 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001016 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001017
telsoa01c577f2c2018-08-31 09:22:23 +01001018 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001019 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
1020 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001021
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001022 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1023 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001024 }
1025
Francis Murtagh46c09d02019-05-28 08:15:28 +01001026 // Check the supported data types
1027 std::vector<DataType> supportedTypes =
1028 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001029 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001030 DataType::Float32,
1031 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001032 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001033 DataType::QAsymmU8,
1034 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001035 };
1036
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001037 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001038
1039 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1040 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1041 {
1042 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1043 {
1044 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1045 "for BFloat16 input.");
1046 }
1047 }
1048 else
1049 {
1050 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1051 }
telsoa014fcda012018-03-09 14:13:49 +00001052}
1053
telsoa014fcda012018-03-09 14:13:49 +00001054void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1055{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001056 const std::string descriptorName{"NormalizationQueueDescriptor"};
1057
1058 ValidateNumInputs(workloadInfo, descriptorName, 1);
1059 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1060
1061 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1062 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001063
1064 // Check the supported data types
1065 std::vector<DataType> supportedTypes =
1066 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001067 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001068 DataType::Float16,
1069 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001070 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001071 DataType::QAsymmU8,
1072 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001073 };
1074
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001075 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001076
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001077 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001078
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001079 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001080}
1081
1082void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1083{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001084 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001085
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001086 ValidateNumInputs(workloadInfo, descriptorName, 2);
1087 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1088
1089 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1090 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1091 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1092
1093 std::vector<DataType> supportedTypes =
1094 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001095 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001096 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001097 DataType::Float16,
1098 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001099 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001100 DataType::QSymmS16,
1101 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001102 };
1103
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001104 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1105 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1106 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001107
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001108 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1109 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001110
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001111 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1112 inputTensorInfo1,
1113 outputTensorInfo,
1114 descriptorName,
1115 "input_0",
1116 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001117}
1118
telsoa014fcda012018-03-09 14:13:49 +00001119void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1120{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001121 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001122
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001123 ValidateNumInputs(workloadInfo, descriptorName, 2);
1124 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1125
1126 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1127 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1128 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1129
1130 std::vector<DataType> supportedTypes =
1131 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001132 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001133 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001134 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001135 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001136 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001137 DataType::QSymmS16,
1138 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001139 };
1140
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001141 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1142 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1143 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001144
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001145 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1146 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001147
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001148 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1149 inputTensorInfo1,
1150 outputTensorInfo,
1151 descriptorName,
1152 "input_0",
1153 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001154}
1155
1156void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1157{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001158 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001159
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001160 ValidateNumInputs(workloadInfo, descriptorName, 1);
1161 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1162
1163 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1164 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001165
1166 std::vector<DataType> supportedTypes =
1167 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001168 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001169 DataType::Float16,
1170 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001171 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001172 DataType::QAsymmU8,
1173 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001174 };
1175
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001176 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1177 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001178
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001179 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001180 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001181
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001182 ValidatePointer(m_Mean, descriptorName, "mean");
1183 ValidatePointer(m_Variance, descriptorName, "variance");
1184 ValidatePointer(m_Beta, descriptorName, "beta");
1185 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001186
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001187 const TensorInfo& mean = m_Mean->GetTensorInfo();
1188 const TensorInfo& variance = m_Variance->GetTensorInfo();
1189 const TensorInfo& beta = m_Beta->GetTensorInfo();
1190 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001191
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001192 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1193 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1194 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1195 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001196
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001197 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1198 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1199 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001200}
1201
1202void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1203{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001204 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001205
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001206 ValidateNumInputs(workloadInfo, descriptorName, 1);
1207 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001208
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001209 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1210 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001211
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001212 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1213 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001214
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001215 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001216
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001217 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1218 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001219
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001220 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001221
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001222 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001223 if (m_Parameters.m_BiasEnabled)
1224 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001225 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001226
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001227 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1228 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001229
1230 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1231 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001232 }
1233
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001234 ValidatePerAxisQuantization(inputTensorInfo,
1235 outputTensorInfo,
1236 weightTensorInfo,
1237 optionalBiasTensorInfo,
1238 descriptorName);
1239
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001240 std::vector<DataType> supportedTypes =
1241 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001242 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001243 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001244 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001245 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001246 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001247 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001248 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001249 };
1250
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001251 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001252
1253 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1254 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1255 {
1256 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1257 {
1258 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1259 "for BFloat16 input.");
1260 }
1261 }
1262 else
1263 {
1264 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1265 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001266}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001267
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001268void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1269{
1270 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1271
1272 ValidateNumInputs(workloadInfo, descriptorName, 1);
1273 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1274
1275 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1276 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1277
1278 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1279 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1280
1281 ValidatePointer(m_Weight, descriptorName, "weight");
1282
1283 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1284 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1285
1286 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1287 {
1288 throw InvalidArgumentException(
1289 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
1290 "cannot be smaller than 1.") % descriptorName %
1291 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
1292 }
1293
1294 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1295
1296 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1297 // inputChannels * channelMultiplier should be equal to outputChannels.
1298 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1299 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1300 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1301 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1302 {
1303 throw InvalidArgumentException(
1304 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
1305 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
1306 "(provided %4%).") % descriptorName % numWeightOutputChannels %
1307 numWeightInputChannels % numWeightChannelMultiplier));
1308 }
1309
Teresa Charlind8df0262019-11-11 12:28:15 +00001310 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001311
Teresa Charlind8df0262019-11-11 12:28:15 +00001312 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001313 if (m_Parameters.m_BiasEnabled)
1314 {
1315 ValidatePointer(m_Bias, descriptorName, "bias");
1316
Teresa Charlind8df0262019-11-11 12:28:15 +00001317 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1318 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001319
1320 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1321 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1322 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001323 ValidatePerAxisQuantization(inputTensorInfo,
1324 outputTensorInfo,
1325 weightTensorInfo,
1326 optionalBiasTensorInfo,
1327 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001328
1329 std::vector<DataType> supportedTypes =
1330 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001331 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001332 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001333 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001334 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001335 DataType::QAsymmU8,
1336 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001337 };
1338
1339 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1340 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001341}
1342
1343void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1344{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001345 const std::string descriptorName{"PermuteQueueDescriptor"};
1346
1347 ValidateNumInputs(workloadInfo, descriptorName, 1);
1348 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001349
1350 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1351
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001352 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1353 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001354
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001355 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1356 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001357
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001358 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001359 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001360 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001361 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001362 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1363 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1364 "must match dst dimension " + to_string(mapping[i]) +
1365 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001366 }
1367 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001368
1369 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001370}
1371
1372void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1373{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001374 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001375
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001376 ValidateNumInputs(workloadInfo, descriptorName, 1);
1377 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1378
1379 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1380 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1381
1382 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1383 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001384
1385 std::vector<DataType> supportedTypes =
1386 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001387 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001388 DataType::Float32,
1389 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001390 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001391 DataType::QAsymmU8,
1392 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001393 };
1394
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001395 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1396 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001397}
1398
1399void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1400{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001401 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001402
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001403 ValidateNumInputs(workloadInfo, descriptorName, 1);
1404 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1405
1406 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1407 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1408
1409 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1410 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001411
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001412 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001413 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001414 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001415 DataType::Float16,
1416 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001417 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001418 DataType::QAsymmU8,
1419 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001420 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001421
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001422 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1423 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001424
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001425 // ResizeBilinear only changes width and height: batch and channel count must match.
1426 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1427 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001428 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001429 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001430 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001431 boost::str(boost::format("%1%: Input batch size (%2%) "
1432 "does not match output batch size (%3%)") %
1433 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001434 }
1435
Teresa Charlin970f43b2019-07-01 13:51:07 +01001436 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001437 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1438 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001439 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001440 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001441 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001442 boost::str(boost::format("%1%: Input channel count (%2%) "
1443 "does not match output channel count (%3%)") %
1444 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001445 }
1446}
1447
1448void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1449{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001450 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001451
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001452 ValidateNumInputs(workloadInfo, descriptorName, 1);
1453 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1454
1455 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1456 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1457
1458 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1459 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001460
1461 std::vector<DataType> supportedTypes =
1462 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001463 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001464 DataType::Float16,
1465 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001466 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001467 DataType::QAsymmU8,
1468 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001469 };
1470
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001471 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1472 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001473
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001474 // Resize only changes width and height: batch and channel count must match.
1475 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1476 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001477 if (inputBatchSize != outputBatchSize)
1478 {
1479 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001480 boost::str(boost::format("%1%: Input batch size (%2%) "
1481 "does not match output batch size (%3%)") %
1482 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001483 }
1484
1485 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001486 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1487 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001488 if (inputChannelCount != outputChannelCount)
1489 {
1490 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001491 boost::str(boost::format("%1%: Input channel count (%2%) "
1492 "does not match output channel count (%3%)") %
1493 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001494 }
1495}
1496
1497void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1498{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001499 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001500
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001501 ValidateNumInputs(workloadInfo, descriptorName, 1);
1502 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1503
1504 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1505 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1506
1507 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1508 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1509
1510 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1511
telsoa014fcda012018-03-09 14:13:49 +00001512 if (m_Parameters.m_Min > m_Parameters.m_Max)
1513 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001514 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001515 }
telsoa014fcda012018-03-09 14:13:49 +00001516}
1517
Kevin Mayce5045a2019-10-02 14:07:47 +01001518void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1519{
1520 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1521
1522 ValidateNumInputs(workloadInfo, descriptorName, 1);
1523 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1524
1525 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1526 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1527
1528 if (inputTensorInfo.GetNumDimensions() > 4)
1529 {
1530 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1531 }
1532
1533 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1534
1535 // Check the supported data types
1536 std::vector<DataType> supportedTypes =
1537 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001538 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001539 DataType::Float32,
1540 DataType::Float16
1541 };
1542
1543 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001544 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001545}
1546
telsoa014fcda012018-03-09 14:13:49 +00001547void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1548{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001549 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001550
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001551 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001552 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1553
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001554 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1555 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1556
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001557 if (inputTensorInfo.GetNumDimensions() > 4)
1558 {
1559 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1560 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001561
1562 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001563
1564 // Check the supported data types
1565 std::vector<DataType> supportedTypes =
1566 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001567 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001568 DataType::Float32,
1569 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001570 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001571 DataType::QAsymmU8,
1572 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001573 };
1574
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001575 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001576 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1577}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001578
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001579void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1580{
1581 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1582
1583 ValidateNumInputs(workloadInfo, descriptorName, 1);
1584 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1585
1586 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1587 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1588
1589 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1590
1591 std::vector<DataType> supportedTypes =
1592 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001593 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001594 DataType::Float32,
1595 DataType::Float16,
1596 };
1597
1598 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001599 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001600}
1601
1602void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1603{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001604 const std::string descriptorName{"ConstantQueueDescriptor"};
1605
1606 ValidateNumInputs(workloadInfo, descriptorName, 0);
1607 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001608
1609 if (!m_LayerOutput)
1610 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001611 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001612 }
1613
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001614 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1615 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001616
1617 // Check the supported data types
1618 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001619 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001620 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001621 DataType::Float32,
1622 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001623 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001624 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001625 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001626 DataType::QSymmS16,
1627 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001628 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001629
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001630 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001631}
1632
1633void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1634{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001635 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001636
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001637 ValidateNumInputs(workloadInfo, descriptorName, 1);
1638 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1639
1640 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1641 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1642
1643 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001644
1645 // Check the supported data types
1646 std::vector<DataType> supportedTypes =
1647 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001648 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001649 DataType::Float32,
1650 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001651 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001652 DataType::QAsymmU8,
1653 DataType::QSymmS16,
1654 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001655 };
1656
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001657 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1658 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001659}
1660
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001661void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1662{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001663 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001664
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001665 ValidateNumInputs(workloadInfo, descriptorName, 1);
1666 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1667
1668 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1669 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1670
1671 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1672 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001673
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001674 if (m_Parameters.m_BlockShape.size() != 2)
1675 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001676 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001677 }
1678
1679 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1680 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001681 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1682 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001683 }
1684
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001685 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001686
1687 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001688 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001689
Matthew Bentham8800c002018-11-19 13:19:28 +00001690 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001691
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001692 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1693 widthPad.first + widthPad.second;
1694 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1695 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001696
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001697 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1698 inputShape[dimensionIndices.GetChannelsIndex()];
1699 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001700
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001701 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001702 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001703 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001704 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001705 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001706 }
1707
1708 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001709 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001710 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1711 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001712 }
nikraj01120522a2019-05-31 11:33:07 +01001713
1714 std::vector<DataType> supportedTypes =
1715 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001716 DataType::BFloat16,
1717 DataType::Float16,
1718 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001719 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001720 DataType::QAsymmU8,
1721 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001722 };
1723
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001724 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1725 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001726}
1727
Keith Davisa57eccb2019-06-14 17:33:22 +01001728void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1729{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001730 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001731
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001732 ValidateNumInputs(workloadInfo, descriptorName, 1);
1733 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001734
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001735 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1736 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1737
1738 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1739 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001740
1741 std::vector<DataType> supportedTypes =
1742 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001743 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001744 DataType::Float32,
1745 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001746 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001747 DataType::QAsymmU8,
1748 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001749 };
1750
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001751 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1752 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001753
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001754 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1755
1756 if (m_Parameters.m_BlockSize == 0)
1757 {
1758 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1759 }
1760
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001761 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1762 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1763 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1764 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001765
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001766 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001767 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001768 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001769 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1770 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001771 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001772
1773 const TensorShape& outputShape = outputTensorInfo.GetShape();
1774 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1775 {
1776 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1777 "must be divisible by the square of block size." );
1778 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001779}
1780
telsoa014fcda012018-03-09 14:13:49 +00001781void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1782{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001783 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001784
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001785 ValidateNumInputs(workloadInfo, descriptorName, 1);
1786 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1787
1788 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1789 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001790
1791 std::vector<DataType> supportedTypes =
1792 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001793 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001794 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001795 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001796 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001797 };
1798
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001799 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001800
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001801 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001802 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001803 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001804 }
1805}
1806
telsoa01c577f2c2018-08-31 09:22:23 +01001807void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1808{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001809 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1810
1811 const std::string descriptorName{"LstmQueueDescriptor"};
1812
1813 // check dimensions of all inputs and outputs
1814 if (workloadInfo.m_InputTensorInfos.size() != 3)
1815 {
1816 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1817 }
1818 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1819 {
1820 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1821 }
1822
1823 std::vector<DataType> supportedTypes =
1824 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001825 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001826 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001827 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001828 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001829 };
1830
Jan Eilers38e05bd2019-06-26 13:10:09 +01001831 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001832 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1833
Jan Eilers38e05bd2019-06-26 13:10:09 +01001834 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001835 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001836 {
1837 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1838 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001839 descriptorName,
1840 "input_0",
1841 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001842 }
1843 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001844 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001845 {
1846 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1847 workloadInfo.m_OutputTensorInfos[i],
1848 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001849 "input_0",
1850 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001851 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001852
janeil0117d8d852019-11-15 15:00:16 +00001853 // Making sure clipping parameters have valid values.
1854 // == 0 means no clipping
1855 // > 0 means clipping
1856 if (m_Parameters.m_ClippingThresCell < 0.0f)
1857 {
1858 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
1859 }
1860 if (m_Parameters.m_ClippingThresProj < 0.0f)
1861 {
1862 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
1863 }
1864
Jan Eilers38e05bd2019-06-26 13:10:09 +01001865
1866 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01001867 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1868 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1869 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1870 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1871 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1872 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1873
Jan Eilers38e05bd2019-06-26 13:10:09 +01001874 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001875 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1876 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001877 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001878 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1879 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001880 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001881 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1882 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001883 // scratchBufferTensor
1884 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001885 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1886 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001887 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001888 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1889 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001890 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001891 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1892 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001893 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001894 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1895 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001896
1897
1898 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1899 if ( m_InputToInputWeights )
1900 {
1901 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1902 (n_cell * n_input), "InputLayerNormWeights");
1903 }
1904
1905 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1906 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1907 (n_cell * n_input), "InputToForgetWeights");
1908
1909 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1910 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1911 (n_cell * n_input), "InputToCellWeights");
1912
1913 if ( m_RecurrentToInputWeights )
1914 {
1915 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1916 (n_cell * n_output), "RecurrentToInputWeights");
1917 }
1918
1919 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1920 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1921 (n_cell * n_output), "RecurrentToForgetWeights");
1922
1923 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1924 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1925 (n_cell * n_output), "RecurrentToCellWeights");
1926
1927 // Make sure the input-gate's parameters are either both present (regular
1928 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1929 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1930 !m_Parameters.m_CifgEnabled) ||
1931 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1932 m_Parameters.m_CifgEnabled));
1933 if (!cifg_weights_all_or_none)
1934 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001935 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1936 "RecurrentToInputWeights must either both be present (regular LSTM) "
1937 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1938 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001939 }
1940
1941 if ( m_CellToInputWeights )
1942 {
1943 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1944 n_cell, "CellToInputWeights");
1945 }
1946 if ( m_CellToForgetWeights )
1947 {
1948 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1949 n_cell, "CellToForgetWeights");
1950 }
1951 if ( m_CellToOutputWeights )
1952 {
1953 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1954 n_cell, "CellToOutputWeights");
1955 }
1956
1957 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1958 bool peephole_weights_all_or_none =
1959 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1960 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1961 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1962 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1963 if (!peephole_weights_all_or_none)
1964 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001965 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001966 }
1967
1968 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1969 if (m_Parameters.m_CifgEnabled)
1970 {
1971 if (m_InputGateBias)
1972 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001973 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001974 }
1975 }
1976 else
1977 {
1978 if (!m_InputGateBias)
1979 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001980 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1981 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001982 }
1983 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1984 n_cell, "InputGateBias");
1985 }
1986
1987 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1988 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1989
1990 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1991 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1992
1993 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1994 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1995
1996 if (m_ProjectionWeights)
1997 {
1998 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
1999 (n_cell * n_output), "ProjectionWeights");
2000 }
2001 if (m_ProjectionBias)
2002 {
2003 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2004 }
2005
2006 // Making sure the projection tensors are consistent:
2007 // 1) If projection weight is not present, then projection bias should not be
2008 // present.
2009 // 2) If projection weight is present, then projection bias is optional.
2010 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2011 !m_Parameters.m_ProjectionEnabled)
2012 || (m_ProjectionWeights && !m_ProjectionBias &&
2013 m_Parameters.m_ProjectionEnabled)
2014 || (m_ProjectionWeights && m_ProjectionBias &&
2015 m_Parameters.m_ProjectionEnabled));
2016 if (!projecton_tensors_consistent)
2017 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002018 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002019 }
2020
2021 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2022 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2023 // either all have values or none of them have values. Layer normalization is used when the values of all the
2024 // layer normalization weights are present
2025 if (m_InputLayerNormWeights)
2026 {
2027 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2028 }
2029 if (m_ForgetLayerNormWeights)
2030 {
2031 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2032 }
2033 if (m_CellLayerNormWeights)
2034 {
2035 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2036 }
2037 if (m_OutputLayerNormWeights)
2038 {
2039 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2040 }
2041
Jan Eilers38e05bd2019-06-26 13:10:09 +01002042 if (m_Parameters.m_LayerNormEnabled)
2043 {
2044 if (!m_Parameters.m_CifgEnabled)
2045 {
2046 if (!m_InputLayerNormWeights)
2047 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002048 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2049 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002050 }
2051 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2052 1, n_cell, "InputLayerNormWeights");
2053 }
2054 else if (m_InputLayerNormWeights)
2055 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002056 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2057 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002058 }
2059
2060 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2061 "ForgetLayerNormWeights");
2062 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2063
2064 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2065 "OutputLayerNormWeights");
2066 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2067
2068 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2069 "CellLayerNormWeights");
2070 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2071 }
2072 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2073 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002074 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2075 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002076 }
telsoa01c577f2c2018-08-31 09:22:23 +01002077}
2078
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002079void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2080{
2081 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2082
2083 ValidateNumInputs(workloadInfo, descriptorName, 1);
2084 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2085
2086 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2087 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2088
2089 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2090 {
2091 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2092 }
2093
2094 if (outputTensorInfo.GetDataType() != DataType::Float32)
2095 {
2096 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2097 }
2098
2099 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2100}
2101
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002102void ConvertFp32ToBf16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2103{
2104 const std::string descriptorName{"ConvertFp32ToBf16QueueDescriptor"};
2105
2106 ValidateNumInputs(workloadInfo, descriptorName, 1);
2107 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2108
2109 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2110 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2111
2112 if (inputTensorInfo.GetDataType() != DataType::Float32)
2113 {
2114 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2115 }
2116
2117 if (outputTensorInfo.GetDataType() != DataType::BFloat16)
2118 {
2119 throw InvalidArgumentException(descriptorName + ": Output tensor type must be BFloat16.");
2120 }
2121
2122 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2123}
2124
telsoa01c577f2c2018-08-31 09:22:23 +01002125void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2126{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002127 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002128
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002129 ValidateNumInputs(workloadInfo, descriptorName, 1);
2130 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2131
2132 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2133 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2134
2135 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002136 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002137 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002138 }
2139
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002140 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002141 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002142 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002143 }
2144
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002145 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002146}
2147
2148void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2149{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002150 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002151
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002152 ValidateNumInputs(workloadInfo, descriptorName, 1);
2153 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2154
2155 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2156 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2157
2158 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002159 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002160 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002161 }
2162
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002163 if (outputTensorInfo.GetDataType() != DataType::Float32)
2164 {
2165 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2166 }
2167
2168 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002169}
2170
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002171void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2172{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002173 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002174
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002175 ValidateNumInputs(workloadInfo, descriptorName, 2);
2176 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2177
2178 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2179 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2180 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2181
2182 std::vector<DataType> supportedTypes =
2183 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002184 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002185 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002186 DataType::Float32,
2187 DataType::QAsymmS8,
2188 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002189 DataType::QSymmS16,
2190 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002191 };
2192
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002193 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2194 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2195 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002196
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002197 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2198 inputTensorInfo1,
2199 outputTensorInfo,
2200 descriptorName,
2201 "input_0",
2202 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002203}
2204
David Beckc2044fe2018-09-05 15:00:38 +01002205void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2206{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002207 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002208
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002209 ValidateNumInputs(workloadInfo, descriptorName, 2);
2210 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2211
2212 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2213 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2214 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2215
2216 std::vector<DataType> supportedTypes =
2217 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002218 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002219 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002220 DataType::Float32,
2221 DataType::QAsymmS8,
2222 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002223 DataType::QSymmS16,
2224 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002225 };
2226
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002227 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2228 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2229 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002230
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002231 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2232 inputTensorInfo1,
2233 outputTensorInfo,
2234 descriptorName,
2235 "input_0",
2236 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002237}
2238
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002239void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2240{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002241 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002242
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002243 ValidateNumInputs(workloadInfo, descriptorName, 2);
2244 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2245
2246 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2247 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2248 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2249
2250 std::vector<DataType> supportedTypes =
2251 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002252 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002253 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002254 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002255 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002256 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002257 DataType::QSymmS16,
2258 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002259 };
2260
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002261 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2262 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2263 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002264
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002265 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2266 inputTensorInfo1,
2267 outputTensorInfo,
2268 descriptorName,
2269 "input_0",
2270 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002271}
2272
narpra01a6bf9122018-09-10 09:50:09 +01002273void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2274{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002275 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002276
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002277 ValidateNumInputs(workloadInfo, descriptorName, 1);
2278 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2279
2280 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2281 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002282
2283 std::vector<DataType> supportedTypes =
2284 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002285 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002286 DataType::Float32,
2287 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002288 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002289 DataType::QAsymmU8,
2290 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002291 };
narpra01eb061912018-09-10 17:35:27 +01002292
James Conroy4d1ff582019-06-10 17:06:39 +01002293 // First check if input tensor data type is supported, then
2294 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002295 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2296 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002297
narpra0132b90462018-09-13 11:07:48 +01002298 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002299 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002300 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002301 }
narpra0132b90462018-09-13 11:07:48 +01002302 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002303 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002304 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002305 }
2306 else
2307 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002308 unsigned int outputDim =
2309 inputTensorInfo.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
2310 ValidateTensorNumDimensions(outputTensorInfo,
2311 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002312 outputDim > 0 ? outputDim : 1,
2313 "output");
2314 }
narpra01a6bf9122018-09-10 09:50:09 +01002315}
2316
jimfly012c9322a2018-09-19 10:59:49 +01002317void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2318{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002319 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002320
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002321 ValidateNumInputs(workloadInfo, descriptorName, 1);
2322 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2323
2324 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2325 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002326
jimfly012c9322a2018-09-19 10:59:49 +01002327 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002328 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2329
jimfly012c9322a2018-09-19 10:59:49 +01002330 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002331 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2332 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2333 "as there are dimensions in the input tensor that is " +
2334 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2335 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002336 }
2337}
2338
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002339void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2340{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002341 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002342
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002343 ValidateNumInputs(workloadInfo, descriptorName, 1);
2344 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002345
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002346 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2347 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2348
Sadik Armagan2208b602019-07-31 16:36:27 +01002349 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002350 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002351 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002352 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002353 DataType::Float16,
2354 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002355 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002356 DataType::QAsymmU8,
2357 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002358 };
2359
2360 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002361
Keith Davis0c2eeac2020-02-11 16:51:50 +00002362 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002363 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002364 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002365 }
2366}
2367
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002368void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2369{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002370 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002371
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002372 ValidateNumInputs(workloadInfo, descriptorName, 1);
2373 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002374
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002375 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2376 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002377
2378 std::vector<DataType> supportedTypes =
2379 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002380 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002381 DataType::Float32,
2382 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002383 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002384 DataType::QAsymmU8,
2385 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002386 };
2387
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002388 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2389 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002390}
2391
Conor Kennedy430b5d82018-11-14 15:28:28 +00002392void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2393{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002394 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002395
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002396 ValidateNumInputs(workloadInfo, descriptorName, 1);
2397 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2398
2399 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2400 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002401
2402 std::vector<DataType> supportedTypes =
2403 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002404 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002405 DataType::Float16,
2406 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002407 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002408 DataType::QAsymmU8,
2409 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002410 };
2411
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002412 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2413 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002414
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002415 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002416
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002417 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002418 if (rank > 4)
2419 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002420 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002421 }
2422
Conor Kennedy430b5d82018-11-14 15:28:28 +00002423 // Begin, End & Stride length must be of rank(input0)
2424 if (m_Parameters.m_Begin.size() != rank)
2425 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002426 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002427 }
2428
2429 if (m_Parameters.m_End.size() != rank)
2430 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002431 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002432 }
2433
2434 if (m_Parameters.m_Stride.size() != rank)
2435 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002436 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002437 }
2438
2439 // Stride entries must be non-zero
2440 for (auto& stride : m_Parameters.m_Stride)
2441 {
2442 if (stride == 0)
2443 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002444 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002445 }
2446 }
2447}
2448
kevmay0190539692018-11-29 08:40:19 +00002449void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2450{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002451 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002452
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002453 ValidateNumInputs(workloadInfo, descriptorName, 2);
2454 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2455
2456 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2457 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2458 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2459
2460 std::vector<DataType> supportedTypes =
2461 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002462 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002463 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002464 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002465 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002466 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002467 DataType::QSymmS16,
2468 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002469 };
2470
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002471 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2472 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2473 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002474
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002475 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2476 inputTensorInfo1,
2477 outputTensorInfo,
2478 descriptorName,
2479 "input_0",
2480 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002481}
2482
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002483void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2484{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002485 const std::string descriptorName{"DebugQueueDescriptor"};
2486
2487 ValidateNumInputs(workloadInfo, descriptorName, 1);
2488 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002489}
2490
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002491void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2492{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002493 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002494
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002495 ValidateNumInputs(workloadInfo, descriptorName, 2);
2496 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002497
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002498 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2499 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2500 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2501
2502 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2503 inputTensorInfo1,
2504 outputTensorInfo,
2505 descriptorName,
2506 "input_0",
2507 "input_1");
2508
2509 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002510 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002511 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002512 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002513}
2514
FrancisMurtagh878f0232018-12-19 10:56:15 +00002515void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2516{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002517 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002518
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002519 ValidateNumInputs(workloadInfo, descriptorName, 2);
2520 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002521
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002522 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2523 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2524 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2525
2526 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2527 inputTensorInfo1,
2528 outputTensorInfo,
2529 descriptorName,
2530 "input_0",
2531 "input_1");
2532
2533 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002534 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002535 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002536 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002537}
2538
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002539void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2540{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002541 const std::string descriptorName{"RsqrtQueueDescriptor"};
2542
2543 ValidateNumInputs(workloadInfo, descriptorName, 1);
2544 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2545
2546 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2547 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2548
2549 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002550
2551 std::vector<DataType> supportedTypes =
2552 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002553 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002554 DataType::Float16,
2555 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002556 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002557 DataType::QAsymmU8,
2558 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002559 };
2560
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002561 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2562 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002563}
2564
narpra01b89b05f2019-01-16 09:53:09 +00002565void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2566{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002567 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002568
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002569 ValidateNumInputs(workloadInfo, descriptorName, 2);
2570 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002571
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002572 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2573 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002574 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002575 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002576 }
2577
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002578 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2579 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2580
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002581 std::vector<DataType> supportedTypes =
2582 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002583 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002584 DataType::Float16,
2585 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002586 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002587 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002588 DataType::QSymmS16,
2589 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002590 };
2591
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002592 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002593
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002594 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002595
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002596 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2597 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002598}
2599
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002600void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2601{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002602 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2603
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002604 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002605
2606 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2607 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002608 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002609 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2610 }
2611
2612 if (m_Anchors == nullptr)
2613 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002614 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002615 }
2616
2617 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002618 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2619 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2620
2621 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002622 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002623 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2624 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002625
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002626 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2627 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2628 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002629
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002630 const std::vector<DataType> supportedInputTypes =
2631 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002632 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002633 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002634 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002635 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002636 DataType::QAsymmU8,
2637 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002638 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002639
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002640 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2641 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2642 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2643
2644 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2645 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2646 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2647 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2648
2649 // NOTE: Output is always Float32 regardless of input type
2650 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2651 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2652 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2653 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002654
2655 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2656 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002657 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002658 "must be positive and less than or equal to 1.");
2659 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002660
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002661 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2662 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002663 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002664 "should be equal to number of classes + 1.");
2665 }
2666}
2667
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002668void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2669{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002670 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002671
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002672 ValidateNumInputs(workloadInfo, descriptorName, 1);
2673 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2674
2675 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2676 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2677
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002678 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002679 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002680 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002681 }
2682
Sadik Armagan2208b602019-07-31 16:36:27 +01002683 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002684 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002685 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002686 DataType::Float32,
2687 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002688 };
2689
2690 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002691}
2692
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002693void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2694{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002695 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002696
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002697 ValidateNumInputs(workloadInfo, descriptorName, 2);
2698 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002699
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002700 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2701 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2702 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002703
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002704 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2705 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2706
2707 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2708 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002709}
2710
Sadik Armaganeff363d2019-04-05 15:25:46 +01002711void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2712{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002713 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002714
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002715 ValidateNumInputs(workloadInfo, descriptorName, 2);
2716 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2717
2718 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2719 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2720
2721 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2722 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2723
2724 std::vector<DataType> supportedTypes =
2725 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002726 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002727 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002728 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002729 DataType::QAsymmU8,
2730 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002731 };
2732
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002733 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2734 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002735
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002736 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2737 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002738
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002739 ValidateTensorShapesMatch(inputTensorInfo0,
2740 outputTensorInfo0,
2741 descriptorName,
2742 "input_0",
2743 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002744
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002745 ValidateTensorShapesMatch(inputTensorInfo0,
2746 outputTensorInfo1,
2747 descriptorName,
2748 "input_0",
2749 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002750}
2751
Derek Lamberti901ea112019-12-10 22:07:09 +00002752void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002753{
2754 // This is internally generated so it should not need validation.
2755}
2756
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002757void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2758{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002759 const std::string& descriptorName{"PreluQueueDescriptor"};
2760
2761 ValidateNumInputs(workloadInfo, descriptorName, 2);
2762 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2763
2764 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2765 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2766 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002767
2768 std::vector<DataType> supportedTypes
2769 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002770 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002771 DataType::Float16,
2772 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002773 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002774 DataType::QAsymmU8,
2775 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002776 };
2777
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002778 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2779 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002780
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002781 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002782
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002783 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2784 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002785
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002786 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2787 alphaTensorInfo,
2788 outputTensorInfo,
2789 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002790 "input",
2791 "alpha");
2792}
2793
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002794void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2795{
2796 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2797
2798 ValidateNumInputs(workloadInfo, descriptorName, 1);
2799 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2800
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002801 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2802 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2803
2804 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2805 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002806
2807 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002808
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002809 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2810 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002811
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002812 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2813
2814 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002815 if (m_Parameters.m_BiasEnabled)
2816 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002817 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002818
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002819 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2820 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002821
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002822 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002823 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002824 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002825
2826 ValidatePerAxisQuantization(inputTensorInfo,
2827 outputTensorInfo,
2828 weightTensorInfo,
2829 optionalBiasTensorInfo,
2830 descriptorName);
2831
2832 std::vector<DataType> supportedTypes =
2833 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002834 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002835 DataType::Float32,
2836 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002837 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002838 DataType::QAsymmU8,
2839 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002840 };
2841
2842 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2843 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002844}
2845
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002846void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2847{
2848 const std::string descriptorName{"TransposeQueueDescriptor"};
2849
2850 ValidateNumInputs(workloadInfo, descriptorName, 1);
2851 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2852
2853 const PermutationVector& mapping = m_Parameters.m_DimMappings;
2854
2855 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2856 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2857
2858 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
2859 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
2860
2861 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
2862 {
2863 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
2864 {
2865 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
2866 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
2867 "must match dst dimension " + to_string(i) +
2868 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
2869 }
2870 }
2871
2872 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2873}
2874
James Conroy4f1f8992020-04-29 20:01:10 +01002875void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2876{
2877 const std::string descriptorName{"QLstmQueueDescriptor"};
2878
2879 // Validate number of inputs/outputs
2880 ValidateNumInputs(workloadInfo, descriptorName, 3);
2881 ValidateNumOutputs(workloadInfo, descriptorName, 3);
2882
2883 // Input/output tensor info
2884 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2885 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
2886 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
2887
2888 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2889 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2890 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
2891
2892 // Supported types for various tensors in QLSTM
2893 std::vector<DataType> inputOutputSupportedTypes =
2894 {
2895 DataType::QAsymmS8
2896 };
2897
2898 std::vector<DataType> cellStateSupportedTypes =
2899 {
2900 DataType::QSymmS16
2901 };
2902
2903 std::vector<DataType> weightsSupportedTypes =
2904 {
2905 DataType::QSymmS8
2906 };
2907
2908 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
2909 {
2910 DataType::QSymmS16
2911 };
2912
2913 std::vector<DataType> biasSupportedTypes =
2914 {
2915 DataType::Signed32
2916 };
2917
2918 // Validate types of input/output tensors
2919 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2920 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2921 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2922
2923 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2924 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2925 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
2926
2927 // Validate matching types of input/output tensors
2928 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2929 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2930 "outputStateIn", "outputStateOut");
2931 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2932
2933 // Infer number of batches, number of units, input size and output size from tensor dimensions
2934 const uint32_t numBatches = inputInfo.GetShape()[0];
2935 const uint32_t inputSize = inputInfo.GetShape()[1];
2936 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
2937 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
2938
2939 // Validate number of dimensions and number of elements for input/output tensors
2940 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2941 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2942 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
2943
2944 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2945 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
2946 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
2947
2948 // Validate number of dimensions and number of elements for MANDATORY weight tensors
2949 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2950 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2951 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
2952
2953 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2954 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2955 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
2956
2957 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2958 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2959 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
2960
2961 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2962 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2963 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
2964 " RecurrentToForgetWeights");
2965
2966 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2967 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2968 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
2969
2970 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2971 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2972 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
2973
2974 // Validate data types for MANDATORY weights tensors (all should match each other)
2975 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
2976
2977 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
2978 "inputToForgetWeights", "inputToCellWeights");
2979 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2980 "inputToForgetWeights", "inputToOutputWeights");
2981
2982 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2983 "inputToForgetWeights", "recurrentToForgeteights");
2984 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2985 "inputToForgetWeights", "recurrentToCellWeights");
2986 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2987 "inputToForgetWeights", "recurrentToOutputWeights");
2988
2989 // Validate number of dimensions and number of elements for MANDATORY bias tensors
2990 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2991 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2992 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
2993
2994 ValidatePointer(m_CellBias, descriptorName, "CellBias");
2995 auto cellBiasInfo = m_CellBias->GetTensorInfo();
2996 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
2997
2998 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
2999 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3000 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3001
3002 // Validate data types for MANDATORY bias tensors
3003 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3004
3005 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3006 "forgetGateBias", "cellBias");
3007 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3008 "forgetGateBias", "outputGateBias");
3009
3010 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3011 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3012 !m_Parameters.m_CifgEnabled) ||
3013 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3014 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3015
3016 if (!allCifgParamsPresentOrNot)
3017 {
3018 throw InvalidArgumentException(descriptorName +
3019 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3020 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3021 "set appropriately.");
3022 }
3023
3024 if (!m_Parameters.m_CifgEnabled)
3025 {
3026 // Validate number of dimensions and number of elements
3027 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3028 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3029
3030 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3031 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3032 " RecurrentToInputWeights");
3033
3034 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3035 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3036
3037 // Validate data types
3038 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3039 "inputToForgetWeights", "inputToInputWeights");
3040 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3041 "inputToForgetWeights", "recurrentToInputWeights");
3042 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3043 "forgetGateBias", "inputGateBias");
3044 }
3045
3046 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3047 bool allPeepholeWeightsPresentOrNot =
3048 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3049 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3050 || (!m_CellToInputWeights && !m_CellToForgetWeights
3051 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3052
3053 if (!allPeepholeWeightsPresentOrNot)
3054 {
3055 throw InvalidArgumentException(descriptorName +
3056 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3057 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3058 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3059 "appropriately.");
3060 }
3061
3062 if (m_Parameters.m_PeepholeEnabled)
3063 {
3064 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3065 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3066 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3067
3068 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3069 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3070 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3071 "cellToForgetWeight", "cellToOutputWeights");
3072
3073 if (!m_Parameters.m_CifgEnabled)
3074 {
3075 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3076 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3077 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3078 "cellToForgetWeights", "cellToInputWeights");
3079 }
3080 }
3081
3082 // Validate OPTIONAL params: Layer Norm Weights
3083 bool allLayerNormWeightsPresentOrNot =
3084 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3085 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3086 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3087 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3088
3089 if (!allLayerNormWeightsPresentOrNot)
3090 {
3091 throw InvalidArgumentException(descriptorName +
3092 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3093 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3094 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3095 "only be present when Layer Norm is enabled and CIFG is disabled. "
3096 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3097 }
3098
3099 if (m_Parameters.m_LayerNormEnabled)
3100 {
3101 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3102 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3103 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3104
3105 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3106 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3107 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3108 "forgetLayerNormWeights", "cellLayerNormWeights");
3109
3110 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3111 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3112 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3113 "forgetLayerNormWeights", "outputLayerNormWeights");
3114
3115 if (!m_Parameters.m_CifgEnabled)
3116 {
3117 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3118 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3119 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3120 "forgetLayerNormWeights", "inputLayerNormWeights");
3121 }
3122 }
3123
3124 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3125 bool correctProjectionTensorsPresent =
3126 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3127 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3128 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3129
3130 if (!correctProjectionTensorsPresent)
3131 {
3132 throw InvalidArgumentException(descriptorName +
3133 ": If projection is enabled, ProjectionWeights should be present and "
3134 "ProjectionBias is optional. If projection is disabled, neither "
3135 "ProjectionWeights nor ProjectionBias should be present.");
3136 }
3137
3138 if (m_Parameters.m_ProjectionEnabled)
3139 {
3140 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3141 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3142 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3143
3144 if (m_ProjectionBias)
3145 {
3146 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003147 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003148 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3149 }
3150
3151 }
3152 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3153 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3154 throw InvalidArgumentException(descriptorName +
3155 ": If projection is disabled, output quantization info (scale, offset) "
3156 "should match HiddenStateScale and HiddenStateZeroPoint.");
3157 }
3158
3159}
3160
James Conroy9c3cae82019-08-01 16:01:48 +01003161void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3162{
3163 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3164
3165 // Validate number of inputs/outputs
3166 ValidateNumInputs(workloadInfo, descriptorName, 3);
3167 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3168
3169 // Input/output tensor infos
3170 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3171 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3172 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3173
3174 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3175 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3176
3177 std::vector<DataType> inputOutputSupportedTypes =
3178 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003179 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003180 };
3181
3182 std::vector<DataType> cellStateSupportedTypes =
3183 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003184 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003185 };
3186
3187 std::vector<DataType> weightsSupportedTypes =
3188 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003189 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003190 };
3191
3192 std::vector<DataType> biasSupportedTypes =
3193 {
3194 DataType::Signed32
3195 };
3196
3197 // Validate types of input/output tensors
3198 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3199 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3200 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3201
3202 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3203 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3204
3205 // Validate matching types of input/output tensors
3206 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3207 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3208 "outputStateIn", "outputStateOut");
3209 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3210
3211 // Validate matching quantization info for input/output tensors
3212 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3213 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3214 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003215
James Conroy9c3cae82019-08-01 16:01:48 +01003216 // Infer number of batches, input size and output size from tensor dimensions
3217 const uint32_t numBatches = inputInfo.GetShape()[0];
3218 const uint32_t inputSize = inputInfo.GetShape()[1];
3219 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3220
3221 // Validate number of dimensions and number of elements for input/output tensors
3222 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3223 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3224 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3225 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3226 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3227
3228 // Validate number of dimensions and number of elements for weights tensors
3229 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3230 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3231 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3232
3233 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3234 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3235 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3236
3237 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3238 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3239 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3240
3241 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3242 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3243 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3244
3245 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3246 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3247 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3248
3249 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3250 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3251 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3252 " RecurrentToForgetWeights");
3253
3254 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3255 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3256 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3257
3258 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3259 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3260 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3261
3262 // Validate data types for weights tensors (all should match each other)
3263 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3264
3265 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3266 "inputToInputWeights", "inputToForgetWeights");
3267 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3268 "inputToInputWeights", "inputToCellWeights");
3269 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3270 "inputToInputWeights", "inputToOutputWeights");
3271
3272 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3273 "inputToInputWeights", "recurrentToInputWeights");
3274 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3275 "inputToInputWeights", "recurrentToForgeteights");
3276 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3277 "inputToInputWeights", "recurrentToCellWeights");
3278 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3279 "inputToInputWeights", "recurrentToOutputWeights");
3280
3281 // Validate matching quantization info for weight tensors (all should match each other)
3282 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3283 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3284 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3285 descriptorName, "inputToInputWeights", "inputToCellWeights");
3286 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3287 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3288
3289 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3290 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3291 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3292 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3293 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3294 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3295 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3296 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3297
3298 // Validate number of dimensions and number of elements in bias tensors
3299 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3300 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3301 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3302
3303 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3304 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3305 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3306
3307 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3308 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3309 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3310
3311 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3312 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3313 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3314
3315 // Validate data types for bias tensors (all should match each other)
3316 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3317
3318 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3319 "inputGateBias", "forgetGateBias");
3320 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3321 "inputGateBias", "cellBias");
3322 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3323 "inputGateBias", "outputGateBias");
3324
3325 // Validate bias tensor quantization info
3326 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3327 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3328 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3329 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3330}
3331
Kevin May868eb142019-09-04 17:29:31 +01003332void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3333{
3334 const std::string descriptorName{"AbsQueueDescriptor"};
3335
3336 ValidateNumInputs(workloadInfo, descriptorName, 1);
3337 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3338
3339 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3340 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3341
3342 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3343
3344 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003345 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003346 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003347 DataType::Float16,
3348 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003349 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003350 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003351 DataType::QSymmS16,
3352 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003353 };
Kevin May868eb142019-09-04 17:29:31 +01003354
3355 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3356 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3357}
3358
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003359void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3360{
3361 const std::string descriptorName{"SliceQueueDescriptor"};
3362
3363 ValidateNumInputs(workloadInfo, descriptorName, 1);
3364 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3365
3366 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3367 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3368
3369 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3370
3371 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3372 if (rank > 4)
3373 {
3374 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3375 }
3376
3377 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3378
3379 // Check if m_Begin and m_Size have the expected length
3380 if (m_Parameters.m_Begin.size() != rank)
3381 {
3382 throw InvalidArgumentException(descriptorName +
3383 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3384 }
3385 if (m_Parameters.m_Size.size() != rank)
3386 {
3387 throw InvalidArgumentException(descriptorName +
3388 ": Length of size descriptor must equal rank " + std::to_string(rank));
3389 }
3390
3391 // Check if the shape of the output tensor matches m_Size
3392 const TensorShape& outputShape = outputTensorInfo.GetShape();
3393 for (unsigned int i = 0u; i < rank; ++i)
3394 {
3395 if (m_Parameters.m_Size[i] != outputShape[i])
3396 {
3397 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3398 }
3399 }
3400
3401 // Check if the sum of begin offset and size in a given dimension
3402 // does not exceed the size of corresponding input
3403 const TensorShape& inputShape = inputTensorInfo.GetShape();
3404 for(unsigned int i = 0u; i < rank; ++i)
3405 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003406 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003407 {
3408 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3409 std::to_string(i) + " exceeds input size.");
3410 }
3411 }
3412}
3413
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003414void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3415{
3416 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3417
3418 ValidateNumInputs(workloadInfo, descriptorName, 1);
3419 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3420
3421 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3422 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3423
3424 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3425 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3426
3427 std::vector<DataType> supportedTypes =
3428 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003429 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003430 DataType::Float32,
3431 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003432 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003433 DataType::QAsymmU8,
3434 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003435 };
3436
3437 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3438 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3439
3440 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3441
3442 if (m_Parameters.m_BlockSize == 0)
3443 {
3444 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3445 }
3446
3447 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3448 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3449 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3450 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3451
3452 const TensorShape& outputShape = outputInfo.GetShape();
3453 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3454 {
3455 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3456 "must be divisible by block size.");
3457 }
3458
3459 const TensorShape& inputShape = inputInfo.GetShape();
3460 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3461 {
3462 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3463 "must be divisible by the square of block size." );
3464 }
3465}
3466
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003467void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3468{
3469 const std::string descriptorName{"ComparisonQueueDescriptor"};
3470
3471 ValidateNumInputs(workloadInfo, descriptorName, 2);
3472 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3473
3474 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3475 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3476 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3477
3478 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3479 inputTensorInfo1,
3480 outputTensorInfo,
3481 descriptorName,
3482 "input_0",
3483 "input_1");
3484
3485 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3486 {
3487 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3488 }
3489}
3490
josh minor4a3c6102020-01-06 16:40:46 -06003491void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3492{
3493 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3494
3495 ValidateNumInputs(workloadInfo, descriptorName, 1);
3496 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3497
3498 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3499 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3500
3501 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3502
3503 std::vector<DataType> supportedTypes =
3504 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003505 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003506 DataType::Float16,
3507 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003508 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003509 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003510 DataType::QSymmS16,
3511 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003512 };
3513
3514 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3515 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3516}
3517
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003518} // namespace armnn