blob: b39d6b3c4c41220ef45cd6b1698746e74d896634 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00006#include <backendsCommon/WorkloadData.hpp>
7#include <backendsCommon/CpuTensorHandle.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/TensorUtils.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010010#include <armnn/utility/NumericCast.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000011
telsoa014fcda012018-03-09 14:13:49 +000012#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000014#include <string>
15#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000016
James Ward47fce872020-09-10 11:57:28 +010017#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000018
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
telsoa01c577f2c2018-08-31 09:22:23 +010029 case DataType::Float16:
30 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000031 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000032 case DataType::Float32:
33 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000034 case DataType::QAsymmS8:
35 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000036 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000037 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000038 case DataType::QSymmS8:
39 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000040 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010041 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000042 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010043 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000044 return DataType::Float32;
45 }
46}
47
48namespace
49{
50
51//---------------------------------------------------------------
52//android ndk does not support std::to_string function.
53template <typename T>
54std::string to_string(T value)
55{
56 std::ostringstream os;
57 os << value;
58 return os.str();
59}
60
61//---------------------------------------------------------------
62void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
63{
64 if (!ptr)
65 {
66 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
67 paramName + " parameter must be set.");
68 }
69}
70
71//---------------------------------------------------------------
72void ValidateTensorShapesMatch(const TensorInfo& first,
73 const TensorInfo& second,
74 std::string const& descName,
75 std::string const& firstName,
76 std::string const& secondName)
77{
78 if (first.GetShape() != second.GetShape())
79 {
80 throw InvalidArgumentException(descName + ": "
81 + firstName + " & " + secondName + " must have identical shapes");
82 }
83}
84
85//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010086void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000087{
Sadik Armaganeff363d2019-04-05 15:25:46 +010088 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000089 {
90 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010091 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000092 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
93 }
94}
95
96//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010097void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000098{
Sadik Armaganeff363d2019-04-05 15:25:46 +010099 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000100 {
101 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100102 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000103 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
104 }
105}
106
107//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100108void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000109 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100110 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000111 std::string const& tensorName)
112{
113 if (tensor.GetNumDimensions() != numDimensions)
114 {
115 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
116 to_string(tensor.GetNumDimensions()) + " dimensions for " +
117 tensorName + " tensor.");
118 }
119}
120
121//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100122void ValidateTensorNumElements(const TensorInfo& tensor,
123 std::string const& descName,
124 unsigned int numElements,
125 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100126{
127 if (tensor.GetNumElements() != numElements)
128 {
129 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100130 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100131 tensorName + " tensor.");
132 }
133}
134
135//---------------------------------------------------------------
136void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100137 unsigned int numDimension,
138 unsigned int numElements,
139 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100140{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100141 const std::string functionName{"ValidateTensorNumDimNumElem"};
142 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
143 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100144}
145
146//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000147void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
148 const std::string& descName, std::string const& tensorName)
149{
150 if (tensor.GetDataType() != dataType)
151 {
152 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
153 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
154 }
155}
156
Derek Lambertid466a542020-01-22 15:37:29 +0000157void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
158{
159 ARMNN_NO_DEPRECATE_WARN_BEGIN
160 if (tensor.GetDataType() != DataType::QSymmS8 &&
161 tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
162 {
163 throw InvalidArgumentException(descName +
164 ": Expected data type which supports per-axis quantization scheme but got " +
165 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
166 }
167 ARMNN_NO_DEPRECATE_WARN_END
168}
169
telsoa014fcda012018-03-09 14:13:49 +0000170//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100171void ValidateTensorQuantizationSpace(const TensorInfo& first,
172 const TensorInfo& second,
173 const std::string& descName,
174 std::string const& firstName,
175 std::string const& secondName)
176{
177 if (!first.IsQuantized() ||
178 !second.IsQuantized())
179 {
180 // Not a quantized type, ignore the validation
181 return;
182 }
183
184 DataType firstDataType = first.GetDataType();
185 DataType secondDataType = second.GetDataType();
186
187 if (firstDataType != secondDataType)
188 {
189 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
190 " must be of the same quantized type, " +
191 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
192 secondName + " is " + GetDataTypeName(secondDataType));
193 }
194
195 if (!first.IsTypeSpaceMatch(second))
196 {
197 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
198 " must have the same quantization space, " +
199 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
200 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
201 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
202 " and scale " + to_string(second.GetQuantizationScale()));
203 }
204}
205
206//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100207void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
208 const TensorInfo& inputTensorInfo,
209 const TensorInfo& weightsTensorInfo,
210 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000211{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000212 // Helper lambda function to validate a single bias quantization scale value
213 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
214 {
ricbur013f4d7102019-10-31 16:22:18 +0000215 constexpr float tolerance = 0.000001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000216 if (std::abs(biasScale - expectedScale) > tolerance)
217 {
218 // Print the float values with extra precision to see very small differences
219 std::stringstream msg;
220 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
221 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
222 biasScale;
223 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
224 }
225 };
226
telsoa014fcda012018-03-09 14:13:49 +0000227 if (biasTensor.GetQuantizationOffset() != 0)
228 {
229 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
230 to_string(biasTensor.GetQuantizationOffset()));
231 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000232
233 if (biasTensor.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000234 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000235 // Validate per-axis quantization scales
236 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
237 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
238
239 if (weightScales.size() != biasScales.size())
240 {
241 std::stringstream msg;
242 msg << descName << ": Expected matchhing number of per-axis quantization scales, but got different "
243 << "values: weights=" << weightScales.size() << ", biases=" << biasScales.size();
244 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
245 }
246
247 for (size_t i = 0ul; i < biasScales.size(); ++i)
248 {
249 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
250 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
251 }
252 }
253 else
254 {
255 // Validate per-tensor quantization scale
256 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
257 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000258 }
259}
260
261//---------------------------------------------------------------
262void ValidateTensors(const std::vector<ITensorHandle*>& vec,
263 unsigned int numExpected,
264 const std::string& descName,
265 const std::string& varName)
266{
267 if (vec.empty() && numExpected > 0)
268 {
269 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
270 }
271
272 for (unsigned int i = 0; i < numExpected; ++i)
273 {
274 if (!vec[i])
275 {
276 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
277 }
278 }
279}
280
281//---------------------------------------------------------------
282void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
283 const TensorInfo& second,
284 const TensorInfo& output,
285 std::string const& descName,
286 std::string const& firstName,
287 std::string const& secondName)
288{
289 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
290 // broadcasted.
291 if (first.GetNumDimensions() != second.GetNumDimensions())
292 {
293 throw InvalidArgumentException(descName + ": Tensors "
294 + firstName + " & " + secondName
295 + " must have the same number of dimensions in order to be broadcasted");
296 }
297 uint32_t numDims = first.GetNumDimensions();
298 std::vector<uint32_t> outputDims(numDims, 0u);
299 for (uint32_t i = 0; i < numDims; i++)
300 {
301 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
302 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
303 if (dimsNotEqual && dimsNotOne)
304 {
305 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
306 }
307 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
308 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100309 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000310 if (broadcastShape != output.GetShape())
311 {
312 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
313 + firstName + " & " + secondName
314 + " does not match the output shape");
315 }
316}
317
318//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100319void ValidateDataTypes(const TensorInfo& info,
320 const std::vector<armnn::DataType>& supportedTypes,
321 std::string const& descName)
322{
323 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
324 if (iterator == supportedTypes.end())
325 {
326 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
327 }
328}
329
James Conroy4d1ff582019-06-10 17:06:39 +0100330//---------------------------------------------------------------
331void ValidateTensorDataTypesMatch(const TensorInfo& first,
332 const TensorInfo& second,
333 std::string const& descName,
334 std::string const& firstName,
335 std::string const& secondName)
336{
337 if (first.GetDataType() != second.GetDataType())
338 {
339 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
340 " must have identical data types.");
341 }
342}
343
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100344//---------------------------------------------------------------
345void ValidateTensorNumElementsMatch(const TensorInfo& first,
346 const TensorInfo& second,
347 std::string const& descName,
348 std::string const& firstName,
349 std::string const& secondName)
350{
351 if (first.GetNumElements() != second.GetNumElements())
352 {
353 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
354 " must have the same number of elements.");
355 }
356}
357
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000358void ValidateWeightDataType(const TensorInfo& inputInfo,
359 const TensorInfo& weightInfo,
360 const std::string& descName)
361{
362 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000363 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000364 {
Derek Lambertid466a542020-01-22 15:37:29 +0000365 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000366 const std::vector<DataType> validTypes =
367 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000368 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100369 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000370 DataType::QSymmS8,
371 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000372 };
Derek Lambertid466a542020-01-22 15:37:29 +0000373 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000374
375 ValidateDataTypes(weightInfo, validTypes, descName);
376 }
377 else
378 {
379 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
380 }
381}
382
383void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
384 const std::string& descName,
385 const std::string& tensorName)
386{
387 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
388 if (!quantizationDim.has_value())
389 {
James Ward47fce872020-09-10 11:57:28 +0100390 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
391 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000392 }
393
394 if (quantizationDim.value() != 0)
395 {
James Ward47fce872020-09-10 11:57:28 +0100396 throw InvalidArgumentException(fmt::format(
397 "{0}: Quantization dimension for per-axis quantization expected to be 0 on tensor {1}, "
398 "but got: {2}", descName, tensorName, quantizationDim.value()));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000399 }
400}
401
402void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
403 const std::string& descName,
404 const std::string& tensorName)
405{
406 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
407 if (quantizationOffset != 0)
408 {
James Ward47fce872020-09-10 11:57:28 +0100409 throw InvalidArgumentException(fmt::format(
410 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
411 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000412 }
413}
414
415void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
416 const TensorInfo& outputInfo,
417 const TensorInfo& weightInfo,
418 const Optional<TensorInfo>& optionalBiasInfo,
419 const std::string& descName)
420{
421 if (weightInfo.HasPerAxisQuantization())
422 {
423 const DataType inputDataType = inputInfo.GetDataType();
424 const DataType outputDataType = outputInfo.GetDataType();
425
Keith Davis0c2eeac2020-02-11 16:51:50 +0000426 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000427
428 if (!canHavePerAxisQuantization)
429 {
James Ward47fce872020-09-10 11:57:28 +0100430 throw InvalidArgumentException(fmt::format(
431 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
432 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000433 }
434
Derek Lambertid466a542020-01-22 15:37:29 +0000435
436 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000437 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
438 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
439
440 if (optionalBiasInfo.has_value())
441 {
442 const TensorInfo& biasInfo = optionalBiasInfo.value();
443 if (!biasInfo.HasPerAxisQuantization())
444 {
James Ward47fce872020-09-10 11:57:28 +0100445 throw InvalidArgumentException(fmt::format(
446 "{}: Per-axis quantization parameters not set on bias tensor, "
447 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000448 }
449
450 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
451 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
452 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
453 }
454 }
455}
456
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100457} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000458
459void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
460 unsigned int numExpectedIn, unsigned int numExpectedOut) const
461{
462 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
463 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
464}
465
466//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100467void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
468{
469 const std::string descriptorName{"MapQueueDescriptor"};
470
471 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100472 ValidateNumOutputs(workloadInfo, descriptorName, 0);
473
474 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
475 {
476 if (!m_Inputs[i])
477 {
478 throw InvalidArgumentException(
479 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
480 }
481 }
482}
483
484//---------------------------------------------------------------
485void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
486{
487 const std::string descriptorName{"UnmapQueueDescriptor"};
488
489 ValidateNumInputs(workloadInfo, descriptorName, 1);
490 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100491
492 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
493 {
494 if (!m_Inputs[i])
495 {
496 throw InvalidArgumentException(
497 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
498 }
499 }
500}
501
502//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000503void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
504{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100505 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000506
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100507 ValidateNumInputs(workloadInfo, descriptorName, 1);
508 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000509
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100510 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
511 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
512
513 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
514 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000515
516 if (m_Inputs.size() != m_Outputs.size())
517 {
James Ward47fce872020-09-10 11:57:28 +0100518 throw InvalidArgumentException(fmt::format(
519 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
520 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000521 }
522
523 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
524 {
525 if (!m_Inputs[i])
526 {
James Ward47fce872020-09-10 11:57:28 +0100527 throw InvalidArgumentException(fmt::format(
528 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000529 }
530
531 if (!m_Outputs[i])
532 {
James Ward47fce872020-09-10 11:57:28 +0100533 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000534 }
535 }
536}
537
Derek Lambertif674aa02019-08-01 15:56:25 +0100538//---------------------------------------------------------------
539void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
540{
541 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
542 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
543
544 if (workloadInfo.m_InputTensorInfos.size() != 1)
545 {
James Ward47fce872020-09-10 11:57:28 +0100546 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
547 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100548
549 }
550
551 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
552 {
James Ward47fce872020-09-10 11:57:28 +0100553 throw InvalidArgumentException(fmt::format(
554 "Number of input infos ({0}) does not match the number of output infos ({1})",
555 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100556 }
557
558 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
559 {
560 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
561 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
562 {
James Ward47fce872020-09-10 11:57:28 +0100563 throw InvalidArgumentException(fmt::format(
564 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100565 }
566 }
567
568 if (m_Inputs.size() != 1)
569 {
James Ward47fce872020-09-10 11:57:28 +0100570 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100571 }
572
573 if (m_Inputs.size() != m_Outputs.size())
574 {
James Ward47fce872020-09-10 11:57:28 +0100575 throw InvalidArgumentException(fmt::format(
576 "Number of inputs ({0}) does not match the number of outputs ({1})",
577 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100578 }
579
580 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
581 {
582 if (!m_Inputs[i])
583 {
James Ward47fce872020-09-10 11:57:28 +0100584 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100585 }
586
587 if (!m_Outputs[i])
588 {
James Ward47fce872020-09-10 11:57:28 +0100589 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100590 }
591 }
592}
593
594//---------------------------------------------------------------
595void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
596{
597 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
598 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
599
Derek Lambertif674aa02019-08-01 15:56:25 +0100600 if (m_Inputs.size() != 1)
601 {
James Ward47fce872020-09-10 11:57:28 +0100602 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100603 }
604
605 if (m_Outputs.size() != 0)
606 {
James Ward47fce872020-09-10 11:57:28 +0100607 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100608 }
609
610 if (!m_Inputs[0])
611 {
James Ward47fce872020-09-10 11:57:28 +0100612 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100613 }
614}
615
616//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000617void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
618{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100619 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100620
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100621 ValidateNumInputs(workloadInfo, descriptorName, 1);
622 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100623
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100624 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
625 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100626
627 std::vector<DataType> supportedTypes =
628 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000629 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100630 DataType::Float16,
631 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000632 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000633 DataType::QAsymmU8,
634 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100635 };
636
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100637 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
638 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
639 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000640}
641
Nikhil Rajee391d52019-09-05 17:50:44 +0100642void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
643{
644 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
645
646 ValidateNumInputs(workloadInfo, descriptorName, 1);
647 ValidateNumOutputs(workloadInfo, descriptorName, 1);
648
649 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
650 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
651
Inki Daed4619e22020-09-10 15:33:54 +0900652 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
653 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100654 {
Inki Daed4619e22020-09-10 15:33:54 +0900655 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100656 }
657
James Conroyd47a0642019-09-17 14:22:06 +0100658 std::vector<DataType> supportedInputTypes =
659 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000660 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100661 DataType::Float16,
662 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100663 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000664 DataType::QAsymmU8,
665 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900666 DataType::Signed32,
667 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100668 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100669
James Conroyd47a0642019-09-17 14:22:06 +0100670 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100671
672 auto inputShape = inputTensorInfo.GetShape();
673 auto outputShape = outputTensorInfo.GetShape();
674
675 auto inputNumDimensions = inputShape.GetNumDimensions();
676 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
677
678 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
679
680 // 1D input shape results in scalar output shape
681 if (inputShape.GetNumDimensions() == 1)
682 {
683 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
684 {
685 throw InvalidArgumentException(descriptorName + outputShapeError);
686 }
687 }
688 else
689 {
690 for (unsigned int i = 0; i < unsignedAxis; ++i)
691 {
692 if (outputShape[i] != inputShape[i])
693 {
694 throw InvalidArgumentException(descriptorName + outputShapeError);
695 }
696 }
697
698 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
699 {
700 if (outputShape[i - 1] != inputShape[i])
701 {
702 throw InvalidArgumentException(descriptorName + outputShapeError);
703 }
704 }
705 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100706}
707
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100708void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
709{
710 const std::string descriptorName{"SoftmaxQueueDescriptor"};
711
712 ValidateNumInputs(workloadInfo, descriptorName, 1);
713 ValidateNumOutputs(workloadInfo, descriptorName, 1);
714
715 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
716 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
717
718 std::vector<DataType> supportedTypes =
719 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000720 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100721 DataType::Float16,
722 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000723 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000724 DataType::QAsymmU8,
725 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100726 };
727
728 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
729 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
730 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
731}
732
telsoa014fcda012018-03-09 14:13:49 +0000733void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
734{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100735 const std::string descriptorName{"SplitterQueueDescriptor"};
736
737 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000738
Ruomei Yan25339c32019-05-28 16:48:20 +0100739 // Check the supported data types
740 std::vector<DataType> supportedTypes =
741 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000742 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100743 DataType::Float32,
744 DataType::Float16,
745 DataType::Boolean,
746 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100747 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000748 DataType::QAsymmU8,
749 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100750 };
751
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100752 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
753 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100754 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100755 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
756 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
757
758 const std::string outputName = "output_" + std::to_string(i);
759 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100760 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100761
telsoa014fcda012018-03-09 14:13:49 +0000762 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
763 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100764 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000765 }
766
767 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
768 {
769 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100770 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000771 "has to match number of workloadInfo.m_OutputTensorInfos. "
772 "Number of windows: " +
773 to_string(m_ViewOrigins.size()) +
774 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
775 }
776
telsoa01c577f2c2018-08-31 09:22:23 +0100777 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000778 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
779 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
780 {
telsoa01c577f2c2018-08-31 09:22:23 +0100781 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000782 ViewOrigin const& e = m_ViewOrigins[w];
783 if (e.m_Origin.size() != inputDims)
784 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100785 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000786 "have the same dimensionality as the input tensor. "
787 "Window origin (index: " +
788 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
789 " dimensions, the input "
790 "tensor has " +
791 to_string(inputDims) + " dimensions.");
792 }
793 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
794 {
795 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
796 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
797 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100798 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000799 "be smaller or equal than the size of the input in that coord.");
800 }
801 }
802 }
803}
804
Jim Flynne242f2d2019-05-22 14:24:13 +0100805void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000806{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100807 const std::string descriptorName{"ConcatQueueDescriptor"};
808
809 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000810
811 if (m_Inputs.size() <= 0)
812 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100813 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000814 }
815 if (m_Outputs.size() <= 0)
816 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100817 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000818 }
819
820 if (workloadInfo.m_InputTensorInfos.size() <= 0)
821 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100822 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000823 }
824 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
825 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100826 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000827 }
828
Nikhil Raj8599a412018-11-19 14:51:07 +0000829 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
830 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100831 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000832 }
833
834 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
835 {
836 return;
837 }
838
telsoa014fcda012018-03-09 14:13:49 +0000839 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
840 {
841 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100842 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000843 "has to match number of workloadInfo.m_InputTensorInfos. "
844 "Number of windows: " +
845 to_string(m_ViewOrigins.size()) +
846 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
847 }
848
telsoa01c577f2c2018-08-31 09:22:23 +0100849 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000850 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
851 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
852 {
telsoa01c577f2c2018-08-31 09:22:23 +0100853 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000854 ViewOrigin const& e = m_ViewOrigins[w];
855 if (e.m_Origin.size() != outputDims)
856 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100857 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000858 "have the same dimensionality as the output tensor. "
859 "Window origin (index: " +
860 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
861 " dimensions, the output "
862 "tensor has " +
863 to_string(outputDims) + " dimensions.");
864 }
telsoa01c577f2c2018-08-31 09:22:23 +0100865 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000866 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
867 {
868 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
869 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
870 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100871 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000872 "be smaller or equal than the size of the output in that coord.");
873 }
874 }
875 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100876
877 // Check the supported data types
878 std::vector<DataType> supportedTypes =
879 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000880 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100881 DataType::Float32,
882 DataType::Float16,
883 DataType::Boolean,
884 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100885 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000886 DataType::QAsymmU8,
887 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100888 };
889
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100890 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
891 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100892 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100893 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
894 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
895
896 const std::string inputName = "input_" + std::to_string(i);
897 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100898 }
telsoa014fcda012018-03-09 14:13:49 +0000899}
900
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100901void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
902{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100903 const std::string descriptorName{"StackQueueDescriptor"};
904
905 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100906
907 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
908 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100909 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100910 }
911
912 // All inputs must have the same shape, which is defined in parameters
913 const TensorShape& inputShape = m_Parameters.m_InputShape;
914 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
915 {
916 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
917 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100918 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100919 }
920 }
921
Matthew Jacksondba634f2019-08-15 15:14:18 +0100922 if (inputShape.GetNumDimensions() > 4)
923 {
924 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
925 }
926
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100927 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
928 // since the output tensor has an additional dimension.
929 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
930 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100931 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100932 "than the number of input dimensions.");
933 }
934
935 // Output shape must be as inferred from the input shape
936 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
937 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
938 {
939 if (outputShape[i] != inputShape[i])
940 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100941 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100942 "match shape inferred from input tensor.");
943 }
944 }
945
946 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
947 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100948 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100949 "match shape inferred from input tensor.");
950 }
951
952 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
953 {
954 if (outputShape[i] != inputShape[i-1])
955 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100956 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100957 "match shape inferred from input tensor.");
958 }
959 }
960
Matthew Jacksondba634f2019-08-15 15:14:18 +0100961 if (outputShape.GetNumDimensions() > 5)
962 {
963 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
964 }
965
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100966 // Check the supported data types
967 std::vector<DataType> supportedTypes =
968 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000969 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100970 DataType::Float32,
971 DataType::Float16,
972 DataType::Boolean,
973 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100974 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000975 DataType::QAsymmU8,
976 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100977 };
978
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100979 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100980
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100981 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100982 {
983 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
984 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100985 descriptorName,
986 "input_0",
987 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100988 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100989
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100990 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
991 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100992 descriptorName,
993 "input_0",
994 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100995}
996
Ryan OSheaec6c6802020-06-05 17:17:06 +0100997void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
998{
999 const std::string descriptorName{"FillQueueDescriptor"};
1000
1001 ValidateNumInputs(workloadInfo, descriptorName, 1);
1002 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1003
1004 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1005 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1006
1007 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1008
1009 std::vector<DataType> supportedTypes =
1010 {
1011 DataType::BFloat16,
1012 DataType::Float32,
1013 DataType::Float16,
1014 DataType::Signed32
1015 };
1016
1017 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1018}
1019
telsoa014fcda012018-03-09 14:13:49 +00001020void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1021{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001022 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001023
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001024 ValidateNumInputs(workloadInfo, descriptorName, 1);
1025 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1026
1027 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1028 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1029
1030 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1031
1032 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001033 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001034 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001035 }
1036
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001037 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001038
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001039 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1040 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001041
1042 if (m_Parameters.m_BiasEnabled)
1043 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001044 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001045
telsoa01c577f2c2018-08-31 09:22:23 +01001046 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001047 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
1048 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001049
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001050 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1051 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001052 }
1053
Francis Murtagh46c09d02019-05-28 08:15:28 +01001054 // Check the supported data types
1055 std::vector<DataType> supportedTypes =
1056 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001057 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001058 DataType::Float32,
1059 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001060 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001061 DataType::QAsymmU8,
1062 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001063 };
1064
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001065 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001066
1067 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1068 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1069 {
1070 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1071 {
1072 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1073 "for BFloat16 input.");
1074 }
1075 }
1076 else
1077 {
1078 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1079 }
telsoa014fcda012018-03-09 14:13:49 +00001080}
1081
telsoa014fcda012018-03-09 14:13:49 +00001082void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1083{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001084 const std::string descriptorName{"NormalizationQueueDescriptor"};
1085
1086 ValidateNumInputs(workloadInfo, descriptorName, 1);
1087 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1088
1089 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1090 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001091
1092 // Check the supported data types
1093 std::vector<DataType> supportedTypes =
1094 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001095 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001096 DataType::Float16,
1097 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001098 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001099 DataType::QAsymmU8,
1100 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001101 };
1102
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001103 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001104
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001105 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001106
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001107 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001108}
1109
1110void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1111{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001112 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001113
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001114 ValidateNumInputs(workloadInfo, descriptorName, 2);
1115 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1116
1117 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1118 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1119 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1120
1121 std::vector<DataType> supportedTypes =
1122 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001123 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001124 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001125 DataType::Float16,
1126 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001127 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001128 DataType::QSymmS16,
1129 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001130 };
1131
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001132 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1133 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1134 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001135
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001136 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1137 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001138
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001139 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1140 inputTensorInfo1,
1141 outputTensorInfo,
1142 descriptorName,
1143 "input_0",
1144 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001145}
1146
telsoa014fcda012018-03-09 14:13:49 +00001147void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1148{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001149 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001150
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001151 ValidateNumInputs(workloadInfo, descriptorName, 2);
1152 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1153
1154 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1155 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1156 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1157
1158 std::vector<DataType> supportedTypes =
1159 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001160 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001161 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001162 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001163 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001164 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001165 DataType::QSymmS16,
1166 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001167 };
1168
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001169 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1170 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1171 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001172
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001173 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1174 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001175
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001176 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1177 inputTensorInfo1,
1178 outputTensorInfo,
1179 descriptorName,
1180 "input_0",
1181 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001182}
1183
1184void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1185{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001186 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001187
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001188 ValidateNumInputs(workloadInfo, descriptorName, 1);
1189 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1190
1191 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1192 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001193
1194 std::vector<DataType> supportedTypes =
1195 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001196 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001197 DataType::Float16,
1198 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001199 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001200 DataType::QAsymmU8,
1201 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001202 };
1203
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001204 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1205 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001206
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001207 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001208 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001209
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001210 ValidatePointer(m_Mean, descriptorName, "mean");
1211 ValidatePointer(m_Variance, descriptorName, "variance");
1212 ValidatePointer(m_Beta, descriptorName, "beta");
1213 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001214
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001215 const TensorInfo& mean = m_Mean->GetTensorInfo();
1216 const TensorInfo& variance = m_Variance->GetTensorInfo();
1217 const TensorInfo& beta = m_Beta->GetTensorInfo();
1218 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001219
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001220 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1221 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1222 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1223 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001224
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001225 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1226 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1227 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001228}
1229
1230void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1231{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001232 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001233
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001234 ValidateNumInputs(workloadInfo, descriptorName, 1);
1235 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001236
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001237 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1238 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001239
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001240 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1241 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001242
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001243 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001244
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001245 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1246 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001247
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001248 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001249
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001250 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001251 if (m_Parameters.m_BiasEnabled)
1252 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001253 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001254
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001255 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1256 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001257
1258 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1259 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001260 }
1261
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001262 ValidatePerAxisQuantization(inputTensorInfo,
1263 outputTensorInfo,
1264 weightTensorInfo,
1265 optionalBiasTensorInfo,
1266 descriptorName);
1267
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001268 std::vector<DataType> supportedTypes =
1269 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001270 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001271 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001272 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001273 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001274 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001275 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001276 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001277 };
1278
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001279 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001280
1281 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1282 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1283 {
1284 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1285 {
1286 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1287 "for BFloat16 input.");
1288 }
1289 }
1290 else
1291 {
1292 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1293 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001294}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001295
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001296void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1297{
1298 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1299
1300 ValidateNumInputs(workloadInfo, descriptorName, 1);
1301 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1302
1303 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1304 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1305
1306 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1307 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1308
1309 ValidatePointer(m_Weight, descriptorName, "weight");
1310
1311 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1312 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1313
1314 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1315 {
1316 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001317 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1318 "cannot be smaller than 1.",
1319 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001320 }
1321
1322 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1323
1324 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1325 // inputChannels * channelMultiplier should be equal to outputChannels.
1326 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1327 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1328 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1329 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1330 {
James Ward47fce872020-09-10 11:57:28 +01001331 throw InvalidArgumentException(fmt::format(
1332 "{0}: output_channels (provided {1}) should be equal to input_channels (provided {2}) "
1333 "multiplied by channel_multiplier (provided {3}).",
1334 descriptorName, numWeightOutputChannels, numWeightInputChannels, numWeightChannelMultiplier));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001335 }
1336
Teresa Charlind8df0262019-11-11 12:28:15 +00001337 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001338
Teresa Charlind8df0262019-11-11 12:28:15 +00001339 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001340 if (m_Parameters.m_BiasEnabled)
1341 {
1342 ValidatePointer(m_Bias, descriptorName, "bias");
1343
Teresa Charlind8df0262019-11-11 12:28:15 +00001344 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1345 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001346
1347 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1348 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1349 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001350 ValidatePerAxisQuantization(inputTensorInfo,
1351 outputTensorInfo,
1352 weightTensorInfo,
1353 optionalBiasTensorInfo,
1354 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001355
1356 std::vector<DataType> supportedTypes =
1357 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001358 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001359 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001360 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001361 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001362 DataType::QAsymmU8,
1363 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001364 };
1365
1366 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1367 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001368}
1369
1370void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1371{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001372 const std::string descriptorName{"PermuteQueueDescriptor"};
1373
1374 ValidateNumInputs(workloadInfo, descriptorName, 1);
1375 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001376
1377 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1378
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001379 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1380 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001381
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001382 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1383 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001384
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001385 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001386 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001387 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001388 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001389 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1390 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1391 "must match dst dimension " + to_string(mapping[i]) +
1392 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001393 }
1394 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001395
1396 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001397}
1398
1399void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1400{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001401 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001402
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001403 ValidateNumInputs(workloadInfo, descriptorName, 1);
1404 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1405
1406 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1407 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1408
1409 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1410 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001411
1412 std::vector<DataType> supportedTypes =
1413 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001414 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001415 DataType::Float32,
1416 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001417 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001418 DataType::QAsymmU8,
1419 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001420 };
1421
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001422 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1423 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001424}
1425
1426void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1427{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001428 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001429
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001430 ValidateNumInputs(workloadInfo, descriptorName, 1);
1431 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1432
1433 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1434 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1435
1436 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1437 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001438
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001439 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001440 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001441 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001442 DataType::Float16,
1443 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001444 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001445 DataType::QAsymmU8,
1446 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001447 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001448
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001449 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1450 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001451
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001452 // ResizeBilinear only changes width and height: batch and channel count must match.
1453 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1454 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001455 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001456 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001457 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001458 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1459 descriptorName, inputBatchSize, outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001460 }
1461
Teresa Charlin970f43b2019-07-01 13:51:07 +01001462 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001463 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1464 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001465 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001466 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001467 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001468 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1469 descriptorName, inputChannelCount, outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001470 }
1471}
1472
1473void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1474{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001475 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001476
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001477 ValidateNumInputs(workloadInfo, descriptorName, 1);
1478 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1479
1480 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1481 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1482
1483 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1484 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001485
1486 std::vector<DataType> supportedTypes =
1487 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001488 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001489 DataType::Float16,
1490 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001491 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001492 DataType::QAsymmU8,
1493 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001494 };
1495
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001496 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1497 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001498
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001499 // Resize only changes width and height: batch and channel count must match.
1500 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1501 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001502 if (inputBatchSize != outputBatchSize)
1503 {
1504 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001505 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1506 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001507 }
1508
1509 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001510 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1511 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001512 if (inputChannelCount != outputChannelCount)
1513 {
1514 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001515 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1516 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001517 }
1518}
1519
1520void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1521{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001522 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001523
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001524 ValidateNumInputs(workloadInfo, descriptorName, 1);
1525 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1526
1527 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1528 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1529
1530 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1531 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1532
1533 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1534
telsoa014fcda012018-03-09 14:13:49 +00001535 if (m_Parameters.m_Min > m_Parameters.m_Max)
1536 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001537 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001538 }
telsoa014fcda012018-03-09 14:13:49 +00001539}
1540
Kevin Mayce5045a2019-10-02 14:07:47 +01001541void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1542{
1543 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1544
1545 ValidateNumInputs(workloadInfo, descriptorName, 1);
1546 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1547
1548 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1549 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1550
1551 if (inputTensorInfo.GetNumDimensions() > 4)
1552 {
1553 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1554 }
1555
1556 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1557
1558 // Check the supported data types
1559 std::vector<DataType> supportedTypes =
1560 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001561 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001562 DataType::Float32,
1563 DataType::Float16
1564 };
1565
1566 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001567 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001568}
1569
telsoa014fcda012018-03-09 14:13:49 +00001570void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1571{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001572 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001573
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001574 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001575 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1576
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001577 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1578 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1579
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001580 if (inputTensorInfo.GetNumDimensions() > 4)
1581 {
1582 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1583 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001584
1585 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001586
1587 // Check the supported data types
1588 std::vector<DataType> supportedTypes =
1589 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001590 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001591 DataType::Float32,
1592 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001593 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001594 DataType::QAsymmU8,
1595 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001596 };
1597
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001598 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001599 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1600}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001601
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001602void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1603{
1604 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1605
1606 ValidateNumInputs(workloadInfo, descriptorName, 1);
1607 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1608
1609 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1610 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1611
1612 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1613
1614 std::vector<DataType> supportedTypes =
1615 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001616 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001617 DataType::Float32,
1618 DataType::Float16,
1619 };
1620
1621 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001622 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001623}
1624
1625void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1626{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001627 const std::string descriptorName{"ConstantQueueDescriptor"};
1628
1629 ValidateNumInputs(workloadInfo, descriptorName, 0);
1630 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001631
1632 if (!m_LayerOutput)
1633 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001634 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001635 }
1636
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001637 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1638 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001639
1640 // Check the supported data types
1641 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001642 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001643 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001644 DataType::Float32,
1645 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001646 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001647 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001648 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001649 DataType::QSymmS16,
1650 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001651 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001652
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001653 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001654}
1655
1656void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1657{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001658 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001659
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001660 ValidateNumInputs(workloadInfo, descriptorName, 1);
1661 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1662
1663 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1664 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1665
1666 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001667
1668 // Check the supported data types
1669 std::vector<DataType> supportedTypes =
1670 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001671 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001672 DataType::Float32,
1673 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001674 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001675 DataType::QAsymmU8,
1676 DataType::QSymmS16,
1677 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001678 };
1679
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001680 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1681 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001682}
1683
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001684void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1685{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001686 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001687
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001688 ValidateNumInputs(workloadInfo, descriptorName, 1);
1689 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1690
1691 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1692 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1693
1694 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1695 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001696
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001697 if (m_Parameters.m_BlockShape.size() != 2)
1698 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001699 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001700 }
1701
1702 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1703 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001704 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1705 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001706 }
1707
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001708 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001709
1710 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001711 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001712
Matthew Bentham8800c002018-11-19 13:19:28 +00001713 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001714
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001715 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1716 widthPad.first + widthPad.second;
1717 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1718 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001719
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001720 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1721 inputShape[dimensionIndices.GetChannelsIndex()];
1722 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001723
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001724 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001725 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001726 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001727 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001728 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001729 }
1730
1731 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001732 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001733 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1734 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001735 }
nikraj01120522a2019-05-31 11:33:07 +01001736
1737 std::vector<DataType> supportedTypes =
1738 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001739 DataType::BFloat16,
1740 DataType::Float16,
1741 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001742 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001743 DataType::QAsymmU8,
1744 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001745 };
1746
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001747 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1748 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001749}
1750
Keith Davisa57eccb2019-06-14 17:33:22 +01001751void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1752{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001753 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001754
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001755 ValidateNumInputs(workloadInfo, descriptorName, 1);
1756 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001757
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001758 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1759 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1760
1761 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1762 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001763
1764 std::vector<DataType> supportedTypes =
1765 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001766 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001767 DataType::Float32,
1768 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001769 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001770 DataType::QAsymmU8,
1771 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001772 };
1773
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001774 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1775 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001776
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001777 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1778
1779 if (m_Parameters.m_BlockSize == 0)
1780 {
1781 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1782 }
1783
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001784 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1785 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1786 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1787 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001788
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001789 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001790 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001791 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001792 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1793 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001794 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001795
1796 const TensorShape& outputShape = outputTensorInfo.GetShape();
1797 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1798 {
1799 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1800 "must be divisible by the square of block size." );
1801 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001802}
1803
telsoa014fcda012018-03-09 14:13:49 +00001804void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1805{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001806 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001807
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001808 ValidateNumInputs(workloadInfo, descriptorName, 1);
1809 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1810
1811 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1812 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001813
1814 std::vector<DataType> supportedTypes =
1815 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001816 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001817 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001818 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001819 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001820 };
1821
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001822 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001823
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001824 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001825 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001826 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001827 }
1828}
1829
telsoa01c577f2c2018-08-31 09:22:23 +01001830void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1831{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001832 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1833
1834 const std::string descriptorName{"LstmQueueDescriptor"};
1835
1836 // check dimensions of all inputs and outputs
1837 if (workloadInfo.m_InputTensorInfos.size() != 3)
1838 {
1839 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1840 }
1841 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1842 {
1843 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1844 }
1845
1846 std::vector<DataType> supportedTypes =
1847 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001848 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001849 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001850 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001851 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001852 };
1853
Jan Eilers38e05bd2019-06-26 13:10:09 +01001854 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001855 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1856
Jan Eilers38e05bd2019-06-26 13:10:09 +01001857 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001858 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001859 {
1860 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1861 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001862 descriptorName,
1863 "input_0",
1864 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001865 }
1866 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001867 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001868 {
1869 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1870 workloadInfo.m_OutputTensorInfos[i],
1871 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001872 "input_0",
1873 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001874 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001875
janeil0117d8d852019-11-15 15:00:16 +00001876 // Making sure clipping parameters have valid values.
1877 // == 0 means no clipping
1878 // > 0 means clipping
1879 if (m_Parameters.m_ClippingThresCell < 0.0f)
1880 {
1881 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
1882 }
1883 if (m_Parameters.m_ClippingThresProj < 0.0f)
1884 {
1885 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
1886 }
1887
Jan Eilers38e05bd2019-06-26 13:10:09 +01001888
1889 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01001890 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1891 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1892 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1893 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1894 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1895 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1896
Jan Eilers38e05bd2019-06-26 13:10:09 +01001897 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001898 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1899 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001900 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001901 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1902 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001903 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001904 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1905 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001906 // scratchBufferTensor
1907 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001908 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1909 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001910 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001911 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1912 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001913 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001914 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1915 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001916 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001917 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1918 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001919
1920
1921 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1922 if ( m_InputToInputWeights )
1923 {
1924 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1925 (n_cell * n_input), "InputLayerNormWeights");
1926 }
1927
1928 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1929 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1930 (n_cell * n_input), "InputToForgetWeights");
1931
1932 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1933 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1934 (n_cell * n_input), "InputToCellWeights");
1935
1936 if ( m_RecurrentToInputWeights )
1937 {
1938 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1939 (n_cell * n_output), "RecurrentToInputWeights");
1940 }
1941
1942 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1943 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1944 (n_cell * n_output), "RecurrentToForgetWeights");
1945
1946 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1947 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1948 (n_cell * n_output), "RecurrentToCellWeights");
1949
1950 // Make sure the input-gate's parameters are either both present (regular
1951 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1952 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1953 !m_Parameters.m_CifgEnabled) ||
1954 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1955 m_Parameters.m_CifgEnabled));
1956 if (!cifg_weights_all_or_none)
1957 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001958 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1959 "RecurrentToInputWeights must either both be present (regular LSTM) "
1960 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1961 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001962 }
1963
1964 if ( m_CellToInputWeights )
1965 {
1966 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1967 n_cell, "CellToInputWeights");
1968 }
1969 if ( m_CellToForgetWeights )
1970 {
1971 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1972 n_cell, "CellToForgetWeights");
1973 }
1974 if ( m_CellToOutputWeights )
1975 {
1976 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1977 n_cell, "CellToOutputWeights");
1978 }
1979
1980 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1981 bool peephole_weights_all_or_none =
1982 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1983 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1984 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1985 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1986 if (!peephole_weights_all_or_none)
1987 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001988 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001989 }
1990
1991 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1992 if (m_Parameters.m_CifgEnabled)
1993 {
1994 if (m_InputGateBias)
1995 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001996 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001997 }
1998 }
1999 else
2000 {
2001 if (!m_InputGateBias)
2002 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002003 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2004 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002005 }
2006 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2007 n_cell, "InputGateBias");
2008 }
2009
2010 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2011 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2012
2013 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2014 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2015
2016 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2017 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2018
2019 if (m_ProjectionWeights)
2020 {
2021 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2022 (n_cell * n_output), "ProjectionWeights");
2023 }
2024 if (m_ProjectionBias)
2025 {
2026 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2027 }
2028
2029 // Making sure the projection tensors are consistent:
2030 // 1) If projection weight is not present, then projection bias should not be
2031 // present.
2032 // 2) If projection weight is present, then projection bias is optional.
2033 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2034 !m_Parameters.m_ProjectionEnabled)
2035 || (m_ProjectionWeights && !m_ProjectionBias &&
2036 m_Parameters.m_ProjectionEnabled)
2037 || (m_ProjectionWeights && m_ProjectionBias &&
2038 m_Parameters.m_ProjectionEnabled));
2039 if (!projecton_tensors_consistent)
2040 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002041 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002042 }
2043
2044 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2045 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2046 // either all have values or none of them have values. Layer normalization is used when the values of all the
2047 // layer normalization weights are present
2048 if (m_InputLayerNormWeights)
2049 {
2050 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2051 }
2052 if (m_ForgetLayerNormWeights)
2053 {
2054 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2055 }
2056 if (m_CellLayerNormWeights)
2057 {
2058 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2059 }
2060 if (m_OutputLayerNormWeights)
2061 {
2062 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2063 }
2064
Jan Eilers38e05bd2019-06-26 13:10:09 +01002065 if (m_Parameters.m_LayerNormEnabled)
2066 {
2067 if (!m_Parameters.m_CifgEnabled)
2068 {
2069 if (!m_InputLayerNormWeights)
2070 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002071 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2072 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002073 }
2074 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2075 1, n_cell, "InputLayerNormWeights");
2076 }
2077 else if (m_InputLayerNormWeights)
2078 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002079 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2080 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002081 }
2082
2083 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2084 "ForgetLayerNormWeights");
2085 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2086
2087 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2088 "OutputLayerNormWeights");
2089 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2090
2091 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2092 "CellLayerNormWeights");
2093 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2094 }
2095 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2096 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002097 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2098 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002099 }
telsoa01c577f2c2018-08-31 09:22:23 +01002100}
2101
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002102void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2103{
2104 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2105
2106 ValidateNumInputs(workloadInfo, descriptorName, 1);
2107 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2108
2109 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2110 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2111
2112 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2113 {
2114 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2115 }
2116
2117 if (outputTensorInfo.GetDataType() != DataType::Float32)
2118 {
2119 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2120 }
2121
2122 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2123}
2124
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002125void ConvertFp32ToBf16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2126{
2127 const std::string descriptorName{"ConvertFp32ToBf16QueueDescriptor"};
2128
2129 ValidateNumInputs(workloadInfo, descriptorName, 1);
2130 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2131
2132 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2133 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2134
2135 if (inputTensorInfo.GetDataType() != DataType::Float32)
2136 {
2137 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2138 }
2139
2140 if (outputTensorInfo.GetDataType() != DataType::BFloat16)
2141 {
2142 throw InvalidArgumentException(descriptorName + ": Output tensor type must be BFloat16.");
2143 }
2144
2145 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2146}
2147
telsoa01c577f2c2018-08-31 09:22:23 +01002148void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2149{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002150 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002151
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002152 ValidateNumInputs(workloadInfo, descriptorName, 1);
2153 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2154
2155 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2156 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2157
2158 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002159 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002160 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002161 }
2162
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002163 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002164 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002165 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002166 }
2167
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002168 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002169}
2170
2171void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2172{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002173 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002174
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002175 ValidateNumInputs(workloadInfo, descriptorName, 1);
2176 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2177
2178 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2179 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2180
2181 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002182 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002183 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002184 }
2185
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002186 if (outputTensorInfo.GetDataType() != DataType::Float32)
2187 {
2188 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2189 }
2190
2191 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002192}
2193
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002194void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2195{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002196 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002197
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002198 ValidateNumInputs(workloadInfo, descriptorName, 2);
2199 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2200
2201 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2202 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2203 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2204
2205 std::vector<DataType> supportedTypes =
2206 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002207 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002208 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002209 DataType::Float32,
2210 DataType::QAsymmS8,
2211 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002212 DataType::QSymmS16,
2213 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002214 };
2215
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002216 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2217 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2218 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002219
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002220 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2221 inputTensorInfo1,
2222 outputTensorInfo,
2223 descriptorName,
2224 "input_0",
2225 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002226}
2227
David Beckc2044fe2018-09-05 15:00:38 +01002228void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2229{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002230 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002231
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002232 ValidateNumInputs(workloadInfo, descriptorName, 2);
2233 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2234
2235 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2236 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2237 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2238
2239 std::vector<DataType> supportedTypes =
2240 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002241 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002242 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002243 DataType::Float32,
2244 DataType::QAsymmS8,
2245 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002246 DataType::QSymmS16,
2247 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002248 };
2249
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002250 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2251 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2252 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002253
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002254 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2255 inputTensorInfo1,
2256 outputTensorInfo,
2257 descriptorName,
2258 "input_0",
2259 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002260}
2261
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002262void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2263{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002264 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002265
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002266 ValidateNumInputs(workloadInfo, descriptorName, 2);
2267 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2268
2269 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2270 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2271 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2272
2273 std::vector<DataType> supportedTypes =
2274 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002275 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002276 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002277 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002278 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002279 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002280 DataType::QSymmS16,
2281 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002282 };
2283
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002284 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2285 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2286 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002287
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002288 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2289 inputTensorInfo1,
2290 outputTensorInfo,
2291 descriptorName,
2292 "input_0",
2293 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002294}
2295
narpra01a6bf9122018-09-10 09:50:09 +01002296void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2297{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002298 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002299
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002300 ValidateNumInputs(workloadInfo, descriptorName, 1);
2301 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2302
2303 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2304 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002305
2306 std::vector<DataType> supportedTypes =
2307 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002308 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002309 DataType::Float32,
2310 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002311 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002312 DataType::QAsymmU8,
2313 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002314 };
narpra01eb061912018-09-10 17:35:27 +01002315
James Conroy4d1ff582019-06-10 17:06:39 +01002316 // First check if input tensor data type is supported, then
2317 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002318 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2319 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002320
narpra0132b90462018-09-13 11:07:48 +01002321 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002322 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002323 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002324 }
narpra0132b90462018-09-13 11:07:48 +01002325 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002326 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002327 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002328 }
2329 else
2330 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002331 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002332 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002333 ValidateTensorNumDimensions(outputTensorInfo,
2334 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002335 outputDim > 0 ? outputDim : 1,
2336 "output");
2337 }
narpra01a6bf9122018-09-10 09:50:09 +01002338}
2339
jimfly012c9322a2018-09-19 10:59:49 +01002340void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2341{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002342 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002343
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002344 ValidateNumInputs(workloadInfo, descriptorName, 1);
2345 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2346
2347 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2348 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002349
jimfly012c9322a2018-09-19 10:59:49 +01002350 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002351 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2352
jimfly012c9322a2018-09-19 10:59:49 +01002353 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002354 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2355 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2356 "as there are dimensions in the input tensor that is " +
2357 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2358 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002359 }
2360}
2361
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002362void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2363{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002364 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002365
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002366 ValidateNumInputs(workloadInfo, descriptorName, 1);
2367 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002368
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002369 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2370 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2371
Sadik Armagan2208b602019-07-31 16:36:27 +01002372 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002373 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002374 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002375 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002376 DataType::Float16,
2377 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002378 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002379 DataType::QAsymmU8,
2380 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002381 };
2382
2383 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002384
Keith Davis0c2eeac2020-02-11 16:51:50 +00002385 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002386 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002387 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002388 }
2389}
2390
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002391void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2392{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002393 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002394
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002395 ValidateNumInputs(workloadInfo, descriptorName, 1);
2396 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002397
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002398 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2399 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002400
2401 std::vector<DataType> supportedTypes =
2402 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002403 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002404 DataType::Float32,
2405 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002406 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002407 DataType::QAsymmU8,
2408 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002409 };
2410
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002411 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2412 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002413}
2414
Conor Kennedy430b5d82018-11-14 15:28:28 +00002415void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2416{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002417 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002418
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002419 ValidateNumInputs(workloadInfo, descriptorName, 1);
2420 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2421
2422 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2423 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002424
2425 std::vector<DataType> supportedTypes =
2426 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002427 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002428 DataType::Float16,
2429 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002430 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002431 DataType::QAsymmU8,
2432 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002433 };
2434
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002435 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2436 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002437
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002438 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002439
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002440 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002441 if (rank > 4)
2442 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002443 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002444 }
2445
Conor Kennedy430b5d82018-11-14 15:28:28 +00002446 // Begin, End & Stride length must be of rank(input0)
2447 if (m_Parameters.m_Begin.size() != rank)
2448 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002449 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002450 }
2451
2452 if (m_Parameters.m_End.size() != rank)
2453 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002454 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002455 }
2456
2457 if (m_Parameters.m_Stride.size() != rank)
2458 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002459 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002460 }
2461
2462 // Stride entries must be non-zero
2463 for (auto& stride : m_Parameters.m_Stride)
2464 {
2465 if (stride == 0)
2466 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002467 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002468 }
2469 }
2470}
2471
kevmay0190539692018-11-29 08:40:19 +00002472void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2473{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002474 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002475
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002476 ValidateNumInputs(workloadInfo, descriptorName, 2);
2477 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2478
2479 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2480 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2481 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2482
2483 std::vector<DataType> supportedTypes =
2484 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002485 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002486 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002487 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002488 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002489 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002490 DataType::QSymmS16,
2491 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002492 };
2493
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002494 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2495 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2496 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002497
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002498 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2499 inputTensorInfo1,
2500 outputTensorInfo,
2501 descriptorName,
2502 "input_0",
2503 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002504}
2505
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002506void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2507{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002508 const std::string descriptorName{"DebugQueueDescriptor"};
2509
2510 ValidateNumInputs(workloadInfo, descriptorName, 1);
2511 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002512}
2513
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002514void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2515{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002516 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002517
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002518 ValidateNumInputs(workloadInfo, descriptorName, 2);
2519 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002520
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002521 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2522 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2523 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2524
2525 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2526 inputTensorInfo1,
2527 outputTensorInfo,
2528 descriptorName,
2529 "input_0",
2530 "input_1");
2531
2532 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002533 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002534 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002535 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002536}
2537
FrancisMurtagh878f0232018-12-19 10:56:15 +00002538void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2539{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002540 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002541
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002542 ValidateNumInputs(workloadInfo, descriptorName, 2);
2543 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002544
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002545 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2546 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2547 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2548
2549 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2550 inputTensorInfo1,
2551 outputTensorInfo,
2552 descriptorName,
2553 "input_0",
2554 "input_1");
2555
2556 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002557 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002558 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002559 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002560}
2561
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002562void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2563{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002564 const std::string descriptorName{"RsqrtQueueDescriptor"};
2565
2566 ValidateNumInputs(workloadInfo, descriptorName, 1);
2567 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2568
2569 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2570 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2571
2572 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002573
2574 std::vector<DataType> supportedTypes =
2575 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002576 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002577 DataType::Float16,
2578 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002579 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002580 DataType::QAsymmU8,
2581 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002582 };
2583
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002584 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2585 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002586}
2587
narpra01b89b05f2019-01-16 09:53:09 +00002588void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2589{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002590 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002591
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002592 ValidateNumInputs(workloadInfo, descriptorName, 2);
2593 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002594
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002595 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2596 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002597 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002598 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002599 }
2600
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002601 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2602 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2603
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002604 std::vector<DataType> supportedTypes =
2605 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002606 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002607 DataType::Float16,
2608 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002609 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002610 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002611 DataType::QSymmS16,
2612 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002613 };
2614
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002615 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002616
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002617 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002618
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002619 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2620 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002621}
2622
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002623void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2624{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002625 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2626
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002627 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002628
2629 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2630 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002631 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002632 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2633 }
2634
2635 if (m_Anchors == nullptr)
2636 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002637 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002638 }
2639
2640 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002641 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2642 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2643
2644 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002645 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002646 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2647 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002648
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002649 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2650 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2651 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002652
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002653 const std::vector<DataType> supportedInputTypes =
2654 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002655 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002656 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002657 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002658 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002659 DataType::QAsymmU8,
2660 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002661 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002662
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002663 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2664 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2665 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2666
2667 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2668 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2669 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2670 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2671
2672 // NOTE: Output is always Float32 regardless of input type
2673 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2674 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2675 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2676 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002677
2678 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2679 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002680 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002681 "must be positive and less than or equal to 1.");
2682 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002683
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002684 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2685 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002686 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002687 "should be equal to number of classes + 1.");
2688 }
2689}
2690
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002691void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2692{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002693 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002694
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002695 ValidateNumInputs(workloadInfo, descriptorName, 1);
2696 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2697
2698 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2699 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2700
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002701 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002702 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002703 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002704 }
2705
Sadik Armagan2208b602019-07-31 16:36:27 +01002706 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002707 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002708 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002709 DataType::Float32,
2710 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002711 };
2712
2713 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002714}
2715
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002716void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2717{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002718 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002719
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002720 ValidateNumInputs(workloadInfo, descriptorName, 2);
2721 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002722
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002723 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2724 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2725 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002726
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002727 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2728 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2729
2730 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2731 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002732}
2733
Sadik Armaganeff363d2019-04-05 15:25:46 +01002734void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2735{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002736 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002737
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002738 ValidateNumInputs(workloadInfo, descriptorName, 2);
2739 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2740
2741 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2742 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2743
2744 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2745 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2746
2747 std::vector<DataType> supportedTypes =
2748 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002749 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002750 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002751 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002752 DataType::QAsymmU8,
2753 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002754 };
2755
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002756 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2757 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002758
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002759 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2760 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002761
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002762 ValidateTensorShapesMatch(inputTensorInfo0,
2763 outputTensorInfo0,
2764 descriptorName,
2765 "input_0",
2766 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002767
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002768 ValidateTensorShapesMatch(inputTensorInfo0,
2769 outputTensorInfo1,
2770 descriptorName,
2771 "input_0",
2772 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002773}
2774
Derek Lamberti901ea112019-12-10 22:07:09 +00002775void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002776{
2777 // This is internally generated so it should not need validation.
2778}
2779
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002780void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2781{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002782 const std::string& descriptorName{"PreluQueueDescriptor"};
2783
2784 ValidateNumInputs(workloadInfo, descriptorName, 2);
2785 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2786
2787 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2788 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2789 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002790
2791 std::vector<DataType> supportedTypes
2792 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002793 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002794 DataType::Float16,
2795 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002796 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002797 DataType::QAsymmU8,
2798 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002799 };
2800
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002801 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2802 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002803
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002804 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002805
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002806 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2807 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002808
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002809 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2810 alphaTensorInfo,
2811 outputTensorInfo,
2812 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002813 "input",
2814 "alpha");
2815}
2816
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002817void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2818{
2819 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2820
2821 ValidateNumInputs(workloadInfo, descriptorName, 1);
2822 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2823
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002824 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2825 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2826
2827 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2828 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002829
2830 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002831
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002832 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2833 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002834
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002835 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2836
2837 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002838 if (m_Parameters.m_BiasEnabled)
2839 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002840 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002841
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002842 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2843 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002844
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002845 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002846 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002847 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002848
2849 ValidatePerAxisQuantization(inputTensorInfo,
2850 outputTensorInfo,
2851 weightTensorInfo,
2852 optionalBiasTensorInfo,
2853 descriptorName);
2854
2855 std::vector<DataType> supportedTypes =
2856 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002857 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002858 DataType::Float32,
2859 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002860 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002861 DataType::QAsymmU8,
2862 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002863 };
2864
2865 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2866 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002867}
2868
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002869void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2870{
2871 const std::string descriptorName{"TransposeQueueDescriptor"};
2872
2873 ValidateNumInputs(workloadInfo, descriptorName, 1);
2874 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2875
2876 const PermutationVector& mapping = m_Parameters.m_DimMappings;
2877
2878 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2879 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2880
2881 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
2882 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
2883
2884 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
2885 {
2886 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
2887 {
2888 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
2889 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
2890 "must match dst dimension " + to_string(i) +
2891 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
2892 }
2893 }
2894
2895 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2896}
2897
James Conroy4f1f8992020-04-29 20:01:10 +01002898void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2899{
2900 const std::string descriptorName{"QLstmQueueDescriptor"};
2901
2902 // Validate number of inputs/outputs
2903 ValidateNumInputs(workloadInfo, descriptorName, 3);
2904 ValidateNumOutputs(workloadInfo, descriptorName, 3);
2905
2906 // Input/output tensor info
2907 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2908 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
2909 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
2910
2911 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2912 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2913 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
2914
2915 // Supported types for various tensors in QLSTM
2916 std::vector<DataType> inputOutputSupportedTypes =
2917 {
2918 DataType::QAsymmS8
2919 };
2920
2921 std::vector<DataType> cellStateSupportedTypes =
2922 {
2923 DataType::QSymmS16
2924 };
2925
2926 std::vector<DataType> weightsSupportedTypes =
2927 {
2928 DataType::QSymmS8
2929 };
2930
2931 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
2932 {
2933 DataType::QSymmS16
2934 };
2935
2936 std::vector<DataType> biasSupportedTypes =
2937 {
2938 DataType::Signed32
2939 };
2940
2941 // Validate types of input/output tensors
2942 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2943 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2944 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2945
2946 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2947 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2948 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
2949
2950 // Validate matching types of input/output tensors
2951 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2952 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2953 "outputStateIn", "outputStateOut");
2954 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2955
2956 // Infer number of batches, number of units, input size and output size from tensor dimensions
2957 const uint32_t numBatches = inputInfo.GetShape()[0];
2958 const uint32_t inputSize = inputInfo.GetShape()[1];
2959 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
2960 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
2961
2962 // Validate number of dimensions and number of elements for input/output tensors
2963 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2964 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2965 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
2966
2967 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2968 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
2969 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
2970
2971 // Validate number of dimensions and number of elements for MANDATORY weight tensors
2972 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2973 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2974 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
2975
2976 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2977 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2978 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
2979
2980 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2981 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2982 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
2983
2984 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2985 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2986 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
2987 " RecurrentToForgetWeights");
2988
2989 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2990 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2991 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
2992
2993 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2994 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2995 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
2996
2997 // Validate data types for MANDATORY weights tensors (all should match each other)
2998 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
2999
3000 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3001 "inputToForgetWeights", "inputToCellWeights");
3002 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3003 "inputToForgetWeights", "inputToOutputWeights");
3004
3005 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3006 "inputToForgetWeights", "recurrentToForgeteights");
3007 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3008 "inputToForgetWeights", "recurrentToCellWeights");
3009 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3010 "inputToForgetWeights", "recurrentToOutputWeights");
3011
3012 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3013 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3014 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3015 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3016
3017 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3018 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3019 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3020
3021 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3022 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3023 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3024
3025 // Validate data types for MANDATORY bias tensors
3026 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3027
3028 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3029 "forgetGateBias", "cellBias");
3030 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3031 "forgetGateBias", "outputGateBias");
3032
3033 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3034 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3035 !m_Parameters.m_CifgEnabled) ||
3036 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3037 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3038
3039 if (!allCifgParamsPresentOrNot)
3040 {
3041 throw InvalidArgumentException(descriptorName +
3042 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3043 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3044 "set appropriately.");
3045 }
3046
3047 if (!m_Parameters.m_CifgEnabled)
3048 {
3049 // Validate number of dimensions and number of elements
3050 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3051 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3052
3053 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3054 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3055 " RecurrentToInputWeights");
3056
3057 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3058 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3059
3060 // Validate data types
3061 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3062 "inputToForgetWeights", "inputToInputWeights");
3063 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3064 "inputToForgetWeights", "recurrentToInputWeights");
3065 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3066 "forgetGateBias", "inputGateBias");
3067 }
3068
3069 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3070 bool allPeepholeWeightsPresentOrNot =
3071 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3072 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3073 || (!m_CellToInputWeights && !m_CellToForgetWeights
3074 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3075
3076 if (!allPeepholeWeightsPresentOrNot)
3077 {
3078 throw InvalidArgumentException(descriptorName +
3079 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3080 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3081 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3082 "appropriately.");
3083 }
3084
3085 if (m_Parameters.m_PeepholeEnabled)
3086 {
3087 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3088 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3089 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3090
3091 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3092 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3093 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3094 "cellToForgetWeight", "cellToOutputWeights");
3095
3096 if (!m_Parameters.m_CifgEnabled)
3097 {
3098 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3099 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3100 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3101 "cellToForgetWeights", "cellToInputWeights");
3102 }
3103 }
3104
3105 // Validate OPTIONAL params: Layer Norm Weights
3106 bool allLayerNormWeightsPresentOrNot =
3107 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3108 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3109 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3110 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3111
3112 if (!allLayerNormWeightsPresentOrNot)
3113 {
3114 throw InvalidArgumentException(descriptorName +
3115 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3116 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3117 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3118 "only be present when Layer Norm is enabled and CIFG is disabled. "
3119 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3120 }
3121
3122 if (m_Parameters.m_LayerNormEnabled)
3123 {
3124 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3125 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3126 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3127
3128 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3129 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3130 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3131 "forgetLayerNormWeights", "cellLayerNormWeights");
3132
3133 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3134 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3135 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3136 "forgetLayerNormWeights", "outputLayerNormWeights");
3137
3138 if (!m_Parameters.m_CifgEnabled)
3139 {
3140 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3141 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3142 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3143 "forgetLayerNormWeights", "inputLayerNormWeights");
3144 }
3145 }
3146
3147 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3148 bool correctProjectionTensorsPresent =
3149 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3150 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3151 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3152
3153 if (!correctProjectionTensorsPresent)
3154 {
3155 throw InvalidArgumentException(descriptorName +
3156 ": If projection is enabled, ProjectionWeights should be present and "
3157 "ProjectionBias is optional. If projection is disabled, neither "
3158 "ProjectionWeights nor ProjectionBias should be present.");
3159 }
3160
3161 if (m_Parameters.m_ProjectionEnabled)
3162 {
3163 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3164 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3165 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3166
3167 if (m_ProjectionBias)
3168 {
3169 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003170 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003171 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3172 }
3173
3174 }
3175 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3176 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3177 throw InvalidArgumentException(descriptorName +
3178 ": If projection is disabled, output quantization info (scale, offset) "
3179 "should match HiddenStateScale and HiddenStateZeroPoint.");
3180 }
3181
3182}
3183
James Conroy9c3cae82019-08-01 16:01:48 +01003184void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3185{
3186 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3187
3188 // Validate number of inputs/outputs
3189 ValidateNumInputs(workloadInfo, descriptorName, 3);
3190 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3191
3192 // Input/output tensor infos
3193 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3194 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3195 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3196
3197 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3198 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3199
3200 std::vector<DataType> inputOutputSupportedTypes =
3201 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003202 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003203 };
3204
3205 std::vector<DataType> cellStateSupportedTypes =
3206 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003207 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003208 };
3209
3210 std::vector<DataType> weightsSupportedTypes =
3211 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003212 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003213 };
3214
3215 std::vector<DataType> biasSupportedTypes =
3216 {
3217 DataType::Signed32
3218 };
3219
3220 // Validate types of input/output tensors
3221 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3222 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3223 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3224
3225 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3226 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3227
3228 // Validate matching types of input/output tensors
3229 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3230 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3231 "outputStateIn", "outputStateOut");
3232 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3233
3234 // Validate matching quantization info for input/output tensors
3235 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3236 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3237 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003238
James Conroy9c3cae82019-08-01 16:01:48 +01003239 // Infer number of batches, input size and output size from tensor dimensions
3240 const uint32_t numBatches = inputInfo.GetShape()[0];
3241 const uint32_t inputSize = inputInfo.GetShape()[1];
3242 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3243
3244 // Validate number of dimensions and number of elements for input/output tensors
3245 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3246 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3247 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3248 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3249 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3250
3251 // Validate number of dimensions and number of elements for weights tensors
3252 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3253 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3254 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3255
3256 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3257 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3258 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3259
3260 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3261 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3262 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3263
3264 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3265 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3266 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3267
3268 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3269 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3270 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3271
3272 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3273 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3274 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3275 " RecurrentToForgetWeights");
3276
3277 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3278 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3279 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3280
3281 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3282 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3283 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3284
3285 // Validate data types for weights tensors (all should match each other)
3286 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3287
3288 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3289 "inputToInputWeights", "inputToForgetWeights");
3290 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3291 "inputToInputWeights", "inputToCellWeights");
3292 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3293 "inputToInputWeights", "inputToOutputWeights");
3294
3295 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3296 "inputToInputWeights", "recurrentToInputWeights");
3297 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3298 "inputToInputWeights", "recurrentToForgeteights");
3299 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3300 "inputToInputWeights", "recurrentToCellWeights");
3301 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3302 "inputToInputWeights", "recurrentToOutputWeights");
3303
3304 // Validate matching quantization info for weight tensors (all should match each other)
3305 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3306 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3307 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3308 descriptorName, "inputToInputWeights", "inputToCellWeights");
3309 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3310 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3311
3312 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3313 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3314 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3315 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3316 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3317 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3318 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3319 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3320
3321 // Validate number of dimensions and number of elements in bias tensors
3322 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3323 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3324 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3325
3326 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3327 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3328 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3329
3330 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3331 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3332 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3333
3334 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3335 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3336 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3337
3338 // Validate data types for bias tensors (all should match each other)
3339 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3340
3341 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3342 "inputGateBias", "forgetGateBias");
3343 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3344 "inputGateBias", "cellBias");
3345 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3346 "inputGateBias", "outputGateBias");
3347
3348 // Validate bias tensor quantization info
3349 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3350 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3351 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3352 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3353}
3354
Kevin May868eb142019-09-04 17:29:31 +01003355void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3356{
3357 const std::string descriptorName{"AbsQueueDescriptor"};
3358
3359 ValidateNumInputs(workloadInfo, descriptorName, 1);
3360 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3361
3362 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3363 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3364
3365 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3366
3367 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003368 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003369 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003370 DataType::Float16,
3371 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003372 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003373 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003374 DataType::QSymmS16,
3375 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003376 };
Kevin May868eb142019-09-04 17:29:31 +01003377
3378 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3379 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3380}
3381
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003382void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3383{
3384 const std::string descriptorName{"SliceQueueDescriptor"};
3385
3386 ValidateNumInputs(workloadInfo, descriptorName, 1);
3387 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3388
3389 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3390 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3391
3392 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3393
3394 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3395 if (rank > 4)
3396 {
3397 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3398 }
3399
3400 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3401
3402 // Check if m_Begin and m_Size have the expected length
3403 if (m_Parameters.m_Begin.size() != rank)
3404 {
3405 throw InvalidArgumentException(descriptorName +
3406 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3407 }
3408 if (m_Parameters.m_Size.size() != rank)
3409 {
3410 throw InvalidArgumentException(descriptorName +
3411 ": Length of size descriptor must equal rank " + std::to_string(rank));
3412 }
3413
3414 // Check if the shape of the output tensor matches m_Size
3415 const TensorShape& outputShape = outputTensorInfo.GetShape();
3416 for (unsigned int i = 0u; i < rank; ++i)
3417 {
3418 if (m_Parameters.m_Size[i] != outputShape[i])
3419 {
3420 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3421 }
3422 }
3423
3424 // Check if the sum of begin offset and size in a given dimension
3425 // does not exceed the size of corresponding input
3426 const TensorShape& inputShape = inputTensorInfo.GetShape();
3427 for(unsigned int i = 0u; i < rank; ++i)
3428 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003429 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003430 {
3431 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3432 std::to_string(i) + " exceeds input size.");
3433 }
3434 }
3435}
3436
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003437void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3438{
3439 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3440
3441 ValidateNumInputs(workloadInfo, descriptorName, 1);
3442 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3443
3444 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3445 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3446
3447 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3448 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3449
3450 std::vector<DataType> supportedTypes =
3451 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003452 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003453 DataType::Float32,
3454 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003455 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003456 DataType::QAsymmU8,
3457 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003458 };
3459
3460 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3461 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3462
3463 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3464
3465 if (m_Parameters.m_BlockSize == 0)
3466 {
3467 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3468 }
3469
3470 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3471 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3472 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3473 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3474
3475 const TensorShape& outputShape = outputInfo.GetShape();
3476 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3477 {
3478 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3479 "must be divisible by block size.");
3480 }
3481
3482 const TensorShape& inputShape = inputInfo.GetShape();
3483 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3484 {
3485 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3486 "must be divisible by the square of block size." );
3487 }
3488}
3489
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003490void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3491{
3492 const std::string descriptorName{"ComparisonQueueDescriptor"};
3493
3494 ValidateNumInputs(workloadInfo, descriptorName, 2);
3495 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3496
3497 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3498 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3499 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3500
3501 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3502 inputTensorInfo1,
3503 outputTensorInfo,
3504 descriptorName,
3505 "input_0",
3506 "input_1");
3507
3508 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3509 {
3510 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3511 }
3512}
3513
josh minor4a3c6102020-01-06 16:40:46 -06003514void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3515{
3516 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3517
3518 ValidateNumInputs(workloadInfo, descriptorName, 1);
3519 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3520
3521 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3522 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3523
3524 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3525
3526 std::vector<DataType> supportedTypes =
3527 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003528 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003529 DataType::Float16,
3530 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003531 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003532 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003533 DataType::QSymmS16,
3534 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003535 };
3536
James Conroyaba90cd2020-11-06 16:28:18 +00003537 std::vector<DataType> logicalSupportedTypes =
3538 {
3539 DataType::Boolean
3540 };
3541
3542 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3543 {
3544 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3545 }
3546 else
3547 {
3548 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3549 }
3550
3551
josh minor4a3c6102020-01-06 16:40:46 -06003552 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3553}
3554
Finn Williams2605b232020-06-10 15:53:46 +01003555void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3556{
3557 const std::string descriptorName{"RankQueueDescriptor"};
3558
3559 ValidateNumInputs(workloadInfo, descriptorName, 1);
3560 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3561
3562 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3563 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3564
3565 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3566 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3567
3568 std::vector<DataType> supportedTypes =
3569 {
3570 DataType::BFloat16,
3571 DataType::Float16,
3572 DataType::Float32,
3573 DataType::QAsymmS8,
3574 DataType::QAsymmU8,
3575 DataType::QSymmS8,
3576 DataType::QSymmS16,
3577 DataType::Signed32
3578 };
3579
3580 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3581 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3582}
3583
James Conroyaba90cd2020-11-06 16:28:18 +00003584void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3585{
3586 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3587
3588 ValidateNumInputs(workloadInfo, descriptorName, 2);
3589 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3590
3591 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3592 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3593 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3594
3595 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3596 inputTensorInfo1,
3597 outputTensorInfo,
3598 descriptorName,
3599 "input_0",
3600 "input_1");
3601
3602 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3603 {
3604 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3605 }
3606
3607 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3608 {
3609 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3610 }
3611
3612 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3613 {
3614 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3615 }
3616}
3617
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003618} // namespace armnn