blob: 134495991e421aa86c419a8aa726547dba8c576a [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00006#include <backendsCommon/WorkloadData.hpp>
7#include <backendsCommon/CpuTensorHandle.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/TensorUtils.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010010#include <armnn/utility/NumericCast.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000011
telsoa014fcda012018-03-09 14:13:49 +000012#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000014#include <string>
15#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000016
James Ward47fce872020-09-10 11:57:28 +010017#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000018
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
telsoa01c577f2c2018-08-31 09:22:23 +010029 case DataType::Float16:
30 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000031 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000032 case DataType::Float32:
33 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000034 case DataType::QAsymmS8:
35 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000036 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000037 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000038 case DataType::QSymmS8:
39 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000040 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010041 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000042 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010043 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000044 return DataType::Float32;
45 }
46}
47
48namespace
49{
50
51//---------------------------------------------------------------
52//android ndk does not support std::to_string function.
53template <typename T>
54std::string to_string(T value)
55{
56 std::ostringstream os;
57 os << value;
58 return os.str();
59}
60
61//---------------------------------------------------------------
62void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
63{
64 if (!ptr)
65 {
66 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
67 paramName + " parameter must be set.");
68 }
69}
70
71//---------------------------------------------------------------
72void ValidateTensorShapesMatch(const TensorInfo& first,
73 const TensorInfo& second,
74 std::string const& descName,
75 std::string const& firstName,
76 std::string const& secondName)
77{
78 if (first.GetShape() != second.GetShape())
79 {
80 throw InvalidArgumentException(descName + ": "
81 + firstName + " & " + secondName + " must have identical shapes");
82 }
83}
84
85//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010086void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000087{
Sadik Armaganeff363d2019-04-05 15:25:46 +010088 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000089 {
90 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010091 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000092 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
93 }
94}
95
96//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010097void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000098{
Sadik Armaganeff363d2019-04-05 15:25:46 +010099 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000100 {
101 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100102 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000103 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
104 }
105}
106
107//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100108void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000109 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100110 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000111 std::string const& tensorName)
112{
113 if (tensor.GetNumDimensions() != numDimensions)
114 {
115 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
116 to_string(tensor.GetNumDimensions()) + " dimensions for " +
117 tensorName + " tensor.");
118 }
119}
120
121//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100122void ValidateTensorNumElements(const TensorInfo& tensor,
123 std::string const& descName,
124 unsigned int numElements,
125 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100126{
127 if (tensor.GetNumElements() != numElements)
128 {
129 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100130 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100131 tensorName + " tensor.");
132 }
133}
134
135//---------------------------------------------------------------
136void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100137 unsigned int numDimension,
138 unsigned int numElements,
139 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100140{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100141 const std::string functionName{"ValidateTensorNumDimNumElem"};
142 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
143 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100144}
145
146//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000147void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
148 const std::string& descName, std::string const& tensorName)
149{
150 if (tensor.GetDataType() != dataType)
151 {
152 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
153 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
154 }
155}
156
Derek Lambertid466a542020-01-22 15:37:29 +0000157void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
158{
159 ARMNN_NO_DEPRECATE_WARN_BEGIN
160 if (tensor.GetDataType() != DataType::QSymmS8 &&
161 tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
162 {
163 throw InvalidArgumentException(descName +
164 ": Expected data type which supports per-axis quantization scheme but got " +
165 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
166 }
167 ARMNN_NO_DEPRECATE_WARN_END
168}
169
telsoa014fcda012018-03-09 14:13:49 +0000170//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100171void ValidateTensorQuantizationSpace(const TensorInfo& first,
172 const TensorInfo& second,
173 const std::string& descName,
174 std::string const& firstName,
175 std::string const& secondName)
176{
177 if (!first.IsQuantized() ||
178 !second.IsQuantized())
179 {
180 // Not a quantized type, ignore the validation
181 return;
182 }
183
184 DataType firstDataType = first.GetDataType();
185 DataType secondDataType = second.GetDataType();
186
187 if (firstDataType != secondDataType)
188 {
189 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
190 " must be of the same quantized type, " +
191 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
192 secondName + " is " + GetDataTypeName(secondDataType));
193 }
194
195 if (!first.IsTypeSpaceMatch(second))
196 {
197 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
198 " must have the same quantization space, " +
199 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
200 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
201 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
202 " and scale " + to_string(second.GetQuantizationScale()));
203 }
204}
205
206//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100207void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
208 const TensorInfo& inputTensorInfo,
209 const TensorInfo& weightsTensorInfo,
210 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000211{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000212 // Helper lambda function to validate a single bias quantization scale value
213 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
214 {
ricbur013f4d7102019-10-31 16:22:18 +0000215 constexpr float tolerance = 0.000001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000216 if (std::abs(biasScale - expectedScale) > tolerance)
217 {
218 // Print the float values with extra precision to see very small differences
219 std::stringstream msg;
220 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
221 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
222 biasScale;
223 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
224 }
225 };
226
telsoa014fcda012018-03-09 14:13:49 +0000227 if (biasTensor.GetQuantizationOffset() != 0)
228 {
229 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
230 to_string(biasTensor.GetQuantizationOffset()));
231 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000232
233 if (biasTensor.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000234 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000235 // Validate per-axis quantization scales
236 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
237 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
238
239 if (weightScales.size() != biasScales.size())
240 {
241 std::stringstream msg;
242 msg << descName << ": Expected matchhing number of per-axis quantization scales, but got different "
243 << "values: weights=" << weightScales.size() << ", biases=" << biasScales.size();
244 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
245 }
246
247 for (size_t i = 0ul; i < biasScales.size(); ++i)
248 {
249 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
250 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
251 }
252 }
253 else
254 {
255 // Validate per-tensor quantization scale
256 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
257 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000258 }
259}
260
261//---------------------------------------------------------------
262void ValidateTensors(const std::vector<ITensorHandle*>& vec,
263 unsigned int numExpected,
264 const std::string& descName,
265 const std::string& varName)
266{
267 if (vec.empty() && numExpected > 0)
268 {
269 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
270 }
271
272 for (unsigned int i = 0; i < numExpected; ++i)
273 {
274 if (!vec[i])
275 {
276 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
277 }
278 }
279}
280
281//---------------------------------------------------------------
282void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
283 const TensorInfo& second,
284 const TensorInfo& output,
285 std::string const& descName,
286 std::string const& firstName,
287 std::string const& secondName)
288{
289 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
290 // broadcasted.
291 if (first.GetNumDimensions() != second.GetNumDimensions())
292 {
293 throw InvalidArgumentException(descName + ": Tensors "
294 + firstName + " & " + secondName
295 + " must have the same number of dimensions in order to be broadcasted");
296 }
297 uint32_t numDims = first.GetNumDimensions();
298 std::vector<uint32_t> outputDims(numDims, 0u);
299 for (uint32_t i = 0; i < numDims; i++)
300 {
301 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
302 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
303 if (dimsNotEqual && dimsNotOne)
304 {
305 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
306 }
307 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
308 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100309 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000310 if (broadcastShape != output.GetShape())
311 {
312 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
313 + firstName + " & " + secondName
314 + " does not match the output shape");
315 }
316}
317
318//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100319void ValidateDataTypes(const TensorInfo& info,
320 const std::vector<armnn::DataType>& supportedTypes,
321 std::string const& descName)
322{
323 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
324 if (iterator == supportedTypes.end())
325 {
326 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
327 }
328}
329
James Conroy4d1ff582019-06-10 17:06:39 +0100330//---------------------------------------------------------------
331void ValidateTensorDataTypesMatch(const TensorInfo& first,
332 const TensorInfo& second,
333 std::string const& descName,
334 std::string const& firstName,
335 std::string const& secondName)
336{
337 if (first.GetDataType() != second.GetDataType())
338 {
339 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
340 " must have identical data types.");
341 }
342}
343
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100344//---------------------------------------------------------------
345void ValidateTensorNumElementsMatch(const TensorInfo& first,
346 const TensorInfo& second,
347 std::string const& descName,
348 std::string const& firstName,
349 std::string const& secondName)
350{
351 if (first.GetNumElements() != second.GetNumElements())
352 {
353 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
354 " must have the same number of elements.");
355 }
356}
357
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000358void ValidateWeightDataType(const TensorInfo& inputInfo,
359 const TensorInfo& weightInfo,
360 const std::string& descName)
361{
362 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000363 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000364 {
Derek Lambertid466a542020-01-22 15:37:29 +0000365 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000366 const std::vector<DataType> validTypes =
367 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000368 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100369 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000370 DataType::QSymmS8,
371 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000372 };
Derek Lambertid466a542020-01-22 15:37:29 +0000373 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000374
375 ValidateDataTypes(weightInfo, validTypes, descName);
376 }
377 else
378 {
379 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
380 }
381}
382
383void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
384 const std::string& descName,
385 const std::string& tensorName)
386{
387 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
388 if (!quantizationDim.has_value())
389 {
James Ward47fce872020-09-10 11:57:28 +0100390 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
391 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000392 }
393
394 if (quantizationDim.value() != 0)
395 {
James Ward47fce872020-09-10 11:57:28 +0100396 throw InvalidArgumentException(fmt::format(
397 "{0}: Quantization dimension for per-axis quantization expected to be 0 on tensor {1}, "
398 "but got: {2}", descName, tensorName, quantizationDim.value()));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000399 }
400}
401
402void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
403 const std::string& descName,
404 const std::string& tensorName)
405{
406 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
407 if (quantizationOffset != 0)
408 {
James Ward47fce872020-09-10 11:57:28 +0100409 throw InvalidArgumentException(fmt::format(
410 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
411 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000412 }
413}
414
415void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
416 const TensorInfo& outputInfo,
417 const TensorInfo& weightInfo,
418 const Optional<TensorInfo>& optionalBiasInfo,
419 const std::string& descName)
420{
421 if (weightInfo.HasPerAxisQuantization())
422 {
423 const DataType inputDataType = inputInfo.GetDataType();
424 const DataType outputDataType = outputInfo.GetDataType();
425
Keith Davis0c2eeac2020-02-11 16:51:50 +0000426 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000427
428 if (!canHavePerAxisQuantization)
429 {
James Ward47fce872020-09-10 11:57:28 +0100430 throw InvalidArgumentException(fmt::format(
431 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
432 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000433 }
434
Derek Lambertid466a542020-01-22 15:37:29 +0000435
436 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000437 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
438 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
439
440 if (optionalBiasInfo.has_value())
441 {
442 const TensorInfo& biasInfo = optionalBiasInfo.value();
443 if (!biasInfo.HasPerAxisQuantization())
444 {
James Ward47fce872020-09-10 11:57:28 +0100445 throw InvalidArgumentException(fmt::format(
446 "{}: Per-axis quantization parameters not set on bias tensor, "
447 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000448 }
449
450 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
451 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
452 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
453 }
454 }
455}
456
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100457} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000458
459void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
460 unsigned int numExpectedIn, unsigned int numExpectedOut) const
461{
462 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
463 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
464}
465
466//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100467void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
468{
469 const std::string descriptorName{"MapQueueDescriptor"};
470
471 ValidateNumInputs(workloadInfo, descriptorName, 1);
472 ValidateNumOutputs(workloadInfo, descriptorName , 0);
473
474 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
475 {
476 if (!m_Inputs[i])
477 {
478 throw InvalidArgumentException(
479 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
480 }
481 }
482}
483
484//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000485void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
486{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100487 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000488
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100489 ValidateNumInputs(workloadInfo, descriptorName, 1);
490 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000491
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100492 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
493 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
494
495 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
496 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000497
498 if (m_Inputs.size() != m_Outputs.size())
499 {
James Ward47fce872020-09-10 11:57:28 +0100500 throw InvalidArgumentException(fmt::format(
501 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
502 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000503 }
504
505 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
506 {
507 if (!m_Inputs[i])
508 {
James Ward47fce872020-09-10 11:57:28 +0100509 throw InvalidArgumentException(fmt::format(
510 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000511 }
512
513 if (!m_Outputs[i])
514 {
James Ward47fce872020-09-10 11:57:28 +0100515 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000516 }
517 }
518}
519
Derek Lambertif674aa02019-08-01 15:56:25 +0100520//---------------------------------------------------------------
521void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
522{
523 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
524 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
525
526 if (workloadInfo.m_InputTensorInfos.size() != 1)
527 {
James Ward47fce872020-09-10 11:57:28 +0100528 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
529 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100530
531 }
532
533 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
534 {
James Ward47fce872020-09-10 11:57:28 +0100535 throw InvalidArgumentException(fmt::format(
536 "Number of input infos ({0}) does not match the number of output infos ({1})",
537 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100538 }
539
540 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
541 {
542 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
543 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
544 {
James Ward47fce872020-09-10 11:57:28 +0100545 throw InvalidArgumentException(fmt::format(
546 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100547 }
548 }
549
550 if (m_Inputs.size() != 1)
551 {
James Ward47fce872020-09-10 11:57:28 +0100552 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100553 }
554
555 if (m_Inputs.size() != m_Outputs.size())
556 {
James Ward47fce872020-09-10 11:57:28 +0100557 throw InvalidArgumentException(fmt::format(
558 "Number of inputs ({0}) does not match the number of outputs ({1})",
559 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100560 }
561
562 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
563 {
564 if (!m_Inputs[i])
565 {
James Ward47fce872020-09-10 11:57:28 +0100566 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100567 }
568
569 if (!m_Outputs[i])
570 {
James Ward47fce872020-09-10 11:57:28 +0100571 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100572 }
573 }
574}
575
576//---------------------------------------------------------------
577void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
578{
579 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
580 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
581
Derek Lambertif674aa02019-08-01 15:56:25 +0100582 if (m_Inputs.size() != 1)
583 {
James Ward47fce872020-09-10 11:57:28 +0100584 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100585 }
586
587 if (m_Outputs.size() != 0)
588 {
James Ward47fce872020-09-10 11:57:28 +0100589 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100590 }
591
592 if (!m_Inputs[0])
593 {
James Ward47fce872020-09-10 11:57:28 +0100594 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100595 }
596}
597
598//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000599void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
600{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100601 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100602
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100603 ValidateNumInputs(workloadInfo, descriptorName, 1);
604 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100605
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100606 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
607 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100608
609 std::vector<DataType> supportedTypes =
610 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000611 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100612 DataType::Float16,
613 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000614 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000615 DataType::QAsymmU8,
616 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100617 };
618
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100619 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
620 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
621 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000622}
623
Nikhil Rajee391d52019-09-05 17:50:44 +0100624void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
625{
626 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
627
628 ValidateNumInputs(workloadInfo, descriptorName, 1);
629 ValidateNumOutputs(workloadInfo, descriptorName, 1);
630
631 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
632 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
633
Inki Daed4619e22020-09-10 15:33:54 +0900634 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
635 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100636 {
Inki Daed4619e22020-09-10 15:33:54 +0900637 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100638 }
639
James Conroyd47a0642019-09-17 14:22:06 +0100640 std::vector<DataType> supportedInputTypes =
641 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000642 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100643 DataType::Float16,
644 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100645 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000646 DataType::QAsymmU8,
647 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900648 DataType::Signed32,
649 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100650 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100651
James Conroyd47a0642019-09-17 14:22:06 +0100652 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100653
654 auto inputShape = inputTensorInfo.GetShape();
655 auto outputShape = outputTensorInfo.GetShape();
656
657 auto inputNumDimensions = inputShape.GetNumDimensions();
658 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
659
660 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
661
662 // 1D input shape results in scalar output shape
663 if (inputShape.GetNumDimensions() == 1)
664 {
665 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
666 {
667 throw InvalidArgumentException(descriptorName + outputShapeError);
668 }
669 }
670 else
671 {
672 for (unsigned int i = 0; i < unsignedAxis; ++i)
673 {
674 if (outputShape[i] != inputShape[i])
675 {
676 throw InvalidArgumentException(descriptorName + outputShapeError);
677 }
678 }
679
680 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
681 {
682 if (outputShape[i - 1] != inputShape[i])
683 {
684 throw InvalidArgumentException(descriptorName + outputShapeError);
685 }
686 }
687 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100688}
689
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100690void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
691{
692 const std::string descriptorName{"SoftmaxQueueDescriptor"};
693
694 ValidateNumInputs(workloadInfo, descriptorName, 1);
695 ValidateNumOutputs(workloadInfo, descriptorName, 1);
696
697 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
698 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
699
700 std::vector<DataType> supportedTypes =
701 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000702 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100703 DataType::Float16,
704 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000705 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000706 DataType::QAsymmU8,
707 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100708 };
709
710 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
711 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
712 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
713}
714
telsoa014fcda012018-03-09 14:13:49 +0000715void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
716{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100717 const std::string descriptorName{"SplitterQueueDescriptor"};
718
719 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000720
Ruomei Yan25339c32019-05-28 16:48:20 +0100721 // Check the supported data types
722 std::vector<DataType> supportedTypes =
723 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000724 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100725 DataType::Float32,
726 DataType::Float16,
727 DataType::Boolean,
728 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100729 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000730 DataType::QAsymmU8,
731 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100732 };
733
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100734 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
735 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100736 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100737 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
738 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
739
740 const std::string outputName = "output_" + std::to_string(i);
741 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100742 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100743
telsoa014fcda012018-03-09 14:13:49 +0000744 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
745 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100746 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000747 }
748
749 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
750 {
751 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100752 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000753 "has to match number of workloadInfo.m_OutputTensorInfos. "
754 "Number of windows: " +
755 to_string(m_ViewOrigins.size()) +
756 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
757 }
758
telsoa01c577f2c2018-08-31 09:22:23 +0100759 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000760 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
761 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
762 {
telsoa01c577f2c2018-08-31 09:22:23 +0100763 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000764 ViewOrigin const& e = m_ViewOrigins[w];
765 if (e.m_Origin.size() != inputDims)
766 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100767 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000768 "have the same dimensionality as the input tensor. "
769 "Window origin (index: " +
770 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
771 " dimensions, the input "
772 "tensor has " +
773 to_string(inputDims) + " dimensions.");
774 }
775 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
776 {
777 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
778 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
779 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100780 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000781 "be smaller or equal than the size of the input in that coord.");
782 }
783 }
784 }
785}
786
Jim Flynne242f2d2019-05-22 14:24:13 +0100787void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000788{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100789 const std::string descriptorName{"ConcatQueueDescriptor"};
790
791 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000792
793 if (m_Inputs.size() <= 0)
794 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100795 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000796 }
797 if (m_Outputs.size() <= 0)
798 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100799 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000800 }
801
802 if (workloadInfo.m_InputTensorInfos.size() <= 0)
803 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100804 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000805 }
806 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
807 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100808 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000809 }
810
Nikhil Raj8599a412018-11-19 14:51:07 +0000811 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
812 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100813 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000814 }
815
816 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
817 {
818 return;
819 }
820
telsoa014fcda012018-03-09 14:13:49 +0000821 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
822 {
823 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100824 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000825 "has to match number of workloadInfo.m_InputTensorInfos. "
826 "Number of windows: " +
827 to_string(m_ViewOrigins.size()) +
828 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
829 }
830
telsoa01c577f2c2018-08-31 09:22:23 +0100831 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000832 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
833 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
834 {
telsoa01c577f2c2018-08-31 09:22:23 +0100835 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000836 ViewOrigin const& e = m_ViewOrigins[w];
837 if (e.m_Origin.size() != outputDims)
838 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100839 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000840 "have the same dimensionality as the output tensor. "
841 "Window origin (index: " +
842 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
843 " dimensions, the output "
844 "tensor has " +
845 to_string(outputDims) + " dimensions.");
846 }
telsoa01c577f2c2018-08-31 09:22:23 +0100847 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000848 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
849 {
850 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
851 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
852 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100853 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000854 "be smaller or equal than the size of the output in that coord.");
855 }
856 }
857 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100858
859 // Check the supported data types
860 std::vector<DataType> supportedTypes =
861 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000862 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100863 DataType::Float32,
864 DataType::Float16,
865 DataType::Boolean,
866 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100867 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000868 DataType::QAsymmU8,
869 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100870 };
871
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100872 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
873 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100874 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100875 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
876 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
877
878 const std::string inputName = "input_" + std::to_string(i);
879 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100880 }
telsoa014fcda012018-03-09 14:13:49 +0000881}
882
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100883void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
884{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100885 const std::string descriptorName{"StackQueueDescriptor"};
886
887 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100888
889 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
890 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100891 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100892 }
893
894 // All inputs must have the same shape, which is defined in parameters
895 const TensorShape& inputShape = m_Parameters.m_InputShape;
896 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
897 {
898 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
899 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100900 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100901 }
902 }
903
Matthew Jacksondba634f2019-08-15 15:14:18 +0100904 if (inputShape.GetNumDimensions() > 4)
905 {
906 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
907 }
908
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100909 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
910 // since the output tensor has an additional dimension.
911 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
912 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100913 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100914 "than the number of input dimensions.");
915 }
916
917 // Output shape must be as inferred from the input shape
918 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
919 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
920 {
921 if (outputShape[i] != inputShape[i])
922 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100923 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100924 "match shape inferred from input tensor.");
925 }
926 }
927
928 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
929 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100930 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100931 "match shape inferred from input tensor.");
932 }
933
934 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
935 {
936 if (outputShape[i] != inputShape[i-1])
937 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100938 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100939 "match shape inferred from input tensor.");
940 }
941 }
942
Matthew Jacksondba634f2019-08-15 15:14:18 +0100943 if (outputShape.GetNumDimensions() > 5)
944 {
945 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
946 }
947
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100948 // Check the supported data types
949 std::vector<DataType> supportedTypes =
950 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000951 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100952 DataType::Float32,
953 DataType::Float16,
954 DataType::Boolean,
955 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100956 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000957 DataType::QAsymmU8,
958 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100959 };
960
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100961 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100962
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100963 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100964 {
965 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
966 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100967 descriptorName,
968 "input_0",
969 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100970 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100971
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100972 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
973 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100974 descriptorName,
975 "input_0",
976 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100977}
978
Ryan OSheaec6c6802020-06-05 17:17:06 +0100979void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
980{
981 const std::string descriptorName{"FillQueueDescriptor"};
982
983 ValidateNumInputs(workloadInfo, descriptorName, 1);
984 ValidateNumOutputs(workloadInfo, descriptorName, 1);
985
986 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
987 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
988
989 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
990
991 std::vector<DataType> supportedTypes =
992 {
993 DataType::BFloat16,
994 DataType::Float32,
995 DataType::Float16,
996 DataType::Signed32
997 };
998
999 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1000}
1001
telsoa014fcda012018-03-09 14:13:49 +00001002void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1003{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001004 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001005
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001006 ValidateNumInputs(workloadInfo, descriptorName, 1);
1007 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1008
1009 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1010 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1011
1012 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1013
1014 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001015 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001016 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001017 }
1018
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001019 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001020
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001021 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1022 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001023
1024 if (m_Parameters.m_BiasEnabled)
1025 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001026 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001027
telsoa01c577f2c2018-08-31 09:22:23 +01001028 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001029 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
1030 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001031
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001032 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1033 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001034 }
1035
Francis Murtagh46c09d02019-05-28 08:15:28 +01001036 // Check the supported data types
1037 std::vector<DataType> supportedTypes =
1038 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001039 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001040 DataType::Float32,
1041 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001042 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001043 DataType::QAsymmU8,
1044 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001045 };
1046
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001047 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001048
1049 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1050 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1051 {
1052 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1053 {
1054 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1055 "for BFloat16 input.");
1056 }
1057 }
1058 else
1059 {
1060 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1061 }
telsoa014fcda012018-03-09 14:13:49 +00001062}
1063
telsoa014fcda012018-03-09 14:13:49 +00001064void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1065{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001066 const std::string descriptorName{"NormalizationQueueDescriptor"};
1067
1068 ValidateNumInputs(workloadInfo, descriptorName, 1);
1069 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1070
1071 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1072 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001073
1074 // Check the supported data types
1075 std::vector<DataType> supportedTypes =
1076 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001077 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001078 DataType::Float16,
1079 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001080 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001081 DataType::QAsymmU8,
1082 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001083 };
1084
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001085 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001086
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001087 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001088
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001089 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001090}
1091
1092void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1093{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001094 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001095
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001096 ValidateNumInputs(workloadInfo, descriptorName, 2);
1097 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1098
1099 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1100 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1101 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1102
1103 std::vector<DataType> supportedTypes =
1104 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001105 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001106 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001107 DataType::Float16,
1108 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001109 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001110 DataType::QSymmS16,
1111 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001112 };
1113
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001114 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1115 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1116 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001117
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001118 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1119 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001120
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001121 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1122 inputTensorInfo1,
1123 outputTensorInfo,
1124 descriptorName,
1125 "input_0",
1126 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001127}
1128
telsoa014fcda012018-03-09 14:13:49 +00001129void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1130{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001131 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001132
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001133 ValidateNumInputs(workloadInfo, descriptorName, 2);
1134 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1135
1136 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1137 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1138 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1139
1140 std::vector<DataType> supportedTypes =
1141 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001142 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001143 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001144 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001145 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001146 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001147 DataType::QSymmS16,
1148 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001149 };
1150
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001151 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1152 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1153 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001154
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001155 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1156 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001157
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001158 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1159 inputTensorInfo1,
1160 outputTensorInfo,
1161 descriptorName,
1162 "input_0",
1163 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001164}
1165
1166void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1167{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001168 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001169
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001170 ValidateNumInputs(workloadInfo, descriptorName, 1);
1171 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1172
1173 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1174 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001175
1176 std::vector<DataType> supportedTypes =
1177 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001178 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001179 DataType::Float16,
1180 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001181 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001182 DataType::QAsymmU8,
1183 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001184 };
1185
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001186 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1187 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001188
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001189 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001190 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001191
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001192 ValidatePointer(m_Mean, descriptorName, "mean");
1193 ValidatePointer(m_Variance, descriptorName, "variance");
1194 ValidatePointer(m_Beta, descriptorName, "beta");
1195 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001196
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001197 const TensorInfo& mean = m_Mean->GetTensorInfo();
1198 const TensorInfo& variance = m_Variance->GetTensorInfo();
1199 const TensorInfo& beta = m_Beta->GetTensorInfo();
1200 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001201
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001202 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1203 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1204 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1205 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001206
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001207 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1208 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1209 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001210}
1211
1212void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1213{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001214 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001215
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001216 ValidateNumInputs(workloadInfo, descriptorName, 1);
1217 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001218
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001219 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1220 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001221
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001222 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1223 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001224
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001225 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001226
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001227 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1228 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001229
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001230 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001231
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001232 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001233 if (m_Parameters.m_BiasEnabled)
1234 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001235 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001236
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001237 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1238 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001239
1240 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1241 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001242 }
1243
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001244 ValidatePerAxisQuantization(inputTensorInfo,
1245 outputTensorInfo,
1246 weightTensorInfo,
1247 optionalBiasTensorInfo,
1248 descriptorName);
1249
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001250 std::vector<DataType> supportedTypes =
1251 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001252 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001253 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001254 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001255 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001256 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001257 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001258 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001259 };
1260
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001261 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001262
1263 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1264 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1265 {
1266 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1267 {
1268 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1269 "for BFloat16 input.");
1270 }
1271 }
1272 else
1273 {
1274 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1275 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001276}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001277
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001278void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1279{
1280 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1281
1282 ValidateNumInputs(workloadInfo, descriptorName, 1);
1283 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1284
1285 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1286 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1287
1288 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1289 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1290
1291 ValidatePointer(m_Weight, descriptorName, "weight");
1292
1293 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1294 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1295
1296 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1297 {
1298 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001299 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1300 "cannot be smaller than 1.",
1301 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001302 }
1303
1304 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1305
1306 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1307 // inputChannels * channelMultiplier should be equal to outputChannels.
1308 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1309 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1310 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1311 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1312 {
James Ward47fce872020-09-10 11:57:28 +01001313 throw InvalidArgumentException(fmt::format(
1314 "{0}: output_channels (provided {1}) should be equal to input_channels (provided {2}) "
1315 "multiplied by channel_multiplier (provided {3}).",
1316 descriptorName, numWeightOutputChannels, numWeightInputChannels, numWeightChannelMultiplier));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001317 }
1318
Teresa Charlind8df0262019-11-11 12:28:15 +00001319 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001320
Teresa Charlind8df0262019-11-11 12:28:15 +00001321 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001322 if (m_Parameters.m_BiasEnabled)
1323 {
1324 ValidatePointer(m_Bias, descriptorName, "bias");
1325
Teresa Charlind8df0262019-11-11 12:28:15 +00001326 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1327 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001328
1329 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1330 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1331 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001332 ValidatePerAxisQuantization(inputTensorInfo,
1333 outputTensorInfo,
1334 weightTensorInfo,
1335 optionalBiasTensorInfo,
1336 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001337
1338 std::vector<DataType> supportedTypes =
1339 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001340 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001341 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001342 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001343 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001344 DataType::QAsymmU8,
1345 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001346 };
1347
1348 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1349 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001350}
1351
1352void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1353{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001354 const std::string descriptorName{"PermuteQueueDescriptor"};
1355
1356 ValidateNumInputs(workloadInfo, descriptorName, 1);
1357 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001358
1359 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1360
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001361 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1362 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001363
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001364 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1365 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001366
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001367 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001368 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001369 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001370 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001371 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1372 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1373 "must match dst dimension " + to_string(mapping[i]) +
1374 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001375 }
1376 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001377
1378 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001379}
1380
1381void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1382{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001383 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001384
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001385 ValidateNumInputs(workloadInfo, descriptorName, 1);
1386 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1387
1388 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1389 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1390
1391 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1392 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001393
1394 std::vector<DataType> supportedTypes =
1395 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001396 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001397 DataType::Float32,
1398 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001399 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001400 DataType::QAsymmU8,
1401 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001402 };
1403
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001404 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1405 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001406}
1407
1408void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1409{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001410 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001411
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001412 ValidateNumInputs(workloadInfo, descriptorName, 1);
1413 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1414
1415 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1416 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1417
1418 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1419 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001420
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001421 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001422 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001423 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001424 DataType::Float16,
1425 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001426 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001427 DataType::QAsymmU8,
1428 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001429 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001430
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001431 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1432 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001433
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001434 // ResizeBilinear only changes width and height: batch and channel count must match.
1435 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1436 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001437 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001438 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001439 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001440 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1441 descriptorName, inputBatchSize, outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001442 }
1443
Teresa Charlin970f43b2019-07-01 13:51:07 +01001444 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001445 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1446 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001447 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001448 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001449 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001450 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1451 descriptorName, inputChannelCount, outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001452 }
1453}
1454
1455void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1456{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001457 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001458
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001459 ValidateNumInputs(workloadInfo, descriptorName, 1);
1460 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1461
1462 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1463 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1464
1465 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1466 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001467
1468 std::vector<DataType> supportedTypes =
1469 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001470 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001471 DataType::Float16,
1472 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001473 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001474 DataType::QAsymmU8,
1475 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001476 };
1477
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001478 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1479 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001480
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001481 // Resize only changes width and height: batch and channel count must match.
1482 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1483 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001484 if (inputBatchSize != outputBatchSize)
1485 {
1486 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001487 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1488 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001489 }
1490
1491 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001492 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1493 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001494 if (inputChannelCount != outputChannelCount)
1495 {
1496 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001497 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1498 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001499 }
1500}
1501
1502void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1503{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001504 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001505
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001506 ValidateNumInputs(workloadInfo, descriptorName, 1);
1507 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1508
1509 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1510 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1511
1512 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1513 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1514
1515 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1516
telsoa014fcda012018-03-09 14:13:49 +00001517 if (m_Parameters.m_Min > m_Parameters.m_Max)
1518 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001519 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001520 }
telsoa014fcda012018-03-09 14:13:49 +00001521}
1522
Kevin Mayce5045a2019-10-02 14:07:47 +01001523void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1524{
1525 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1526
1527 ValidateNumInputs(workloadInfo, descriptorName, 1);
1528 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1529
1530 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1531 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1532
1533 if (inputTensorInfo.GetNumDimensions() > 4)
1534 {
1535 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1536 }
1537
1538 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1539
1540 // Check the supported data types
1541 std::vector<DataType> supportedTypes =
1542 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001543 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001544 DataType::Float32,
1545 DataType::Float16
1546 };
1547
1548 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001549 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001550}
1551
telsoa014fcda012018-03-09 14:13:49 +00001552void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1553{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001554 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001555
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001556 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001557 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1558
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001559 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1560 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1561
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001562 if (inputTensorInfo.GetNumDimensions() > 4)
1563 {
1564 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1565 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001566
1567 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001568
1569 // Check the supported data types
1570 std::vector<DataType> supportedTypes =
1571 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001572 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001573 DataType::Float32,
1574 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001575 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001576 DataType::QAsymmU8,
1577 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001578 };
1579
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001580 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001581 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1582}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001583
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001584void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1585{
1586 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1587
1588 ValidateNumInputs(workloadInfo, descriptorName, 1);
1589 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1590
1591 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1592 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1593
1594 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1595
1596 std::vector<DataType> supportedTypes =
1597 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001598 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001599 DataType::Float32,
1600 DataType::Float16,
1601 };
1602
1603 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001604 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001605}
1606
1607void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1608{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001609 const std::string descriptorName{"ConstantQueueDescriptor"};
1610
1611 ValidateNumInputs(workloadInfo, descriptorName, 0);
1612 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001613
1614 if (!m_LayerOutput)
1615 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001616 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001617 }
1618
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001619 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1620 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001621
1622 // Check the supported data types
1623 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001624 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001625 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001626 DataType::Float32,
1627 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001628 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001629 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001630 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001631 DataType::QSymmS16,
1632 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001633 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001634
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001635 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001636}
1637
1638void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1639{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001640 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001641
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001642 ValidateNumInputs(workloadInfo, descriptorName, 1);
1643 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1644
1645 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1646 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1647
1648 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001649
1650 // Check the supported data types
1651 std::vector<DataType> supportedTypes =
1652 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001653 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001654 DataType::Float32,
1655 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001656 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001657 DataType::QAsymmU8,
1658 DataType::QSymmS16,
1659 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001660 };
1661
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001662 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1663 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001664}
1665
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001666void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1667{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001668 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001669
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001670 ValidateNumInputs(workloadInfo, descriptorName, 1);
1671 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1672
1673 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1674 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1675
1676 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1677 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001678
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001679 if (m_Parameters.m_BlockShape.size() != 2)
1680 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001681 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001682 }
1683
1684 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1685 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001686 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1687 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001688 }
1689
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001690 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001691
1692 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001693 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001694
Matthew Bentham8800c002018-11-19 13:19:28 +00001695 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001696
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001697 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1698 widthPad.first + widthPad.second;
1699 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1700 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001701
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001702 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1703 inputShape[dimensionIndices.GetChannelsIndex()];
1704 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001705
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001706 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001707 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001708 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001709 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001710 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001711 }
1712
1713 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001714 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001715 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1716 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001717 }
nikraj01120522a2019-05-31 11:33:07 +01001718
1719 std::vector<DataType> supportedTypes =
1720 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001721 DataType::BFloat16,
1722 DataType::Float16,
1723 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001724 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001725 DataType::QAsymmU8,
1726 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001727 };
1728
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001729 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1730 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001731}
1732
Keith Davisa57eccb2019-06-14 17:33:22 +01001733void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1734{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001735 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001736
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001737 ValidateNumInputs(workloadInfo, descriptorName, 1);
1738 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001739
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001740 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1741 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1742
1743 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1744 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001745
1746 std::vector<DataType> supportedTypes =
1747 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001748 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001749 DataType::Float32,
1750 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001751 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001752 DataType::QAsymmU8,
1753 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001754 };
1755
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001756 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1757 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001758
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001759 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1760
1761 if (m_Parameters.m_BlockSize == 0)
1762 {
1763 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1764 }
1765
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001766 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1767 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1768 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1769 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001770
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001771 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001772 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001773 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001774 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1775 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001776 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001777
1778 const TensorShape& outputShape = outputTensorInfo.GetShape();
1779 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1780 {
1781 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1782 "must be divisible by the square of block size." );
1783 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001784}
1785
telsoa014fcda012018-03-09 14:13:49 +00001786void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1787{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001788 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001789
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001790 ValidateNumInputs(workloadInfo, descriptorName, 1);
1791 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1792
1793 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1794 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001795
1796 std::vector<DataType> supportedTypes =
1797 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001798 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001799 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001800 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001801 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001802 };
1803
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001804 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001805
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001806 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001807 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001808 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001809 }
1810}
1811
telsoa01c577f2c2018-08-31 09:22:23 +01001812void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1813{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001814 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1815
1816 const std::string descriptorName{"LstmQueueDescriptor"};
1817
1818 // check dimensions of all inputs and outputs
1819 if (workloadInfo.m_InputTensorInfos.size() != 3)
1820 {
1821 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1822 }
1823 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1824 {
1825 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1826 }
1827
1828 std::vector<DataType> supportedTypes =
1829 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001830 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001831 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001832 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001833 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001834 };
1835
Jan Eilers38e05bd2019-06-26 13:10:09 +01001836 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001837 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1838
Jan Eilers38e05bd2019-06-26 13:10:09 +01001839 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001840 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001841 {
1842 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1843 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001844 descriptorName,
1845 "input_0",
1846 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001847 }
1848 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001849 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001850 {
1851 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1852 workloadInfo.m_OutputTensorInfos[i],
1853 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001854 "input_0",
1855 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001856 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001857
janeil0117d8d852019-11-15 15:00:16 +00001858 // Making sure clipping parameters have valid values.
1859 // == 0 means no clipping
1860 // > 0 means clipping
1861 if (m_Parameters.m_ClippingThresCell < 0.0f)
1862 {
1863 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
1864 }
1865 if (m_Parameters.m_ClippingThresProj < 0.0f)
1866 {
1867 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
1868 }
1869
Jan Eilers38e05bd2019-06-26 13:10:09 +01001870
1871 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01001872 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1873 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1874 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1875 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1876 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1877 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1878
Jan Eilers38e05bd2019-06-26 13:10:09 +01001879 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001880 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1881 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001882 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001883 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1884 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001885 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001886 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1887 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001888 // scratchBufferTensor
1889 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001890 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1891 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001892 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001893 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1894 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001895 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001896 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1897 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001898 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001899 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1900 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001901
1902
1903 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1904 if ( m_InputToInputWeights )
1905 {
1906 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1907 (n_cell * n_input), "InputLayerNormWeights");
1908 }
1909
1910 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1911 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1912 (n_cell * n_input), "InputToForgetWeights");
1913
1914 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1915 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1916 (n_cell * n_input), "InputToCellWeights");
1917
1918 if ( m_RecurrentToInputWeights )
1919 {
1920 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1921 (n_cell * n_output), "RecurrentToInputWeights");
1922 }
1923
1924 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1925 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1926 (n_cell * n_output), "RecurrentToForgetWeights");
1927
1928 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1929 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1930 (n_cell * n_output), "RecurrentToCellWeights");
1931
1932 // Make sure the input-gate's parameters are either both present (regular
1933 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1934 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1935 !m_Parameters.m_CifgEnabled) ||
1936 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1937 m_Parameters.m_CifgEnabled));
1938 if (!cifg_weights_all_or_none)
1939 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001940 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1941 "RecurrentToInputWeights must either both be present (regular LSTM) "
1942 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1943 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001944 }
1945
1946 if ( m_CellToInputWeights )
1947 {
1948 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1949 n_cell, "CellToInputWeights");
1950 }
1951 if ( m_CellToForgetWeights )
1952 {
1953 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1954 n_cell, "CellToForgetWeights");
1955 }
1956 if ( m_CellToOutputWeights )
1957 {
1958 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1959 n_cell, "CellToOutputWeights");
1960 }
1961
1962 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1963 bool peephole_weights_all_or_none =
1964 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1965 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1966 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1967 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1968 if (!peephole_weights_all_or_none)
1969 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001970 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001971 }
1972
1973 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1974 if (m_Parameters.m_CifgEnabled)
1975 {
1976 if (m_InputGateBias)
1977 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001978 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001979 }
1980 }
1981 else
1982 {
1983 if (!m_InputGateBias)
1984 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001985 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1986 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001987 }
1988 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1989 n_cell, "InputGateBias");
1990 }
1991
1992 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1993 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1994
1995 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1996 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1997
1998 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1999 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2000
2001 if (m_ProjectionWeights)
2002 {
2003 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2004 (n_cell * n_output), "ProjectionWeights");
2005 }
2006 if (m_ProjectionBias)
2007 {
2008 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2009 }
2010
2011 // Making sure the projection tensors are consistent:
2012 // 1) If projection weight is not present, then projection bias should not be
2013 // present.
2014 // 2) If projection weight is present, then projection bias is optional.
2015 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2016 !m_Parameters.m_ProjectionEnabled)
2017 || (m_ProjectionWeights && !m_ProjectionBias &&
2018 m_Parameters.m_ProjectionEnabled)
2019 || (m_ProjectionWeights && m_ProjectionBias &&
2020 m_Parameters.m_ProjectionEnabled));
2021 if (!projecton_tensors_consistent)
2022 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002023 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002024 }
2025
2026 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2027 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2028 // either all have values or none of them have values. Layer normalization is used when the values of all the
2029 // layer normalization weights are present
2030 if (m_InputLayerNormWeights)
2031 {
2032 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2033 }
2034 if (m_ForgetLayerNormWeights)
2035 {
2036 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2037 }
2038 if (m_CellLayerNormWeights)
2039 {
2040 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2041 }
2042 if (m_OutputLayerNormWeights)
2043 {
2044 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2045 }
2046
Jan Eilers38e05bd2019-06-26 13:10:09 +01002047 if (m_Parameters.m_LayerNormEnabled)
2048 {
2049 if (!m_Parameters.m_CifgEnabled)
2050 {
2051 if (!m_InputLayerNormWeights)
2052 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002053 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2054 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002055 }
2056 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2057 1, n_cell, "InputLayerNormWeights");
2058 }
2059 else if (m_InputLayerNormWeights)
2060 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002061 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2062 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002063 }
2064
2065 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2066 "ForgetLayerNormWeights");
2067 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2068
2069 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2070 "OutputLayerNormWeights");
2071 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2072
2073 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2074 "CellLayerNormWeights");
2075 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2076 }
2077 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2078 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002079 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2080 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002081 }
telsoa01c577f2c2018-08-31 09:22:23 +01002082}
2083
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002084void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2085{
2086 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2087
2088 ValidateNumInputs(workloadInfo, descriptorName, 1);
2089 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2090
2091 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2092 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2093
2094 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2095 {
2096 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2097 }
2098
2099 if (outputTensorInfo.GetDataType() != DataType::Float32)
2100 {
2101 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2102 }
2103
2104 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2105}
2106
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002107void ConvertFp32ToBf16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2108{
2109 const std::string descriptorName{"ConvertFp32ToBf16QueueDescriptor"};
2110
2111 ValidateNumInputs(workloadInfo, descriptorName, 1);
2112 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2113
2114 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2115 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2116
2117 if (inputTensorInfo.GetDataType() != DataType::Float32)
2118 {
2119 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2120 }
2121
2122 if (outputTensorInfo.GetDataType() != DataType::BFloat16)
2123 {
2124 throw InvalidArgumentException(descriptorName + ": Output tensor type must be BFloat16.");
2125 }
2126
2127 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2128}
2129
telsoa01c577f2c2018-08-31 09:22:23 +01002130void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2131{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002132 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002133
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002134 ValidateNumInputs(workloadInfo, descriptorName, 1);
2135 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2136
2137 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2138 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2139
2140 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002141 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002142 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002143 }
2144
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002145 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002146 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002147 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002148 }
2149
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002150 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002151}
2152
2153void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2154{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002155 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002156
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002157 ValidateNumInputs(workloadInfo, descriptorName, 1);
2158 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2159
2160 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2161 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2162
2163 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002164 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002165 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002166 }
2167
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002168 if (outputTensorInfo.GetDataType() != DataType::Float32)
2169 {
2170 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2171 }
2172
2173 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002174}
2175
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002176void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2177{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002178 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002179
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002180 ValidateNumInputs(workloadInfo, descriptorName, 2);
2181 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2182
2183 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2184 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2185 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2186
2187 std::vector<DataType> supportedTypes =
2188 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002189 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002190 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002191 DataType::Float32,
2192 DataType::QAsymmS8,
2193 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002194 DataType::QSymmS16,
2195 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002196 };
2197
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002198 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2199 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2200 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002201
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002202 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2203 inputTensorInfo1,
2204 outputTensorInfo,
2205 descriptorName,
2206 "input_0",
2207 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002208}
2209
David Beckc2044fe2018-09-05 15:00:38 +01002210void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2211{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002212 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002213
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002214 ValidateNumInputs(workloadInfo, descriptorName, 2);
2215 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2216
2217 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2218 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2219 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2220
2221 std::vector<DataType> supportedTypes =
2222 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002223 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002224 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002225 DataType::Float32,
2226 DataType::QAsymmS8,
2227 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002228 DataType::QSymmS16,
2229 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002230 };
2231
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002232 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2233 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2234 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002235
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002236 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2237 inputTensorInfo1,
2238 outputTensorInfo,
2239 descriptorName,
2240 "input_0",
2241 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002242}
2243
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002244void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2245{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002246 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002247
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002248 ValidateNumInputs(workloadInfo, descriptorName, 2);
2249 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2250
2251 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2252 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2253 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2254
2255 std::vector<DataType> supportedTypes =
2256 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002257 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002258 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002259 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002260 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002261 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002262 DataType::QSymmS16,
2263 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002264 };
2265
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002266 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2267 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2268 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002269
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002270 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2271 inputTensorInfo1,
2272 outputTensorInfo,
2273 descriptorName,
2274 "input_0",
2275 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002276}
2277
narpra01a6bf9122018-09-10 09:50:09 +01002278void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2279{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002280 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002281
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002282 ValidateNumInputs(workloadInfo, descriptorName, 1);
2283 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2284
2285 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2286 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002287
2288 std::vector<DataType> supportedTypes =
2289 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002290 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002291 DataType::Float32,
2292 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002293 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002294 DataType::QAsymmU8,
2295 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002296 };
narpra01eb061912018-09-10 17:35:27 +01002297
James Conroy4d1ff582019-06-10 17:06:39 +01002298 // First check if input tensor data type is supported, then
2299 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002300 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2301 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002302
narpra0132b90462018-09-13 11:07:48 +01002303 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002304 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002305 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002306 }
narpra0132b90462018-09-13 11:07:48 +01002307 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002308 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002309 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002310 }
2311 else
2312 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002313 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002314 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002315 ValidateTensorNumDimensions(outputTensorInfo,
2316 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002317 outputDim > 0 ? outputDim : 1,
2318 "output");
2319 }
narpra01a6bf9122018-09-10 09:50:09 +01002320}
2321
jimfly012c9322a2018-09-19 10:59:49 +01002322void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2323{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002324 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002325
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002326 ValidateNumInputs(workloadInfo, descriptorName, 1);
2327 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2328
2329 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2330 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002331
jimfly012c9322a2018-09-19 10:59:49 +01002332 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002333 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2334
jimfly012c9322a2018-09-19 10:59:49 +01002335 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002336 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2337 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2338 "as there are dimensions in the input tensor that is " +
2339 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2340 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002341 }
2342}
2343
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002344void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2345{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002346 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002347
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002348 ValidateNumInputs(workloadInfo, descriptorName, 1);
2349 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002350
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002351 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2352 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2353
Sadik Armagan2208b602019-07-31 16:36:27 +01002354 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002355 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002356 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002357 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002358 DataType::Float16,
2359 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002360 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002361 DataType::QAsymmU8,
2362 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002363 };
2364
2365 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002366
Keith Davis0c2eeac2020-02-11 16:51:50 +00002367 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002368 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002369 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002370 }
2371}
2372
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002373void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2374{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002375 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002376
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002377 ValidateNumInputs(workloadInfo, descriptorName, 1);
2378 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002379
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002380 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2381 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002382
2383 std::vector<DataType> supportedTypes =
2384 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002385 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002386 DataType::Float32,
2387 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002388 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002389 DataType::QAsymmU8,
2390 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002391 };
2392
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002393 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2394 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002395}
2396
Conor Kennedy430b5d82018-11-14 15:28:28 +00002397void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2398{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002399 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002400
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002401 ValidateNumInputs(workloadInfo, descriptorName, 1);
2402 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2403
2404 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2405 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002406
2407 std::vector<DataType> supportedTypes =
2408 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002409 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002410 DataType::Float16,
2411 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002412 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002413 DataType::QAsymmU8,
2414 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002415 };
2416
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002417 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2418 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002419
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002420 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002421
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002422 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002423 if (rank > 4)
2424 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002425 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002426 }
2427
Conor Kennedy430b5d82018-11-14 15:28:28 +00002428 // Begin, End & Stride length must be of rank(input0)
2429 if (m_Parameters.m_Begin.size() != rank)
2430 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002431 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002432 }
2433
2434 if (m_Parameters.m_End.size() != rank)
2435 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002436 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002437 }
2438
2439 if (m_Parameters.m_Stride.size() != rank)
2440 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002441 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002442 }
2443
2444 // Stride entries must be non-zero
2445 for (auto& stride : m_Parameters.m_Stride)
2446 {
2447 if (stride == 0)
2448 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002449 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002450 }
2451 }
2452}
2453
kevmay0190539692018-11-29 08:40:19 +00002454void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2455{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002456 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002457
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002458 ValidateNumInputs(workloadInfo, descriptorName, 2);
2459 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2460
2461 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2462 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2463 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2464
2465 std::vector<DataType> supportedTypes =
2466 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002467 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002468 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002469 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002470 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002471 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002472 DataType::QSymmS16,
2473 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002474 };
2475
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002476 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2477 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2478 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002479
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002480 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2481 inputTensorInfo1,
2482 outputTensorInfo,
2483 descriptorName,
2484 "input_0",
2485 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002486}
2487
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002488void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2489{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002490 const std::string descriptorName{"DebugQueueDescriptor"};
2491
2492 ValidateNumInputs(workloadInfo, descriptorName, 1);
2493 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002494}
2495
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002496void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2497{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002498 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002499
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002500 ValidateNumInputs(workloadInfo, descriptorName, 2);
2501 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002502
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002503 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2504 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2505 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2506
2507 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2508 inputTensorInfo1,
2509 outputTensorInfo,
2510 descriptorName,
2511 "input_0",
2512 "input_1");
2513
2514 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002515 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002516 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002517 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002518}
2519
FrancisMurtagh878f0232018-12-19 10:56:15 +00002520void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2521{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002522 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002523
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002524 ValidateNumInputs(workloadInfo, descriptorName, 2);
2525 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002526
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002527 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2528 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2529 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2530
2531 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2532 inputTensorInfo1,
2533 outputTensorInfo,
2534 descriptorName,
2535 "input_0",
2536 "input_1");
2537
2538 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002539 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002540 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002541 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002542}
2543
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002544void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2545{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002546 const std::string descriptorName{"RsqrtQueueDescriptor"};
2547
2548 ValidateNumInputs(workloadInfo, descriptorName, 1);
2549 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2550
2551 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2552 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2553
2554 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002555
2556 std::vector<DataType> supportedTypes =
2557 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002558 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002559 DataType::Float16,
2560 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002561 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002562 DataType::QAsymmU8,
2563 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002564 };
2565
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002566 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2567 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002568}
2569
narpra01b89b05f2019-01-16 09:53:09 +00002570void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2571{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002572 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002573
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002574 ValidateNumInputs(workloadInfo, descriptorName, 2);
2575 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002576
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002577 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2578 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002579 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002580 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002581 }
2582
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002583 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2584 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2585
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002586 std::vector<DataType> supportedTypes =
2587 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002588 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002589 DataType::Float16,
2590 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002591 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002592 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002593 DataType::QSymmS16,
2594 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002595 };
2596
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002597 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002598
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002599 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002600
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002601 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2602 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002603}
2604
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002605void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2606{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002607 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2608
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002609 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002610
2611 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2612 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002613 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002614 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2615 }
2616
2617 if (m_Anchors == nullptr)
2618 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002619 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002620 }
2621
2622 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002623 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2624 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2625
2626 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002627 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002628 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2629 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002630
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002631 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2632 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2633 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002634
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002635 const std::vector<DataType> supportedInputTypes =
2636 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002637 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002638 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002639 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002640 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002641 DataType::QAsymmU8,
2642 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002643 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002644
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002645 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2646 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2647 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2648
2649 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2650 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2651 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2652 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2653
2654 // NOTE: Output is always Float32 regardless of input type
2655 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2656 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2657 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2658 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002659
2660 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2661 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002662 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002663 "must be positive and less than or equal to 1.");
2664 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002665
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002666 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2667 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002668 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002669 "should be equal to number of classes + 1.");
2670 }
2671}
2672
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002673void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2674{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002675 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002676
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002677 ValidateNumInputs(workloadInfo, descriptorName, 1);
2678 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2679
2680 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2681 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2682
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002683 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002684 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002685 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002686 }
2687
Sadik Armagan2208b602019-07-31 16:36:27 +01002688 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002689 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002690 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002691 DataType::Float32,
2692 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002693 };
2694
2695 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002696}
2697
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002698void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2699{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002700 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002701
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002702 ValidateNumInputs(workloadInfo, descriptorName, 2);
2703 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002704
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002705 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2706 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2707 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002708
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002709 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2710 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2711
2712 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2713 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002714}
2715
Sadik Armaganeff363d2019-04-05 15:25:46 +01002716void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2717{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002718 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002719
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002720 ValidateNumInputs(workloadInfo, descriptorName, 2);
2721 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2722
2723 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2724 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2725
2726 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2727 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2728
2729 std::vector<DataType> supportedTypes =
2730 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002731 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002732 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002733 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002734 DataType::QAsymmU8,
2735 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002736 };
2737
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002738 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2739 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002740
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002741 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2742 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002743
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002744 ValidateTensorShapesMatch(inputTensorInfo0,
2745 outputTensorInfo0,
2746 descriptorName,
2747 "input_0",
2748 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002749
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002750 ValidateTensorShapesMatch(inputTensorInfo0,
2751 outputTensorInfo1,
2752 descriptorName,
2753 "input_0",
2754 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002755}
2756
Derek Lamberti901ea112019-12-10 22:07:09 +00002757void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002758{
2759 // This is internally generated so it should not need validation.
2760}
2761
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002762void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2763{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002764 const std::string& descriptorName{"PreluQueueDescriptor"};
2765
2766 ValidateNumInputs(workloadInfo, descriptorName, 2);
2767 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2768
2769 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2770 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2771 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002772
2773 std::vector<DataType> supportedTypes
2774 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002775 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002776 DataType::Float16,
2777 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002778 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002779 DataType::QAsymmU8,
2780 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002781 };
2782
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002783 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2784 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002785
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002786 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002787
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002788 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2789 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002790
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002791 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2792 alphaTensorInfo,
2793 outputTensorInfo,
2794 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002795 "input",
2796 "alpha");
2797}
2798
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002799void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2800{
2801 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2802
2803 ValidateNumInputs(workloadInfo, descriptorName, 1);
2804 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2805
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002806 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2807 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2808
2809 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2810 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002811
2812 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002813
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002814 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2815 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002816
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002817 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2818
2819 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002820 if (m_Parameters.m_BiasEnabled)
2821 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002822 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002823
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002824 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2825 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002826
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002827 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002828 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002829 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002830
2831 ValidatePerAxisQuantization(inputTensorInfo,
2832 outputTensorInfo,
2833 weightTensorInfo,
2834 optionalBiasTensorInfo,
2835 descriptorName);
2836
2837 std::vector<DataType> supportedTypes =
2838 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002839 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002840 DataType::Float32,
2841 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002842 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002843 DataType::QAsymmU8,
2844 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002845 };
2846
2847 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2848 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002849}
2850
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002851void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2852{
2853 const std::string descriptorName{"TransposeQueueDescriptor"};
2854
2855 ValidateNumInputs(workloadInfo, descriptorName, 1);
2856 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2857
2858 const PermutationVector& mapping = m_Parameters.m_DimMappings;
2859
2860 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2861 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2862
2863 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
2864 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
2865
2866 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
2867 {
2868 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
2869 {
2870 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
2871 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
2872 "must match dst dimension " + to_string(i) +
2873 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
2874 }
2875 }
2876
2877 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2878}
2879
James Conroy4f1f8992020-04-29 20:01:10 +01002880void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2881{
2882 const std::string descriptorName{"QLstmQueueDescriptor"};
2883
2884 // Validate number of inputs/outputs
2885 ValidateNumInputs(workloadInfo, descriptorName, 3);
2886 ValidateNumOutputs(workloadInfo, descriptorName, 3);
2887
2888 // Input/output tensor info
2889 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2890 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
2891 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
2892
2893 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2894 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2895 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
2896
2897 // Supported types for various tensors in QLSTM
2898 std::vector<DataType> inputOutputSupportedTypes =
2899 {
2900 DataType::QAsymmS8
2901 };
2902
2903 std::vector<DataType> cellStateSupportedTypes =
2904 {
2905 DataType::QSymmS16
2906 };
2907
2908 std::vector<DataType> weightsSupportedTypes =
2909 {
2910 DataType::QSymmS8
2911 };
2912
2913 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
2914 {
2915 DataType::QSymmS16
2916 };
2917
2918 std::vector<DataType> biasSupportedTypes =
2919 {
2920 DataType::Signed32
2921 };
2922
2923 // Validate types of input/output tensors
2924 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2925 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2926 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2927
2928 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2929 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2930 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
2931
2932 // Validate matching types of input/output tensors
2933 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2934 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2935 "outputStateIn", "outputStateOut");
2936 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2937
2938 // Infer number of batches, number of units, input size and output size from tensor dimensions
2939 const uint32_t numBatches = inputInfo.GetShape()[0];
2940 const uint32_t inputSize = inputInfo.GetShape()[1];
2941 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
2942 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
2943
2944 // Validate number of dimensions and number of elements for input/output tensors
2945 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2946 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2947 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
2948
2949 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2950 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
2951 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
2952
2953 // Validate number of dimensions and number of elements for MANDATORY weight tensors
2954 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2955 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2956 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
2957
2958 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2959 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2960 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
2961
2962 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2963 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2964 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
2965
2966 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2967 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2968 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
2969 " RecurrentToForgetWeights");
2970
2971 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2972 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2973 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
2974
2975 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2976 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2977 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
2978
2979 // Validate data types for MANDATORY weights tensors (all should match each other)
2980 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
2981
2982 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
2983 "inputToForgetWeights", "inputToCellWeights");
2984 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2985 "inputToForgetWeights", "inputToOutputWeights");
2986
2987 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2988 "inputToForgetWeights", "recurrentToForgeteights");
2989 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2990 "inputToForgetWeights", "recurrentToCellWeights");
2991 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2992 "inputToForgetWeights", "recurrentToOutputWeights");
2993
2994 // Validate number of dimensions and number of elements for MANDATORY bias tensors
2995 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2996 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2997 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
2998
2999 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3000 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3001 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3002
3003 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3004 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3005 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3006
3007 // Validate data types for MANDATORY bias tensors
3008 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3009
3010 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3011 "forgetGateBias", "cellBias");
3012 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3013 "forgetGateBias", "outputGateBias");
3014
3015 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3016 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3017 !m_Parameters.m_CifgEnabled) ||
3018 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3019 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3020
3021 if (!allCifgParamsPresentOrNot)
3022 {
3023 throw InvalidArgumentException(descriptorName +
3024 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3025 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3026 "set appropriately.");
3027 }
3028
3029 if (!m_Parameters.m_CifgEnabled)
3030 {
3031 // Validate number of dimensions and number of elements
3032 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3033 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3034
3035 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3036 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3037 " RecurrentToInputWeights");
3038
3039 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3040 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3041
3042 // Validate data types
3043 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3044 "inputToForgetWeights", "inputToInputWeights");
3045 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3046 "inputToForgetWeights", "recurrentToInputWeights");
3047 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3048 "forgetGateBias", "inputGateBias");
3049 }
3050
3051 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3052 bool allPeepholeWeightsPresentOrNot =
3053 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3054 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3055 || (!m_CellToInputWeights && !m_CellToForgetWeights
3056 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3057
3058 if (!allPeepholeWeightsPresentOrNot)
3059 {
3060 throw InvalidArgumentException(descriptorName +
3061 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3062 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3063 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3064 "appropriately.");
3065 }
3066
3067 if (m_Parameters.m_PeepholeEnabled)
3068 {
3069 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3070 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3071 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3072
3073 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3074 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3075 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3076 "cellToForgetWeight", "cellToOutputWeights");
3077
3078 if (!m_Parameters.m_CifgEnabled)
3079 {
3080 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3081 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3082 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3083 "cellToForgetWeights", "cellToInputWeights");
3084 }
3085 }
3086
3087 // Validate OPTIONAL params: Layer Norm Weights
3088 bool allLayerNormWeightsPresentOrNot =
3089 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3090 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3091 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3092 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3093
3094 if (!allLayerNormWeightsPresentOrNot)
3095 {
3096 throw InvalidArgumentException(descriptorName +
3097 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3098 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3099 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3100 "only be present when Layer Norm is enabled and CIFG is disabled. "
3101 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3102 }
3103
3104 if (m_Parameters.m_LayerNormEnabled)
3105 {
3106 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3107 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3108 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3109
3110 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3111 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3112 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3113 "forgetLayerNormWeights", "cellLayerNormWeights");
3114
3115 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3116 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3117 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3118 "forgetLayerNormWeights", "outputLayerNormWeights");
3119
3120 if (!m_Parameters.m_CifgEnabled)
3121 {
3122 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3123 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3124 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3125 "forgetLayerNormWeights", "inputLayerNormWeights");
3126 }
3127 }
3128
3129 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3130 bool correctProjectionTensorsPresent =
3131 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3132 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3133 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3134
3135 if (!correctProjectionTensorsPresent)
3136 {
3137 throw InvalidArgumentException(descriptorName +
3138 ": If projection is enabled, ProjectionWeights should be present and "
3139 "ProjectionBias is optional. If projection is disabled, neither "
3140 "ProjectionWeights nor ProjectionBias should be present.");
3141 }
3142
3143 if (m_Parameters.m_ProjectionEnabled)
3144 {
3145 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3146 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3147 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3148
3149 if (m_ProjectionBias)
3150 {
3151 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003152 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003153 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3154 }
3155
3156 }
3157 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3158 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3159 throw InvalidArgumentException(descriptorName +
3160 ": If projection is disabled, output quantization info (scale, offset) "
3161 "should match HiddenStateScale and HiddenStateZeroPoint.");
3162 }
3163
3164}
3165
James Conroy9c3cae82019-08-01 16:01:48 +01003166void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3167{
3168 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3169
3170 // Validate number of inputs/outputs
3171 ValidateNumInputs(workloadInfo, descriptorName, 3);
3172 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3173
3174 // Input/output tensor infos
3175 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3176 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3177 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3178
3179 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3180 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3181
3182 std::vector<DataType> inputOutputSupportedTypes =
3183 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003184 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003185 };
3186
3187 std::vector<DataType> cellStateSupportedTypes =
3188 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003189 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003190 };
3191
3192 std::vector<DataType> weightsSupportedTypes =
3193 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003194 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003195 };
3196
3197 std::vector<DataType> biasSupportedTypes =
3198 {
3199 DataType::Signed32
3200 };
3201
3202 // Validate types of input/output tensors
3203 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3204 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3205 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3206
3207 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3208 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3209
3210 // Validate matching types of input/output tensors
3211 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3212 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3213 "outputStateIn", "outputStateOut");
3214 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3215
3216 // Validate matching quantization info for input/output tensors
3217 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3218 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3219 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003220
James Conroy9c3cae82019-08-01 16:01:48 +01003221 // Infer number of batches, input size and output size from tensor dimensions
3222 const uint32_t numBatches = inputInfo.GetShape()[0];
3223 const uint32_t inputSize = inputInfo.GetShape()[1];
3224 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3225
3226 // Validate number of dimensions and number of elements for input/output tensors
3227 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3228 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3229 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3230 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3231 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3232
3233 // Validate number of dimensions and number of elements for weights tensors
3234 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3235 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3236 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3237
3238 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3239 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3240 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3241
3242 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3243 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3244 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3245
3246 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3247 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3248 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3249
3250 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3251 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3252 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3253
3254 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3255 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3256 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3257 " RecurrentToForgetWeights");
3258
3259 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3260 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3261 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3262
3263 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3264 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3265 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3266
3267 // Validate data types for weights tensors (all should match each other)
3268 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3269
3270 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3271 "inputToInputWeights", "inputToForgetWeights");
3272 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3273 "inputToInputWeights", "inputToCellWeights");
3274 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3275 "inputToInputWeights", "inputToOutputWeights");
3276
3277 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3278 "inputToInputWeights", "recurrentToInputWeights");
3279 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3280 "inputToInputWeights", "recurrentToForgeteights");
3281 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3282 "inputToInputWeights", "recurrentToCellWeights");
3283 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3284 "inputToInputWeights", "recurrentToOutputWeights");
3285
3286 // Validate matching quantization info for weight tensors (all should match each other)
3287 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3288 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3289 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3290 descriptorName, "inputToInputWeights", "inputToCellWeights");
3291 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3292 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3293
3294 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3295 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3296 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3297 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3298 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3299 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3300 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3301 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3302
3303 // Validate number of dimensions and number of elements in bias tensors
3304 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3305 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3306 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3307
3308 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3309 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3310 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3311
3312 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3313 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3314 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3315
3316 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3317 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3318 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3319
3320 // Validate data types for bias tensors (all should match each other)
3321 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3322
3323 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3324 "inputGateBias", "forgetGateBias");
3325 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3326 "inputGateBias", "cellBias");
3327 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3328 "inputGateBias", "outputGateBias");
3329
3330 // Validate bias tensor quantization info
3331 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3332 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3333 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3334 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3335}
3336
Kevin May868eb142019-09-04 17:29:31 +01003337void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3338{
3339 const std::string descriptorName{"AbsQueueDescriptor"};
3340
3341 ValidateNumInputs(workloadInfo, descriptorName, 1);
3342 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3343
3344 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3345 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3346
3347 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3348
3349 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003350 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003351 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003352 DataType::Float16,
3353 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003354 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003355 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003356 DataType::QSymmS16,
3357 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003358 };
Kevin May868eb142019-09-04 17:29:31 +01003359
3360 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3361 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3362}
3363
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003364void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3365{
3366 const std::string descriptorName{"SliceQueueDescriptor"};
3367
3368 ValidateNumInputs(workloadInfo, descriptorName, 1);
3369 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3370
3371 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3372 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3373
3374 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3375
3376 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3377 if (rank > 4)
3378 {
3379 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3380 }
3381
3382 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3383
3384 // Check if m_Begin and m_Size have the expected length
3385 if (m_Parameters.m_Begin.size() != rank)
3386 {
3387 throw InvalidArgumentException(descriptorName +
3388 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3389 }
3390 if (m_Parameters.m_Size.size() != rank)
3391 {
3392 throw InvalidArgumentException(descriptorName +
3393 ": Length of size descriptor must equal rank " + std::to_string(rank));
3394 }
3395
3396 // Check if the shape of the output tensor matches m_Size
3397 const TensorShape& outputShape = outputTensorInfo.GetShape();
3398 for (unsigned int i = 0u; i < rank; ++i)
3399 {
3400 if (m_Parameters.m_Size[i] != outputShape[i])
3401 {
3402 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3403 }
3404 }
3405
3406 // Check if the sum of begin offset and size in a given dimension
3407 // does not exceed the size of corresponding input
3408 const TensorShape& inputShape = inputTensorInfo.GetShape();
3409 for(unsigned int i = 0u; i < rank; ++i)
3410 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003411 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003412 {
3413 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3414 std::to_string(i) + " exceeds input size.");
3415 }
3416 }
3417}
3418
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003419void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3420{
3421 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3422
3423 ValidateNumInputs(workloadInfo, descriptorName, 1);
3424 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3425
3426 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3427 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3428
3429 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3430 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3431
3432 std::vector<DataType> supportedTypes =
3433 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003434 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003435 DataType::Float32,
3436 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003437 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003438 DataType::QAsymmU8,
3439 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003440 };
3441
3442 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3443 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3444
3445 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3446
3447 if (m_Parameters.m_BlockSize == 0)
3448 {
3449 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3450 }
3451
3452 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3453 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3454 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3455 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3456
3457 const TensorShape& outputShape = outputInfo.GetShape();
3458 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3459 {
3460 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3461 "must be divisible by block size.");
3462 }
3463
3464 const TensorShape& inputShape = inputInfo.GetShape();
3465 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3466 {
3467 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3468 "must be divisible by the square of block size." );
3469 }
3470}
3471
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003472void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3473{
3474 const std::string descriptorName{"ComparisonQueueDescriptor"};
3475
3476 ValidateNumInputs(workloadInfo, descriptorName, 2);
3477 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3478
3479 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3480 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3481 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3482
3483 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3484 inputTensorInfo1,
3485 outputTensorInfo,
3486 descriptorName,
3487 "input_0",
3488 "input_1");
3489
3490 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3491 {
3492 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3493 }
3494}
3495
josh minor4a3c6102020-01-06 16:40:46 -06003496void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3497{
3498 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3499
3500 ValidateNumInputs(workloadInfo, descriptorName, 1);
3501 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3502
3503 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3504 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3505
3506 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3507
3508 std::vector<DataType> supportedTypes =
3509 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003510 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003511 DataType::Float16,
3512 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003513 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003514 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003515 DataType::QSymmS16,
3516 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003517 };
3518
3519 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3520 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3521}
3522
Finn Williams2605b232020-06-10 15:53:46 +01003523void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3524{
3525 const std::string descriptorName{"RankQueueDescriptor"};
3526
3527 ValidateNumInputs(workloadInfo, descriptorName, 1);
3528 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3529
3530 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3531 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3532
3533 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3534 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3535
3536 std::vector<DataType> supportedTypes =
3537 {
3538 DataType::BFloat16,
3539 DataType::Float16,
3540 DataType::Float32,
3541 DataType::QAsymmS8,
3542 DataType::QAsymmU8,
3543 DataType::QSymmS8,
3544 DataType::QSymmS16,
3545 DataType::Signed32
3546 };
3547
3548 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3549 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3550}
3551
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003552} // namespace armnn