blob: be0ac707a8acb213ac89f55988786ed0b68a1651 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
James Conroy1f58f032021-04-27 17:13:27 +01006#include <backendsCommon/TensorHandle.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00007#include <backendsCommon/WorkloadData.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/TensorUtils.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010010#include <armnn/utility/NumericCast.hpp>
mathad01df9a3222021-04-28 11:42:57 +010011#include <armnn/Logging.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000012
telsoa014fcda012018-03-09 14:13:49 +000013#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000015#include <string>
16#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000017
James Ward47fce872020-09-10 11:57:28 +010018#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000019
Matteo Martincigh21350152018-11-28 16:22:22 +000020using namespace armnnUtils;
21
telsoa014fcda012018-03-09 14:13:49 +000022namespace armnn
23{
24
25//---------------------------------------------------------------
26DataType GetBiasDataType(DataType inputDataType)
27{
28 switch (inputDataType)
29 {
telsoa01c577f2c2018-08-31 09:22:23 +010030 case DataType::Float16:
31 return DataType::Float16;
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +000032 case DataType::BFloat16:
telsoa014fcda012018-03-09 14:13:49 +000033 case DataType::Float32:
34 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000035 case DataType::QAsymmS8:
36 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000037 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000038 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000039 case DataType::QSymmS8:
40 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000041 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010042 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000043 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010044 ARMNN_ASSERT_MSG(false, "Invalid input data type");
telsoa014fcda012018-03-09 14:13:49 +000045 return DataType::Float32;
46 }
47}
48
49namespace
50{
51
52//---------------------------------------------------------------
53//android ndk does not support std::to_string function.
54template <typename T>
55std::string to_string(T value)
56{
57 std::ostringstream os;
58 os << value;
59 return os.str();
60}
61
62//---------------------------------------------------------------
63void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
64{
65 if (!ptr)
66 {
67 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
68 paramName + " parameter must be set.");
69 }
70}
71
72//---------------------------------------------------------------
73void ValidateTensorShapesMatch(const TensorInfo& first,
74 const TensorInfo& second,
75 std::string const& descName,
76 std::string const& firstName,
77 std::string const& secondName)
78{
79 if (first.GetShape() != second.GetShape())
80 {
81 throw InvalidArgumentException(descName + ": "
82 + firstName + " & " + secondName + " must have identical shapes");
83 }
84}
85
86//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010087void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000088{
Sadik Armaganeff363d2019-04-05 15:25:46 +010089 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000090 {
91 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010092 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000093 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
94 }
95}
96
97//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010098void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000099{
Sadik Armaganeff363d2019-04-05 15:25:46 +0100100 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000101 {
102 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100103 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000104 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
105 }
106}
107
108//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100109void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000110 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100111 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000112 std::string const& tensorName)
113{
114 if (tensor.GetNumDimensions() != numDimensions)
115 {
116 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
117 to_string(tensor.GetNumDimensions()) + " dimensions for " +
118 tensorName + " tensor.");
119 }
120}
121
122//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100123void ValidateTensorNumElements(const TensorInfo& tensor,
124 std::string const& descName,
125 unsigned int numElements,
126 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100127{
128 if (tensor.GetNumElements() != numElements)
129 {
130 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100131 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100132 tensorName + " tensor.");
133 }
134}
135
136//---------------------------------------------------------------
137void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100138 unsigned int numDimension,
139 unsigned int numElements,
140 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100141{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100142 const std::string functionName{"ValidateTensorNumDimNumElem"};
143 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
144 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100145}
146
147//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000148void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
149 const std::string& descName, std::string const& tensorName)
150{
151 if (tensor.GetDataType() != dataType)
152 {
153 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
154 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
155 }
156}
157
Derek Lambertid466a542020-01-22 15:37:29 +0000158void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
159{
160 ARMNN_NO_DEPRECATE_WARN_BEGIN
161 if (tensor.GetDataType() != DataType::QSymmS8 &&
162 tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
163 {
164 throw InvalidArgumentException(descName +
165 ": Expected data type which supports per-axis quantization scheme but got " +
166 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
167 }
168 ARMNN_NO_DEPRECATE_WARN_END
169}
170
telsoa014fcda012018-03-09 14:13:49 +0000171//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100172void ValidateTensorQuantizationSpace(const TensorInfo& first,
173 const TensorInfo& second,
174 const std::string& descName,
175 std::string const& firstName,
176 std::string const& secondName)
177{
178 if (!first.IsQuantized() ||
179 !second.IsQuantized())
180 {
181 // Not a quantized type, ignore the validation
182 return;
183 }
184
185 DataType firstDataType = first.GetDataType();
186 DataType secondDataType = second.GetDataType();
187
188 if (firstDataType != secondDataType)
189 {
190 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
191 " must be of the same quantized type, " +
192 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
193 secondName + " is " + GetDataTypeName(secondDataType));
194 }
195
196 if (!first.IsTypeSpaceMatch(second))
197 {
198 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
199 " must have the same quantization space, " +
200 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
201 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
202 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
203 " and scale " + to_string(second.GetQuantizationScale()));
204 }
205}
206
207//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100208void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
209 const TensorInfo& inputTensorInfo,
210 const TensorInfo& weightsTensorInfo,
211 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000212{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000213 // Helper lambda function to validate a single bias quantization scale value
214 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
215 {
mathad01df9a3222021-04-28 11:42:57 +0100216 constexpr float tolerance = 0.0001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000217 if (std::abs(biasScale - expectedScale) > tolerance)
218 {
219 // Print the float values with extra precision to see very small differences
mathad01df9a3222021-04-28 11:42:57 +0100220 ARMNN_LOG(warning) << std::setprecision(6) << descName << ": Expected " << expectedScale <<
221 " for bias quantization scale (product of input and weight scales), but got " <<
222 biasScale << ". Using scale provided.";
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000223 }
224 };
225
telsoa014fcda012018-03-09 14:13:49 +0000226 if (biasTensor.GetQuantizationOffset() != 0)
227 {
228 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
229 to_string(biasTensor.GetQuantizationOffset()));
230 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000231
James Conroy8502ade2020-11-12 19:26:29 +0000232 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000233 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000234 // Validate per-axis quantization scales
235 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
236 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
237
238 if (weightScales.size() != biasScales.size())
239 {
240 std::stringstream msg;
James Conroy8502ade2020-11-12 19:26:29 +0000241 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
242 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
243 << ", biases=" << biasScales.size();
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000244 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
245 }
246
247 for (size_t i = 0ul; i < biasScales.size(); ++i)
248 {
249 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
250 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
251 }
252 }
253 else
254 {
255 // Validate per-tensor quantization scale
256 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
257 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000258 }
259}
260
261//---------------------------------------------------------------
262void ValidateTensors(const std::vector<ITensorHandle*>& vec,
263 unsigned int numExpected,
264 const std::string& descName,
265 const std::string& varName)
266{
267 if (vec.empty() && numExpected > 0)
268 {
269 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
270 }
271
272 for (unsigned int i = 0; i < numExpected; ++i)
273 {
274 if (!vec[i])
275 {
276 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
277 }
278 }
279}
280
281//---------------------------------------------------------------
282void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
283 const TensorInfo& second,
284 const TensorInfo& output,
285 std::string const& descName,
286 std::string const& firstName,
287 std::string const& secondName)
288{
289 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
290 // broadcasted.
291 if (first.GetNumDimensions() != second.GetNumDimensions())
292 {
293 throw InvalidArgumentException(descName + ": Tensors "
294 + firstName + " & " + secondName
295 + " must have the same number of dimensions in order to be broadcasted");
296 }
297 uint32_t numDims = first.GetNumDimensions();
298 std::vector<uint32_t> outputDims(numDims, 0u);
299 for (uint32_t i = 0; i < numDims; i++)
300 {
301 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
302 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
303 if (dimsNotEqual && dimsNotOne)
304 {
305 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
306 }
307 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
308 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100309 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000310 if (broadcastShape != output.GetShape())
311 {
312 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
313 + firstName + " & " + secondName
314 + " does not match the output shape");
315 }
316}
317
318//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100319void ValidateDataTypes(const TensorInfo& info,
320 const std::vector<armnn::DataType>& supportedTypes,
321 std::string const& descName)
322{
323 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
324 if (iterator == supportedTypes.end())
325 {
326 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
327 }
328}
329
James Conroy4d1ff582019-06-10 17:06:39 +0100330//---------------------------------------------------------------
331void ValidateTensorDataTypesMatch(const TensorInfo& first,
332 const TensorInfo& second,
333 std::string const& descName,
334 std::string const& firstName,
335 std::string const& secondName)
336{
337 if (first.GetDataType() != second.GetDataType())
338 {
339 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
340 " must have identical data types.");
341 }
342}
343
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100344//---------------------------------------------------------------
345void ValidateTensorNumElementsMatch(const TensorInfo& first,
346 const TensorInfo& second,
347 std::string const& descName,
348 std::string const& firstName,
349 std::string const& secondName)
350{
351 if (first.GetNumElements() != second.GetNumElements())
352 {
353 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
354 " must have the same number of elements.");
355 }
356}
357
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000358void ValidateWeightDataType(const TensorInfo& inputInfo,
359 const TensorInfo& weightInfo,
360 const std::string& descName)
361{
362 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000363 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000364 {
Derek Lambertid466a542020-01-22 15:37:29 +0000365 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000366 const std::vector<DataType> validTypes =
367 {
Keith Davis0c2eeac2020-02-11 16:51:50 +0000368 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100369 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000370 DataType::QSymmS8,
371 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000372 };
Derek Lambertid466a542020-01-22 15:37:29 +0000373 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000374
375 ValidateDataTypes(weightInfo, validTypes, descName);
376 }
377 else
378 {
379 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
380 }
381}
382
383void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
384 const std::string& descName,
385 const std::string& tensorName)
386{
387 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
388 if (!quantizationDim.has_value())
389 {
James Ward47fce872020-09-10 11:57:28 +0100390 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
391 "not set on tensor {1}.", descName, tensorName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000392 }
393
394 if (quantizationDim.value() != 0)
395 {
James Ward47fce872020-09-10 11:57:28 +0100396 throw InvalidArgumentException(fmt::format(
397 "{0}: Quantization dimension for per-axis quantization expected to be 0 on tensor {1}, "
398 "but got: {2}", descName, tensorName, quantizationDim.value()));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000399 }
400}
401
402void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
403 const std::string& descName,
404 const std::string& tensorName)
405{
406 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
407 if (quantizationOffset != 0)
408 {
James Ward47fce872020-09-10 11:57:28 +0100409 throw InvalidArgumentException(fmt::format(
410 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
411 descName, tensorName, quantizationOffset));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000412 }
413}
414
415void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
416 const TensorInfo& outputInfo,
417 const TensorInfo& weightInfo,
418 const Optional<TensorInfo>& optionalBiasInfo,
419 const std::string& descName)
420{
421 if (weightInfo.HasPerAxisQuantization())
422 {
423 const DataType inputDataType = inputInfo.GetDataType();
424 const DataType outputDataType = outputInfo.GetDataType();
425
Keith Davis0c2eeac2020-02-11 16:51:50 +0000426 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000427
428 if (!canHavePerAxisQuantization)
429 {
James Ward47fce872020-09-10 11:57:28 +0100430 throw InvalidArgumentException(fmt::format(
431 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
432 "per-axis quantization.", descName, "weight"));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000433 }
434
Derek Lambertid466a542020-01-22 15:37:29 +0000435
436 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000437 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
438 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
439
440 if (optionalBiasInfo.has_value())
441 {
442 const TensorInfo& biasInfo = optionalBiasInfo.value();
443 if (!biasInfo.HasPerAxisQuantization())
444 {
James Ward47fce872020-09-10 11:57:28 +0100445 throw InvalidArgumentException(fmt::format(
446 "{}: Per-axis quantization parameters not set on bias tensor, "
447 "despite being set on weight tensor.", descName));
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000448 }
449
450 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
451 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
452 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
453 }
454 }
455}
456
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100457} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000458
459void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
460 unsigned int numExpectedIn, unsigned int numExpectedOut) const
461{
462 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
463 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
464}
465
466//---------------------------------------------------------------
Jim Flynn68db06f2020-10-06 10:14:50 +0100467void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
468{
469 const std::string descriptorName{"MapQueueDescriptor"};
470
471 ValidateNumInputs(workloadInfo, descriptorName, 1);
Jim Flynn3a40ea52020-10-08 11:42:30 +0100472 ValidateNumOutputs(workloadInfo, descriptorName, 0);
473
474 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
475 {
476 if (!m_Inputs[i])
477 {
478 throw InvalidArgumentException(
479 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
480 }
481 }
482}
483
484//---------------------------------------------------------------
485void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
486{
487 const std::string descriptorName{"UnmapQueueDescriptor"};
488
489 ValidateNumInputs(workloadInfo, descriptorName, 1);
490 ValidateNumOutputs(workloadInfo, descriptorName, 0);
Jim Flynn68db06f2020-10-06 10:14:50 +0100491
492 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
493 {
494 if (!m_Inputs[i])
495 {
496 throw InvalidArgumentException(
497 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
498 }
499 }
500}
501
502//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000503void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
504{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100505 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000506
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100507 ValidateNumInputs(workloadInfo, descriptorName, 1);
508 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000509
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100510 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
511 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
512
513 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
514 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000515
516 if (m_Inputs.size() != m_Outputs.size())
517 {
James Ward47fce872020-09-10 11:57:28 +0100518 throw InvalidArgumentException(fmt::format(
519 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
520 descriptorName, m_Inputs.size(), m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000521 }
522
523 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
524 {
525 if (!m_Inputs[i])
526 {
James Ward47fce872020-09-10 11:57:28 +0100527 throw InvalidArgumentException(fmt::format(
528 "{0}: Invalid NULL input {1}.", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000529 }
530
531 if (!m_Outputs[i])
532 {
James Ward47fce872020-09-10 11:57:28 +0100533 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
telsoa014fcda012018-03-09 14:13:49 +0000534 }
535 }
536}
537
Derek Lambertif674aa02019-08-01 15:56:25 +0100538//---------------------------------------------------------------
539void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
540{
541 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
542 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
543
544 if (workloadInfo.m_InputTensorInfos.size() != 1)
545 {
James Ward47fce872020-09-10 11:57:28 +0100546 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
547 workloadInfo.m_InputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100548
549 }
550
551 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
552 {
James Ward47fce872020-09-10 11:57:28 +0100553 throw InvalidArgumentException(fmt::format(
554 "Number of input infos ({0}) does not match the number of output infos ({1})",
555 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100556 }
557
558 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
559 {
560 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
561 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
562 {
James Ward47fce872020-09-10 11:57:28 +0100563 throw InvalidArgumentException(fmt::format(
564 "Number of elements for tensor input and output {} does not match", i ));
Derek Lambertif674aa02019-08-01 15:56:25 +0100565 }
566 }
567
568 if (m_Inputs.size() != 1)
569 {
James Ward47fce872020-09-10 11:57:28 +0100570 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100571 }
572
573 if (m_Inputs.size() != m_Outputs.size())
574 {
James Ward47fce872020-09-10 11:57:28 +0100575 throw InvalidArgumentException(fmt::format(
576 "Number of inputs ({0}) does not match the number of outputs ({1})",
577 m_Inputs.size(), m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100578 }
579
580 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
581 {
582 if (!m_Inputs[i])
583 {
James Ward47fce872020-09-10 11:57:28 +0100584 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100585 }
586
587 if (!m_Outputs[i])
588 {
James Ward47fce872020-09-10 11:57:28 +0100589 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
Derek Lambertif674aa02019-08-01 15:56:25 +0100590 }
591 }
592}
593
594//---------------------------------------------------------------
595void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
596{
597 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
598 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
599
Derek Lambertif674aa02019-08-01 15:56:25 +0100600 if (m_Inputs.size() != 1)
601 {
James Ward47fce872020-09-10 11:57:28 +0100602 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100603 }
604
605 if (m_Outputs.size() != 0)
606 {
James Ward47fce872020-09-10 11:57:28 +0100607 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
Derek Lambertif674aa02019-08-01 15:56:25 +0100608 }
609
610 if (!m_Inputs[0])
611 {
James Ward47fce872020-09-10 11:57:28 +0100612 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
Derek Lambertif674aa02019-08-01 15:56:25 +0100613 }
614}
615
616//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000617void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
618{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100619 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100620
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100621 ValidateNumInputs(workloadInfo, descriptorName, 1);
622 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100623
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100624 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
625 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100626
627 std::vector<DataType> supportedTypes =
628 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000629 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100630 DataType::Float16,
631 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000632 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000633 DataType::QAsymmU8,
634 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100635 };
636
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100637 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
638 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
639 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000640}
641
Nikhil Rajee391d52019-09-05 17:50:44 +0100642void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
643{
644 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
645
646 ValidateNumInputs(workloadInfo, descriptorName, 1);
647 ValidateNumOutputs(workloadInfo, descriptorName, 1);
648
649 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
650 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
651
Inki Daed4619e22020-09-10 15:33:54 +0900652 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
653 outputTensorInfo.GetDataType() != DataType::Signed64)
Nikhil Raj68c2c902019-09-19 11:21:11 +0100654 {
Inki Daed4619e22020-09-10 15:33:54 +0900655 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
Nikhil Raj68c2c902019-09-19 11:21:11 +0100656 }
657
James Conroyd47a0642019-09-17 14:22:06 +0100658 std::vector<DataType> supportedInputTypes =
659 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000660 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100661 DataType::Float16,
662 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100663 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000664 DataType::QAsymmU8,
665 DataType::QSymmS16,
Inki Daed4619e22020-09-10 15:33:54 +0900666 DataType::Signed32,
667 DataType::Signed64
James Conroyd47a0642019-09-17 14:22:06 +0100668 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100669
James Conroyd47a0642019-09-17 14:22:06 +0100670 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100671
672 auto inputShape = inputTensorInfo.GetShape();
673 auto outputShape = outputTensorInfo.GetShape();
674
675 auto inputNumDimensions = inputShape.GetNumDimensions();
676 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
677
678 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
679
680 // 1D input shape results in scalar output shape
681 if (inputShape.GetNumDimensions() == 1)
682 {
683 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
684 {
685 throw InvalidArgumentException(descriptorName + outputShapeError);
686 }
687 }
688 else
689 {
690 for (unsigned int i = 0; i < unsignedAxis; ++i)
691 {
692 if (outputShape[i] != inputShape[i])
693 {
694 throw InvalidArgumentException(descriptorName + outputShapeError);
695 }
696 }
697
698 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
699 {
700 if (outputShape[i - 1] != inputShape[i])
701 {
702 throw InvalidArgumentException(descriptorName + outputShapeError);
703 }
704 }
705 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100706}
707
mathad01b392e982021-04-07 12:07:30 +0100708void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
709{
710 const std::string descriptorName{"CastQueueDescriptor"};
711
712 ValidateNumInputs(workloadInfo, descriptorName, 1);
713 ValidateNumOutputs(workloadInfo, descriptorName, 1);
714
715 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
716 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
717
718 std::vector<DataType> supportedTypes =
719 {
720 DataType::BFloat16,
721 DataType::Float16,
722 DataType::Float32,
723 DataType::QAsymmS8,
724 DataType::QAsymmU8,
725 DataType::QSymmS8,
726 DataType::QSymmS16,
727 DataType::Signed32,
728 DataType::Signed64
729 };
730
731 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
732 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
733}
734
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100735void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
736{
737 const std::string descriptorName{"SoftmaxQueueDescriptor"};
738
739 ValidateNumInputs(workloadInfo, descriptorName, 1);
740 ValidateNumOutputs(workloadInfo, descriptorName, 1);
741
742 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
743 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
744
745 std::vector<DataType> supportedTypes =
746 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000747 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100748 DataType::Float16,
749 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000750 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000751 DataType::QAsymmU8,
752 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100753 };
754
755 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
756 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
757 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
758}
759
telsoa014fcda012018-03-09 14:13:49 +0000760void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
761{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100762 const std::string descriptorName{"SplitterQueueDescriptor"};
763
764 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000765
Ruomei Yan25339c32019-05-28 16:48:20 +0100766 // Check the supported data types
767 std::vector<DataType> supportedTypes =
768 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000769 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100770 DataType::Float32,
771 DataType::Float16,
772 DataType::Boolean,
773 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100774 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000775 DataType::QAsymmU8,
776 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100777 };
778
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100779 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
780 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100781 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100782 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
783 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
784
785 const std::string outputName = "output_" + std::to_string(i);
786 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100787 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100788
telsoa014fcda012018-03-09 14:13:49 +0000789 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
790 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100791 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000792 }
793
794 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
795 {
796 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100797 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000798 "has to match number of workloadInfo.m_OutputTensorInfos. "
799 "Number of windows: " +
800 to_string(m_ViewOrigins.size()) +
801 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
802 }
803
telsoa01c577f2c2018-08-31 09:22:23 +0100804 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000805 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
806 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
807 {
telsoa01c577f2c2018-08-31 09:22:23 +0100808 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000809 ViewOrigin const& e = m_ViewOrigins[w];
810 if (e.m_Origin.size() != inputDims)
811 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100812 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000813 "have the same dimensionality as the input tensor. "
814 "Window origin (index: " +
815 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
816 " dimensions, the input "
817 "tensor has " +
818 to_string(inputDims) + " dimensions.");
819 }
820 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
821 {
822 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
823 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
824 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100825 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000826 "be smaller or equal than the size of the input in that coord.");
827 }
828 }
829 }
830}
831
Jim Flynne242f2d2019-05-22 14:24:13 +0100832void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000833{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100834 const std::string descriptorName{"ConcatQueueDescriptor"};
835
836 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000837
838 if (m_Inputs.size() <= 0)
839 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100840 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000841 }
842 if (m_Outputs.size() <= 0)
843 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100844 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000845 }
846
847 if (workloadInfo.m_InputTensorInfos.size() <= 0)
848 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100849 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000850 }
851 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
852 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100853 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000854 }
855
Nikhil Raj8599a412018-11-19 14:51:07 +0000856 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
857 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100858 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000859 }
860
861 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
862 {
863 return;
864 }
865
telsoa014fcda012018-03-09 14:13:49 +0000866 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
867 {
868 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100869 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000870 "has to match number of workloadInfo.m_InputTensorInfos. "
871 "Number of windows: " +
872 to_string(m_ViewOrigins.size()) +
873 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
874 }
875
telsoa01c577f2c2018-08-31 09:22:23 +0100876 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000877 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
878 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
879 {
telsoa01c577f2c2018-08-31 09:22:23 +0100880 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000881 ViewOrigin const& e = m_ViewOrigins[w];
882 if (e.m_Origin.size() != outputDims)
883 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100884 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000885 "have the same dimensionality as the output tensor. "
886 "Window origin (index: " +
887 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
888 " dimensions, the output "
889 "tensor has " +
890 to_string(outputDims) + " dimensions.");
891 }
telsoa01c577f2c2018-08-31 09:22:23 +0100892 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000893 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
894 {
895 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
896 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
897 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100898 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000899 "be smaller or equal than the size of the output in that coord.");
900 }
901 }
902 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100903
904 // Check the supported data types
905 std::vector<DataType> supportedTypes =
906 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000907 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100908 DataType::Float32,
909 DataType::Float16,
910 DataType::Boolean,
911 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100912 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000913 DataType::QAsymmU8,
914 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100915 };
916
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100917 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
918 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100919 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100920 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
921 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
922
923 const std::string inputName = "input_" + std::to_string(i);
924 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100925 }
telsoa014fcda012018-03-09 14:13:49 +0000926}
927
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100928void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
929{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100930 const std::string descriptorName{"StackQueueDescriptor"};
931
932 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100933
934 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
935 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100936 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100937 }
938
939 // All inputs must have the same shape, which is defined in parameters
940 const TensorShape& inputShape = m_Parameters.m_InputShape;
941 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
942 {
943 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
944 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100945 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100946 }
947 }
948
Matthew Jacksondba634f2019-08-15 15:14:18 +0100949 if (inputShape.GetNumDimensions() > 4)
950 {
951 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
952 }
953
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100954 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
955 // since the output tensor has an additional dimension.
956 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
957 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100958 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100959 "than the number of input dimensions.");
960 }
961
962 // Output shape must be as inferred from the input shape
963 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
964 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
965 {
966 if (outputShape[i] != inputShape[i])
967 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100968 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100969 "match shape inferred from input tensor.");
970 }
971 }
972
973 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
974 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100975 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100976 "match shape inferred from input tensor.");
977 }
978
979 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
980 {
981 if (outputShape[i] != inputShape[i-1])
982 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100983 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100984 "match shape inferred from input tensor.");
985 }
986 }
987
Matthew Jacksondba634f2019-08-15 15:14:18 +0100988 if (outputShape.GetNumDimensions() > 5)
989 {
990 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
991 }
992
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100993 // Check the supported data types
994 std::vector<DataType> supportedTypes =
995 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000996 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100997 DataType::Float32,
998 DataType::Float16,
999 DataType::Boolean,
1000 DataType::Signed32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001001 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001002 DataType::QAsymmU8,
1003 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001004 };
1005
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001006 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001007
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001008 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001009 {
1010 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1011 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001012 descriptorName,
1013 "input_0",
1014 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001015 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001016
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001017 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1018 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001019 descriptorName,
1020 "input_0",
1021 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +01001022}
1023
Ryan OSheaec6c6802020-06-05 17:17:06 +01001024void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1025{
1026 const std::string descriptorName{"FillQueueDescriptor"};
1027
1028 ValidateNumInputs(workloadInfo, descriptorName, 1);
1029 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1030
1031 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1032 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1033
1034 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1035
1036 std::vector<DataType> supportedTypes =
1037 {
1038 DataType::BFloat16,
1039 DataType::Float32,
1040 DataType::Float16,
1041 DataType::Signed32
1042 };
1043
1044 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1045}
1046
telsoa014fcda012018-03-09 14:13:49 +00001047void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1048{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001049 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001050
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001051 uint32_t numInputs = 1;
1052 if (!m_Parameters.m_ConstantWeights)
1053 {
1054 numInputs = 2;
1055 if (m_Parameters.m_BiasEnabled)
1056 {
1057 numInputs = 3;
1058 }
1059 }
1060 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001061 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1062
1063 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1064 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1065
1066 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1067
1068 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +00001069 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001070 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +00001071 }
1072
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001073 TensorInfo weightTensorInfo;
1074 if (m_Parameters.m_ConstantWeights)
1075 {
1076 ValidatePointer(m_Weight, descriptorName, "weight");
1077 weightTensorInfo = m_Weight->GetTensorInfo();
1078 }
1079 else
1080 {
1081 weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
1082 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001083 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001084
1085 if (m_Parameters.m_BiasEnabled)
1086 {
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00001087 TensorInfo biasTensorInfo;
1088 if (m_Parameters.m_ConstantWeights)
1089 {
1090 ValidatePointer(m_Bias, descriptorName, "bias");
1091 biasTensorInfo = m_Bias->GetTensorInfo();
1092 }
1093 else
1094 {
1095 biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
1096 }
telsoa01c577f2c2018-08-31 09:22:23 +01001097 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001098 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001099 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1100 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001101 }
1102
Francis Murtagh46c09d02019-05-28 08:15:28 +01001103 // Check the supported data types
1104 std::vector<DataType> supportedTypes =
1105 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001106 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001107 DataType::Float32,
1108 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001109 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001110 DataType::QAsymmU8,
1111 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001112 };
1113
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001114 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001115
1116 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1117 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1118 {
1119 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1120 {
1121 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1122 "for BFloat16 input.");
1123 }
1124 }
1125 else
1126 {
1127 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1128 }
telsoa014fcda012018-03-09 14:13:49 +00001129}
1130
telsoa014fcda012018-03-09 14:13:49 +00001131void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1132{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001133 const std::string descriptorName{"NormalizationQueueDescriptor"};
1134
1135 ValidateNumInputs(workloadInfo, descriptorName, 1);
1136 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1137
1138 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1139 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001140
1141 // Check the supported data types
1142 std::vector<DataType> supportedTypes =
1143 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001144 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001145 DataType::Float16,
1146 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001147 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001148 DataType::QAsymmU8,
1149 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001150 };
1151
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001152 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001153
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001154 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001155
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001156 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001157}
1158
1159void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1160{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001161 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001162
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001163 ValidateNumInputs(workloadInfo, descriptorName, 2);
1164 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1165
1166 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1167 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1168 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1169
1170 std::vector<DataType> supportedTypes =
1171 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001172 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001173 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001174 DataType::Float16,
1175 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001176 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001177 DataType::QSymmS16,
1178 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001179 };
1180
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001181 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1182 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1183 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001184
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001185 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1186 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001187
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001188 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1189 inputTensorInfo1,
1190 outputTensorInfo,
1191 descriptorName,
1192 "input_0",
1193 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001194}
1195
telsoa014fcda012018-03-09 14:13:49 +00001196void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1197{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001198 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001199
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001200 ValidateNumInputs(workloadInfo, descriptorName, 2);
1201 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1202
1203 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1204 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1205 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1206
1207 std::vector<DataType> supportedTypes =
1208 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001209 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001210 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001211 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001212 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001213 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001214 DataType::QSymmS16,
1215 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001216 };
1217
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001218 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1219 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1220 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001221
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001222 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1223 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001224
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001225 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1226 inputTensorInfo1,
1227 outputTensorInfo,
1228 descriptorName,
1229 "input_0",
1230 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001231}
1232
1233void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1234{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001235 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001236
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001237 ValidateNumInputs(workloadInfo, descriptorName, 1);
1238 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1239
1240 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1241 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001242
1243 std::vector<DataType> supportedTypes =
1244 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001245 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001246 DataType::Float16,
1247 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001248 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001249 DataType::QAsymmU8,
1250 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001251 };
1252
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001253 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1254 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001255
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001256 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001257 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001258
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001259 ValidatePointer(m_Mean, descriptorName, "mean");
1260 ValidatePointer(m_Variance, descriptorName, "variance");
1261 ValidatePointer(m_Beta, descriptorName, "beta");
1262 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001263
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001264 const TensorInfo& mean = m_Mean->GetTensorInfo();
1265 const TensorInfo& variance = m_Variance->GetTensorInfo();
1266 const TensorInfo& beta = m_Beta->GetTensorInfo();
1267 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001268
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001269 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1270 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1271 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1272 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001273
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001274 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1275 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1276 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001277}
1278
1279void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1280{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001281 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001282
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001283 ValidateNumInputs(workloadInfo, descriptorName, 1);
1284 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001285
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001286 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1287 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001288
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001289 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1290 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001291
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001292 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001293
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001294 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1295 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001296
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001297 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001298
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001299 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001300 if (m_Parameters.m_BiasEnabled)
1301 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001302 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001303
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001304 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1305 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001306
1307 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1308 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001309 }
1310
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001311 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1312 {
1313 throw InvalidArgumentException(
1314 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1315 "cannot be either negative or 0.",
1316 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1317 }
1318
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001319 ValidatePerAxisQuantization(inputTensorInfo,
1320 outputTensorInfo,
1321 weightTensorInfo,
1322 optionalBiasTensorInfo,
1323 descriptorName);
1324
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001325 std::vector<DataType> supportedTypes =
1326 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001327 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001328 DataType::Float16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001329 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001330 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001331 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001332 DataType::QSymmS16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001333 DataType::QSymmS8
Ruomei Yan88d44b82019-05-23 14:29:06 +01001334 };
1335
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001336 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001337
1338 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1339 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1340 {
1341 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1342 {
1343 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1344 "for BFloat16 input.");
1345 }
1346 }
1347 else
1348 {
1349 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1350 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001351}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001352
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001353void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1354{
1355 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1356
1357 ValidateNumInputs(workloadInfo, descriptorName, 1);
1358 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1359
1360 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1361 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1362
1363 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1364 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1365
1366 ValidatePointer(m_Weight, descriptorName, "weight");
1367
1368 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1369 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1370
1371 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1372 {
1373 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001374 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1375 "cannot be smaller than 1.",
1376 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001377 }
1378
Teresa Charlinf2ed1b82020-11-24 15:11:54 +00001379 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1380 {
1381 throw InvalidArgumentException(
1382 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1383 "cannot be either negative or 0.",
1384 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1385 }
1386
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001387 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1388
1389 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1390 // inputChannels * channelMultiplier should be equal to outputChannels.
1391 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1392 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1393 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1394 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1395 {
James Ward47fce872020-09-10 11:57:28 +01001396 throw InvalidArgumentException(fmt::format(
1397 "{0}: output_channels (provided {1}) should be equal to input_channels (provided {2}) "
1398 "multiplied by channel_multiplier (provided {3}).",
1399 descriptorName, numWeightOutputChannels, numWeightInputChannels, numWeightChannelMultiplier));
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001400 }
1401
Teresa Charlind8df0262019-11-11 12:28:15 +00001402 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001403
Teresa Charlind8df0262019-11-11 12:28:15 +00001404 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001405 if (m_Parameters.m_BiasEnabled)
1406 {
1407 ValidatePointer(m_Bias, descriptorName, "bias");
1408
Teresa Charlind8df0262019-11-11 12:28:15 +00001409 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1410 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001411
1412 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1413 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1414 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001415 ValidatePerAxisQuantization(inputTensorInfo,
1416 outputTensorInfo,
1417 weightTensorInfo,
1418 optionalBiasTensorInfo,
1419 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001420
1421 std::vector<DataType> supportedTypes =
1422 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001423 DataType::BFloat16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001424 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001425 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001426 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001427 DataType::QAsymmU8,
1428 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001429 };
1430
1431 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1432 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001433}
1434
1435void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1436{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001437 const std::string descriptorName{"PermuteQueueDescriptor"};
1438
1439 ValidateNumInputs(workloadInfo, descriptorName, 1);
1440 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001441
1442 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1443
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001444 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1445 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001446
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001447 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1448 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001449
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001450 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001451 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001452 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001453 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001454 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1455 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1456 "must match dst dimension " + to_string(mapping[i]) +
1457 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001458 }
1459 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001460
1461 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001462}
1463
1464void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1465{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001466 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001467
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001468 ValidateNumInputs(workloadInfo, descriptorName, 1);
1469 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1470
1471 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1472 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1473
1474 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1475 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001476
1477 std::vector<DataType> supportedTypes =
1478 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001479 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001480 DataType::Float32,
1481 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001482 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001483 DataType::QAsymmU8,
1484 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001485 };
1486
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001487 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1488 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001489}
1490
1491void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1492{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001493 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001494
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001495 ValidateNumInputs(workloadInfo, descriptorName, 1);
1496 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1497
1498 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1499 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1500
1501 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1502 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001503
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001504 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001505 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001506 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001507 DataType::Float16,
1508 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001509 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001510 DataType::QAsymmU8,
1511 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001512 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001513
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001514 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1515 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001516
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001517 // ResizeBilinear only changes width and height: batch and channel count must match.
1518 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1519 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001520 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001521 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001522 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001523 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1524 descriptorName, inputBatchSize, outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001525 }
1526
Teresa Charlin970f43b2019-07-01 13:51:07 +01001527 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001528 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1529 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001530 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001531 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001532 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001533 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1534 descriptorName, inputChannelCount, outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001535 }
1536}
1537
1538void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1539{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001540 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001541
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001542 ValidateNumInputs(workloadInfo, descriptorName, 1);
1543 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1544
1545 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1546 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1547
1548 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1549 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001550
1551 std::vector<DataType> supportedTypes =
1552 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001553 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001554 DataType::Float16,
1555 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001556 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001557 DataType::QAsymmU8,
1558 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001559 };
1560
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001561 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1562 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001563
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001564 // Resize only changes width and height: batch and channel count must match.
1565 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1566 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001567 if (inputBatchSize != outputBatchSize)
1568 {
1569 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001570 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1571 descriptorName, inputBatchSize, outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001572 }
1573
1574 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001575 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1576 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001577 if (inputChannelCount != outputChannelCount)
1578 {
1579 throw InvalidArgumentException(
James Ward47fce872020-09-10 11:57:28 +01001580 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1581 descriptorName, inputChannelCount, outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001582 }
1583}
1584
1585void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1586{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001587 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001588
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001589 ValidateNumInputs(workloadInfo, descriptorName, 1);
1590 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1591
1592 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1593 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1594
1595 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1596 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1597
1598 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1599
telsoa014fcda012018-03-09 14:13:49 +00001600 if (m_Parameters.m_Min > m_Parameters.m_Max)
1601 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001602 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001603 }
telsoa014fcda012018-03-09 14:13:49 +00001604}
1605
Kevin Mayce5045a2019-10-02 14:07:47 +01001606void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1607{
1608 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1609
1610 ValidateNumInputs(workloadInfo, descriptorName, 1);
1611 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1612
1613 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1614 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1615
1616 if (inputTensorInfo.GetNumDimensions() > 4)
1617 {
1618 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1619 }
1620
1621 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1622
1623 // Check the supported data types
1624 std::vector<DataType> supportedTypes =
1625 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001626 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001627 DataType::Float32,
1628 DataType::Float16
1629 };
1630
1631 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001632 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001633}
1634
telsoa014fcda012018-03-09 14:13:49 +00001635void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1636{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001637 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001638
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001639 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001640 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1641
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001642 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1643 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1644
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001645 if (inputTensorInfo.GetNumDimensions() > 4)
1646 {
1647 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1648 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001649
1650 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001651
1652 // Check the supported data types
1653 std::vector<DataType> supportedTypes =
1654 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001655 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001656 DataType::Float32,
1657 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001658 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001659 DataType::QAsymmU8,
1660 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001661 };
1662
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001663 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001664 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1665}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001666
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001667void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1668{
1669 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1670
1671 ValidateNumInputs(workloadInfo, descriptorName, 1);
1672 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1673
1674 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1675 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1676
1677 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1678
1679 std::vector<DataType> supportedTypes =
1680 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001681 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001682 DataType::Float32,
1683 DataType::Float16,
1684 };
1685
1686 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001687 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001688}
1689
1690void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1691{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001692 const std::string descriptorName{"ConstantQueueDescriptor"};
1693
1694 ValidateNumInputs(workloadInfo, descriptorName, 0);
1695 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001696
1697 if (!m_LayerOutput)
1698 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001699 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001700 }
1701
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001702 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1703 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001704
1705 // Check the supported data types
1706 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001707 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001708 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001709 DataType::Float32,
1710 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001711 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001712 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001713 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001714 DataType::QSymmS16,
1715 DataType::Signed32
Nina Drozd2f2778f2019-05-27 10:37:05 +01001716 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001717
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001718 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001719}
1720
1721void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1722{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001723 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001724
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001725 ValidateNumInputs(workloadInfo, descriptorName, 1);
1726 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1727
1728 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1729 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1730
1731 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001732
1733 // Check the supported data types
1734 std::vector<DataType> supportedTypes =
1735 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001736 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001737 DataType::Float32,
1738 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001739 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001740 DataType::QAsymmU8,
1741 DataType::QSymmS16,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001742 DataType::Signed32,
1743 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001744 };
1745
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001746 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1747 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001748}
1749
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001750void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1751{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001752 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001753
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001754 ValidateNumInputs(workloadInfo, descriptorName, 1);
1755 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1756
1757 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1758 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1759
1760 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1761 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001762
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001763 if (m_Parameters.m_BlockShape.size() != 2)
1764 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001765 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001766 }
1767
1768 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1769 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001770 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1771 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001772 }
1773
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001774 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001775
1776 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001777 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001778
Matthew Bentham8800c002018-11-19 13:19:28 +00001779 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001780
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001781 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1782 widthPad.first + widthPad.second;
1783 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1784 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001785
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001786 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1787 inputShape[dimensionIndices.GetChannelsIndex()];
1788 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001789
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001790 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001791 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001792 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001793 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001794 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001795 }
1796
1797 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001798 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001799 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1800 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001801 }
nikraj01120522a2019-05-31 11:33:07 +01001802
1803 std::vector<DataType> supportedTypes =
1804 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001805 DataType::BFloat16,
1806 DataType::Float16,
1807 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001808 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001809 DataType::QAsymmU8,
1810 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001811 };
1812
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001813 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1814 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001815}
1816
Keith Davisa57eccb2019-06-14 17:33:22 +01001817void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1818{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001819 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001820
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001821 ValidateNumInputs(workloadInfo, descriptorName, 1);
1822 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001823
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001824 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1825 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1826
1827 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1828 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001829
1830 std::vector<DataType> supportedTypes =
1831 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001832 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001833 DataType::Float32,
1834 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001835 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001836 DataType::QAsymmU8,
1837 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001838 };
1839
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001840 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1841 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001842
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001843 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1844
1845 if (m_Parameters.m_BlockSize == 0)
1846 {
1847 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1848 }
1849
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001850 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1851 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1852 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1853 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001854
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001855 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001856 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001857 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001858 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1859 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001860 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001861
1862 const TensorShape& outputShape = outputTensorInfo.GetShape();
1863 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1864 {
1865 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1866 "must be divisible by the square of block size." );
1867 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001868}
1869
telsoa014fcda012018-03-09 14:13:49 +00001870void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1871{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001872 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001873
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001874 ValidateNumInputs(workloadInfo, descriptorName, 1);
1875 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1876
1877 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1878 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001879
1880 std::vector<DataType> supportedTypes =
1881 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001882 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001883 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001884 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001885 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001886 };
1887
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001888 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001889
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001890 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001891 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001892 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001893 }
1894}
1895
telsoa01c577f2c2018-08-31 09:22:23 +01001896void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1897{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001898 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1899
1900 const std::string descriptorName{"LstmQueueDescriptor"};
1901
1902 // check dimensions of all inputs and outputs
1903 if (workloadInfo.m_InputTensorInfos.size() != 3)
1904 {
1905 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1906 }
1907 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1908 {
1909 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1910 }
1911
1912 std::vector<DataType> supportedTypes =
1913 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001914 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001915 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001916 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001917 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001918 };
1919
Jan Eilers38e05bd2019-06-26 13:10:09 +01001920 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001921 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1922
Jan Eilers38e05bd2019-06-26 13:10:09 +01001923 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001924 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001925 {
1926 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1927 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001928 descriptorName,
1929 "input_0",
1930 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001931 }
1932 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001933 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001934 {
1935 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1936 workloadInfo.m_OutputTensorInfos[i],
1937 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001938 "input_0",
1939 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001940 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001941
janeil0117d8d852019-11-15 15:00:16 +00001942 // Making sure clipping parameters have valid values.
1943 // == 0 means no clipping
1944 // > 0 means clipping
1945 if (m_Parameters.m_ClippingThresCell < 0.0f)
1946 {
1947 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
1948 }
1949 if (m_Parameters.m_ClippingThresProj < 0.0f)
1950 {
1951 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
1952 }
1953
Jan Eilers38e05bd2019-06-26 13:10:09 +01001954
1955 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01001956 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1957 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1958 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1959 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1960 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1961 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1962
Jan Eilers38e05bd2019-06-26 13:10:09 +01001963 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001964 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1965 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001966 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001967 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1968 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001969 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001970 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1971 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001972 // scratchBufferTensor
1973 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001974 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1975 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001976 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001977 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1978 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001979 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001980 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1981 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001982 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001983 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1984 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001985
1986
1987 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1988 if ( m_InputToInputWeights )
1989 {
1990 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1991 (n_cell * n_input), "InputLayerNormWeights");
1992 }
1993
1994 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1995 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1996 (n_cell * n_input), "InputToForgetWeights");
1997
1998 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1999 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2000 (n_cell * n_input), "InputToCellWeights");
2001
2002 if ( m_RecurrentToInputWeights )
2003 {
2004 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2005 (n_cell * n_output), "RecurrentToInputWeights");
2006 }
2007
2008 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2009 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2010 (n_cell * n_output), "RecurrentToForgetWeights");
2011
2012 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2013 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2014 (n_cell * n_output), "RecurrentToCellWeights");
2015
2016 // Make sure the input-gate's parameters are either both present (regular
2017 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2018 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2019 !m_Parameters.m_CifgEnabled) ||
2020 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2021 m_Parameters.m_CifgEnabled));
2022 if (!cifg_weights_all_or_none)
2023 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002024 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2025 "RecurrentToInputWeights must either both be present (regular LSTM) "
2026 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2027 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002028 }
2029
2030 if ( m_CellToInputWeights )
2031 {
2032 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2033 n_cell, "CellToInputWeights");
2034 }
2035 if ( m_CellToForgetWeights )
2036 {
2037 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2038 n_cell, "CellToForgetWeights");
2039 }
2040 if ( m_CellToOutputWeights )
2041 {
2042 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2043 n_cell, "CellToOutputWeights");
2044 }
2045
2046 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2047 bool peephole_weights_all_or_none =
2048 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2049 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2050 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2051 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2052 if (!peephole_weights_all_or_none)
2053 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002054 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002055 }
2056
2057 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2058 if (m_Parameters.m_CifgEnabled)
2059 {
2060 if (m_InputGateBias)
2061 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002062 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002063 }
2064 }
2065 else
2066 {
2067 if (!m_InputGateBias)
2068 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002069 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2070 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002071 }
2072 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2073 n_cell, "InputGateBias");
2074 }
2075
2076 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2077 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2078
2079 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2080 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2081
2082 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2083 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2084
2085 if (m_ProjectionWeights)
2086 {
2087 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2088 (n_cell * n_output), "ProjectionWeights");
2089 }
2090 if (m_ProjectionBias)
2091 {
2092 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2093 }
2094
2095 // Making sure the projection tensors are consistent:
2096 // 1) If projection weight is not present, then projection bias should not be
2097 // present.
2098 // 2) If projection weight is present, then projection bias is optional.
2099 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2100 !m_Parameters.m_ProjectionEnabled)
2101 || (m_ProjectionWeights && !m_ProjectionBias &&
2102 m_Parameters.m_ProjectionEnabled)
2103 || (m_ProjectionWeights && m_ProjectionBias &&
2104 m_Parameters.m_ProjectionEnabled));
2105 if (!projecton_tensors_consistent)
2106 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002107 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002108 }
2109
2110 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2111 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2112 // either all have values or none of them have values. Layer normalization is used when the values of all the
2113 // layer normalization weights are present
2114 if (m_InputLayerNormWeights)
2115 {
2116 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2117 }
2118 if (m_ForgetLayerNormWeights)
2119 {
2120 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2121 }
2122 if (m_CellLayerNormWeights)
2123 {
2124 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2125 }
2126 if (m_OutputLayerNormWeights)
2127 {
2128 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2129 }
2130
Jan Eilers38e05bd2019-06-26 13:10:09 +01002131 if (m_Parameters.m_LayerNormEnabled)
2132 {
2133 if (!m_Parameters.m_CifgEnabled)
2134 {
2135 if (!m_InputLayerNormWeights)
2136 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002137 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2138 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002139 }
2140 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2141 1, n_cell, "InputLayerNormWeights");
2142 }
2143 else if (m_InputLayerNormWeights)
2144 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002145 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2146 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002147 }
2148
2149 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2150 "ForgetLayerNormWeights");
2151 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2152
2153 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2154 "OutputLayerNormWeights");
2155 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2156
2157 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2158 "CellLayerNormWeights");
2159 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2160 }
2161 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2162 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002163 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2164 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002165 }
telsoa01c577f2c2018-08-31 09:22:23 +01002166}
2167
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002168void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2169{
2170 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2171
2172 ValidateNumInputs(workloadInfo, descriptorName, 1);
2173 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2174
2175 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2176 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2177
2178 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2179 {
2180 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2181 }
2182
2183 if (outputTensorInfo.GetDataType() != DataType::Float32)
2184 {
2185 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2186 }
2187
2188 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2189}
2190
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002191void ConvertFp32ToBf16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2192{
2193 const std::string descriptorName{"ConvertFp32ToBf16QueueDescriptor"};
2194
2195 ValidateNumInputs(workloadInfo, descriptorName, 1);
2196 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2197
2198 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2199 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2200
2201 if (inputTensorInfo.GetDataType() != DataType::Float32)
2202 {
2203 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2204 }
2205
2206 if (outputTensorInfo.GetDataType() != DataType::BFloat16)
2207 {
2208 throw InvalidArgumentException(descriptorName + ": Output tensor type must be BFloat16.");
2209 }
2210
2211 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2212}
2213
telsoa01c577f2c2018-08-31 09:22:23 +01002214void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2215{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002216 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002217
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002218 ValidateNumInputs(workloadInfo, descriptorName, 1);
2219 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2220
2221 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2222 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2223
2224 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002225 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002226 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002227 }
2228
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002229 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002230 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002231 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002232 }
2233
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002234 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002235}
2236
2237void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2238{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002239 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002240
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002241 ValidateNumInputs(workloadInfo, descriptorName, 1);
2242 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2243
2244 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2245 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2246
2247 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002248 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002249 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002250 }
2251
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002252 if (outputTensorInfo.GetDataType() != DataType::Float32)
2253 {
2254 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2255 }
2256
2257 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002258}
2259
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002260void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2261{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002262 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002263
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002264 ValidateNumInputs(workloadInfo, descriptorName, 2);
2265 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2266
2267 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2268 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2269 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2270
2271 std::vector<DataType> supportedTypes =
2272 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002273 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002274 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002275 DataType::Float32,
2276 DataType::QAsymmS8,
2277 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002278 DataType::QSymmS16,
2279 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002280 };
2281
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002282 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2283 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2284 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002285
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002286 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2287 inputTensorInfo1,
2288 outputTensorInfo,
2289 descriptorName,
2290 "input_0",
2291 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002292}
2293
David Beckc2044fe2018-09-05 15:00:38 +01002294void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2295{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002296 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002297
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002298 ValidateNumInputs(workloadInfo, descriptorName, 2);
2299 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2300
2301 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2302 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2303 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2304
2305 std::vector<DataType> supportedTypes =
2306 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002307 DataType::BFloat16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002308 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002309 DataType::Float32,
2310 DataType::QAsymmS8,
2311 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002312 DataType::QSymmS16,
2313 DataType::Signed32,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002314 };
2315
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002316 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2317 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2318 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002319
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002320 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2321 inputTensorInfo1,
2322 outputTensorInfo,
2323 descriptorName,
2324 "input_0",
2325 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002326}
2327
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002328void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2329{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002330 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002331
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002332 ValidateNumInputs(workloadInfo, descriptorName, 2);
2333 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2334
2335 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2336 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2337 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2338
2339 std::vector<DataType> supportedTypes =
2340 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002341 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002342 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002343 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00002344 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002345 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002346 DataType::QSymmS16,
2347 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002348 };
2349
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002350 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2351 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2352 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002353
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002354 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2355 inputTensorInfo1,
2356 outputTensorInfo,
2357 descriptorName,
2358 "input_0",
2359 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002360}
2361
narpra01a6bf9122018-09-10 09:50:09 +01002362void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2363{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002364 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002365
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002366 ValidateNumInputs(workloadInfo, descriptorName, 1);
2367 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2368
2369 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2370 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002371
2372 std::vector<DataType> supportedTypes =
2373 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002374 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002375 DataType::Float32,
2376 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002377 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002378 DataType::QAsymmU8,
2379 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002380 };
narpra01eb061912018-09-10 17:35:27 +01002381
James Conroy4d1ff582019-06-10 17:06:39 +01002382 // First check if input tensor data type is supported, then
2383 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002384 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2385 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002386
narpra0132b90462018-09-13 11:07:48 +01002387 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002388 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002389 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002390 }
narpra0132b90462018-09-13 11:07:48 +01002391 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002392 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002393 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002394 }
2395 else
2396 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002397 unsigned int outputDim =
Matthew Sloyan171214c2020-09-09 09:07:37 +01002398 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002399 ValidateTensorNumDimensions(outputTensorInfo,
2400 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002401 outputDim > 0 ? outputDim : 1,
2402 "output");
2403 }
narpra01a6bf9122018-09-10 09:50:09 +01002404}
2405
jimfly012c9322a2018-09-19 10:59:49 +01002406void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2407{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002408 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002409
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002410 ValidateNumInputs(workloadInfo, descriptorName, 1);
2411 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2412
2413 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2414 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002415
jimfly012c9322a2018-09-19 10:59:49 +01002416 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002417 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2418
jimfly012c9322a2018-09-19 10:59:49 +01002419 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002420 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2421 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2422 "as there are dimensions in the input tensor that is " +
2423 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2424 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002425 }
2426}
2427
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002428void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2429{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002430 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002431
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002432 ValidateNumInputs(workloadInfo, descriptorName, 1);
2433 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002434
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002435 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2436 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2437
Sadik Armagan2208b602019-07-31 16:36:27 +01002438 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002439 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002440 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002441 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002442 DataType::Float16,
2443 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002444 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002445 DataType::QAsymmU8,
2446 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002447 };
2448
2449 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002450
Keith Davis0c2eeac2020-02-11 16:51:50 +00002451 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002452 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002453 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002454 }
2455}
2456
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002457void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2458{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002459 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002460
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002461 ValidateNumInputs(workloadInfo, descriptorName, 1);
2462 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002463
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002464 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2465 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002466
2467 std::vector<DataType> supportedTypes =
2468 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002469 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002470 DataType::Float32,
2471 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002472 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002473 DataType::QAsymmU8,
2474 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002475 };
2476
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002477 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2478 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002479}
2480
Conor Kennedy430b5d82018-11-14 15:28:28 +00002481void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2482{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002483 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002484
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002485 ValidateNumInputs(workloadInfo, descriptorName, 1);
2486 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2487
2488 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2489 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002490
2491 std::vector<DataType> supportedTypes =
2492 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002493 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002494 DataType::Float16,
2495 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002496 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002497 DataType::QAsymmU8,
2498 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002499 };
2500
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002501 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2502 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002503
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002504 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002505
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002506 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002507 if (rank > 4)
2508 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002509 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002510 }
2511
Conor Kennedy430b5d82018-11-14 15:28:28 +00002512 // Begin, End & Stride length must be of rank(input0)
2513 if (m_Parameters.m_Begin.size() != rank)
2514 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002515 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002516 }
2517
2518 if (m_Parameters.m_End.size() != rank)
2519 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002520 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002521 }
2522
2523 if (m_Parameters.m_Stride.size() != rank)
2524 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002525 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002526 }
2527
2528 // Stride entries must be non-zero
2529 for (auto& stride : m_Parameters.m_Stride)
2530 {
2531 if (stride == 0)
2532 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002533 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002534 }
2535 }
2536}
2537
kevmay0190539692018-11-29 08:40:19 +00002538void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2539{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002540 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002541
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002542 ValidateNumInputs(workloadInfo, descriptorName, 2);
2543 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2544
2545 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2546 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2547 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2548
2549 std::vector<DataType> supportedTypes =
2550 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002551 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002552 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002553 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002554 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002555 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002556 DataType::QSymmS16,
2557 DataType::Signed32
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002558 };
2559
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002560 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2561 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2562 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002563
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002564 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2565 inputTensorInfo1,
2566 outputTensorInfo,
2567 descriptorName,
2568 "input_0",
2569 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002570}
2571
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002572void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2573{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002574 const std::string descriptorName{"DebugQueueDescriptor"};
2575
2576 ValidateNumInputs(workloadInfo, descriptorName, 1);
2577 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002578}
2579
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002580void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2581{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002582 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002583
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002584 ValidateNumInputs(workloadInfo, descriptorName, 2);
2585 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002586
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002587 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2588 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2589 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2590
2591 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2592 inputTensorInfo1,
2593 outputTensorInfo,
2594 descriptorName,
2595 "input_0",
2596 "input_1");
2597
2598 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002599 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002600 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002601 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002602}
2603
FrancisMurtagh878f0232018-12-19 10:56:15 +00002604void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2605{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002606 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002607
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002608 ValidateNumInputs(workloadInfo, descriptorName, 2);
2609 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002610
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002611 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2612 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2613 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2614
2615 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2616 inputTensorInfo1,
2617 outputTensorInfo,
2618 descriptorName,
2619 "input_0",
2620 "input_1");
2621
2622 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002623 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002624 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002625 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002626}
2627
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002628void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2629{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002630 const std::string descriptorName{"RsqrtQueueDescriptor"};
2631
2632 ValidateNumInputs(workloadInfo, descriptorName, 1);
2633 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2634
2635 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2636 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2637
2638 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002639
2640 std::vector<DataType> supportedTypes =
2641 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002642 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002643 DataType::Float16,
2644 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002645 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002646 DataType::QAsymmU8,
2647 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002648 };
2649
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002650 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2651 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002652}
2653
narpra01b89b05f2019-01-16 09:53:09 +00002654void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2655{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002656 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002657
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002658 ValidateNumInputs(workloadInfo, descriptorName, 2);
2659 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002660
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002661 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2662 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002663 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002664 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002665 }
2666
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002667 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2668 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2669
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002670 std::vector<DataType> supportedTypes =
2671 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002672 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002673 DataType::Float16,
2674 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002675 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002676 DataType::QAsymmU8,
Teresa Charlin93492462020-05-29 13:08:59 +01002677 DataType::QSymmS16,
2678 DataType::Signed32,
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002679 };
2680
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002681 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002682
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002683 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002684
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002685 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2686 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002687}
2688
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002689void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2690{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002691 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2692
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002693 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002694
2695 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2696 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002697 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002698 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2699 }
2700
2701 if (m_Anchors == nullptr)
2702 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002703 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002704 }
2705
2706 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002707 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2708 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2709
2710 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002711 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002712 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2713 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002714
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002715 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2716 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2717 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002718
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002719 const std::vector<DataType> supportedInputTypes =
2720 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002721 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002722 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002723 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002724 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002725 DataType::QAsymmU8,
2726 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002727 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002728
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002729 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2730 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2731 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2732
2733 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2734 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2735 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2736 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2737
2738 // NOTE: Output is always Float32 regardless of input type
2739 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2740 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2741 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2742 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002743
2744 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2745 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002746 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002747 "must be positive and less than or equal to 1.");
2748 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002749
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002750 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2751 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002752 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002753 "should be equal to number of classes + 1.");
2754 }
2755}
2756
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002757void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2758{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002759 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002760
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002761 ValidateNumInputs(workloadInfo, descriptorName, 1);
2762 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2763
2764 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2765 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2766
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002767 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002768 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002769 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002770 }
2771
Sadik Armagan2208b602019-07-31 16:36:27 +01002772 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002773 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002774 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002775 DataType::Float32,
2776 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002777 };
2778
2779 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002780}
2781
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002782void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2783{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002784 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002785
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002786 ValidateNumInputs(workloadInfo, descriptorName, 2);
2787 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002788
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002789 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2790 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2791 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002792
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002793 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2794 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2795
2796 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2797 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002798}
2799
Sadik Armaganeff363d2019-04-05 15:25:46 +01002800void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2801{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002802 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002803
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002804 ValidateNumInputs(workloadInfo, descriptorName, 2);
2805 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2806
2807 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2808 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2809
2810 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2811 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2812
2813 std::vector<DataType> supportedTypes =
2814 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002815 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002816 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002817 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002818 DataType::QAsymmU8,
2819 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002820 };
2821
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002822 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2823 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002824
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002825 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2826 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002827
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002828 ValidateTensorShapesMatch(inputTensorInfo0,
2829 outputTensorInfo0,
2830 descriptorName,
2831 "input_0",
2832 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002833
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002834 ValidateTensorShapesMatch(inputTensorInfo0,
2835 outputTensorInfo1,
2836 descriptorName,
2837 "input_0",
2838 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002839}
2840
Derek Lamberti901ea112019-12-10 22:07:09 +00002841void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002842{
2843 // This is internally generated so it should not need validation.
2844}
2845
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002846void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2847{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002848 const std::string& descriptorName{"PreluQueueDescriptor"};
2849
2850 ValidateNumInputs(workloadInfo, descriptorName, 2);
2851 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2852
2853 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2854 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2855 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002856
2857 std::vector<DataType> supportedTypes
2858 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002859 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002860 DataType::Float16,
2861 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002862 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002863 DataType::QAsymmU8,
2864 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002865 };
2866
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002867 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2868 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002869
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002870 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002871
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002872 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2873 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002874
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002875 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2876 alphaTensorInfo,
2877 outputTensorInfo,
2878 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002879 "input",
2880 "alpha");
2881}
2882
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002883void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2884{
2885 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2886
2887 ValidateNumInputs(workloadInfo, descriptorName, 1);
2888 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2889
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002890 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2891 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2892
2893 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2894 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002895
2896 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002897
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002898 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2899 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002900
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002901 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2902
2903 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002904 if (m_Parameters.m_BiasEnabled)
2905 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002906 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002907
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002908 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2909 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002910
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002911 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002912 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002913 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002914
2915 ValidatePerAxisQuantization(inputTensorInfo,
2916 outputTensorInfo,
2917 weightTensorInfo,
2918 optionalBiasTensorInfo,
2919 descriptorName);
2920
2921 std::vector<DataType> supportedTypes =
2922 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002923 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002924 DataType::Float32,
2925 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002926 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002927 DataType::QAsymmU8,
2928 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002929 };
2930
2931 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2932 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002933}
2934
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002935void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2936{
2937 const std::string descriptorName{"TransposeQueueDescriptor"};
2938
2939 ValidateNumInputs(workloadInfo, descriptorName, 1);
2940 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2941
2942 const PermutationVector& mapping = m_Parameters.m_DimMappings;
2943
2944 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2945 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2946
2947 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
2948 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
2949
2950 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
2951 {
2952 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
2953 {
2954 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
2955 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
2956 "must match dst dimension " + to_string(i) +
2957 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
2958 }
2959 }
2960
2961 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2962}
2963
James Conroy4f1f8992020-04-29 20:01:10 +01002964void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2965{
2966 const std::string descriptorName{"QLstmQueueDescriptor"};
2967
2968 // Validate number of inputs/outputs
2969 ValidateNumInputs(workloadInfo, descriptorName, 3);
2970 ValidateNumOutputs(workloadInfo, descriptorName, 3);
2971
2972 // Input/output tensor info
2973 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2974 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
2975 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
2976
2977 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2978 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2979 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
2980
2981 // Supported types for various tensors in QLSTM
2982 std::vector<DataType> inputOutputSupportedTypes =
2983 {
2984 DataType::QAsymmS8
2985 };
2986
2987 std::vector<DataType> cellStateSupportedTypes =
2988 {
2989 DataType::QSymmS16
2990 };
2991
2992 std::vector<DataType> weightsSupportedTypes =
2993 {
2994 DataType::QSymmS8
2995 };
2996
2997 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
2998 {
2999 DataType::QSymmS16
3000 };
3001
3002 std::vector<DataType> biasSupportedTypes =
3003 {
3004 DataType::Signed32
3005 };
3006
3007 // Validate types of input/output tensors
3008 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3009 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3010 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3011
3012 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3013 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3014 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3015
3016 // Validate matching types of input/output tensors
3017 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3018 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3019 "outputStateIn", "outputStateOut");
3020 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3021
3022 // Infer number of batches, number of units, input size and output size from tensor dimensions
3023 const uint32_t numBatches = inputInfo.GetShape()[0];
3024 const uint32_t inputSize = inputInfo.GetShape()[1];
3025 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3026 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3027
3028 // Validate number of dimensions and number of elements for input/output tensors
3029 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3030 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3031 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3032
3033 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3034 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3035 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3036
3037 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3038 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3039 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3040 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3041
3042 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3043 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3044 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3045
3046 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3047 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3048 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3049
3050 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3051 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3052 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3053 " RecurrentToForgetWeights");
3054
3055 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3056 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3057 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3058
3059 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3060 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3061 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3062
3063 // Validate data types for MANDATORY weights tensors (all should match each other)
3064 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3065
3066 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3067 "inputToForgetWeights", "inputToCellWeights");
3068 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3069 "inputToForgetWeights", "inputToOutputWeights");
3070
3071 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3072 "inputToForgetWeights", "recurrentToForgeteights");
3073 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3074 "inputToForgetWeights", "recurrentToCellWeights");
3075 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3076 "inputToForgetWeights", "recurrentToOutputWeights");
3077
3078 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3079 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3080 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3081 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3082
3083 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3084 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3085 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3086
3087 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3088 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3089 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3090
3091 // Validate data types for MANDATORY bias tensors
3092 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3093
3094 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3095 "forgetGateBias", "cellBias");
3096 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3097 "forgetGateBias", "outputGateBias");
3098
3099 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3100 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3101 !m_Parameters.m_CifgEnabled) ||
3102 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3103 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3104
3105 if (!allCifgParamsPresentOrNot)
3106 {
3107 throw InvalidArgumentException(descriptorName +
3108 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3109 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3110 "set appropriately.");
3111 }
3112
3113 if (!m_Parameters.m_CifgEnabled)
3114 {
3115 // Validate number of dimensions and number of elements
3116 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3117 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3118
3119 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3120 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3121 " RecurrentToInputWeights");
3122
3123 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3124 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3125
3126 // Validate data types
3127 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3128 "inputToForgetWeights", "inputToInputWeights");
3129 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3130 "inputToForgetWeights", "recurrentToInputWeights");
3131 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3132 "forgetGateBias", "inputGateBias");
3133 }
3134
3135 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3136 bool allPeepholeWeightsPresentOrNot =
3137 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3138 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3139 || (!m_CellToInputWeights && !m_CellToForgetWeights
3140 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3141
3142 if (!allPeepholeWeightsPresentOrNot)
3143 {
3144 throw InvalidArgumentException(descriptorName +
3145 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3146 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3147 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3148 "appropriately.");
3149 }
3150
3151 if (m_Parameters.m_PeepholeEnabled)
3152 {
3153 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3154 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3155 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3156
3157 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3158 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3159 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3160 "cellToForgetWeight", "cellToOutputWeights");
3161
3162 if (!m_Parameters.m_CifgEnabled)
3163 {
3164 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3165 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3166 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3167 "cellToForgetWeights", "cellToInputWeights");
3168 }
3169 }
3170
3171 // Validate OPTIONAL params: Layer Norm Weights
3172 bool allLayerNormWeightsPresentOrNot =
3173 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3174 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3175 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3176 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3177
3178 if (!allLayerNormWeightsPresentOrNot)
3179 {
3180 throw InvalidArgumentException(descriptorName +
3181 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3182 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3183 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3184 "only be present when Layer Norm is enabled and CIFG is disabled. "
3185 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3186 }
3187
3188 if (m_Parameters.m_LayerNormEnabled)
3189 {
3190 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3191 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3192 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3193
3194 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3195 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3196 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3197 "forgetLayerNormWeights", "cellLayerNormWeights");
3198
3199 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3200 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3201 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3202 "forgetLayerNormWeights", "outputLayerNormWeights");
3203
3204 if (!m_Parameters.m_CifgEnabled)
3205 {
3206 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3207 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3208 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3209 "forgetLayerNormWeights", "inputLayerNormWeights");
3210 }
3211 }
3212
3213 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3214 bool correctProjectionTensorsPresent =
3215 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3216 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3217 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3218
3219 if (!correctProjectionTensorsPresent)
3220 {
3221 throw InvalidArgumentException(descriptorName +
3222 ": If projection is enabled, ProjectionWeights should be present and "
3223 "ProjectionBias is optional. If projection is disabled, neither "
3224 "ProjectionWeights nor ProjectionBias should be present.");
3225 }
3226
3227 if (m_Parameters.m_ProjectionEnabled)
3228 {
3229 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3230 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3231 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3232
3233 if (m_ProjectionBias)
3234 {
3235 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
Sadik Armagand6f06492020-05-22 08:36:33 +01003236 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
James Conroy4f1f8992020-04-29 20:01:10 +01003237 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3238 }
3239
3240 }
3241 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3242 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3243 throw InvalidArgumentException(descriptorName +
3244 ": If projection is disabled, output quantization info (scale, offset) "
3245 "should match HiddenStateScale and HiddenStateZeroPoint.");
3246 }
3247
3248}
3249
James Conroy9c3cae82019-08-01 16:01:48 +01003250void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3251{
3252 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3253
3254 // Validate number of inputs/outputs
3255 ValidateNumInputs(workloadInfo, descriptorName, 3);
3256 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3257
3258 // Input/output tensor infos
3259 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3260 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3261 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3262
3263 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3264 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3265
3266 std::vector<DataType> inputOutputSupportedTypes =
3267 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003268 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003269 };
3270
3271 std::vector<DataType> cellStateSupportedTypes =
3272 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003273 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01003274 };
3275
3276 std::vector<DataType> weightsSupportedTypes =
3277 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00003278 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01003279 };
3280
3281 std::vector<DataType> biasSupportedTypes =
3282 {
3283 DataType::Signed32
3284 };
3285
3286 // Validate types of input/output tensors
3287 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3288 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3289 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3290
3291 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3292 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3293
3294 // Validate matching types of input/output tensors
3295 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3296 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3297 "outputStateIn", "outputStateOut");
3298 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3299
3300 // Validate matching quantization info for input/output tensors
3301 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3302 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3303 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003304
James Conroy9c3cae82019-08-01 16:01:48 +01003305 // Infer number of batches, input size and output size from tensor dimensions
3306 const uint32_t numBatches = inputInfo.GetShape()[0];
3307 const uint32_t inputSize = inputInfo.GetShape()[1];
3308 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3309
3310 // Validate number of dimensions and number of elements for input/output tensors
3311 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3312 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3313 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3314 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3315 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3316
3317 // Validate number of dimensions and number of elements for weights tensors
3318 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3319 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3320 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3321
3322 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3323 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3324 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3325
3326 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3327 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3328 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3329
3330 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3331 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3332 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3333
3334 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3335 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3336 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3337
3338 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3339 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3340 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3341 " RecurrentToForgetWeights");
3342
3343 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3344 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3345 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3346
3347 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3348 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3349 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3350
3351 // Validate data types for weights tensors (all should match each other)
3352 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3353
3354 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3355 "inputToInputWeights", "inputToForgetWeights");
3356 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3357 "inputToInputWeights", "inputToCellWeights");
3358 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3359 "inputToInputWeights", "inputToOutputWeights");
3360
3361 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3362 "inputToInputWeights", "recurrentToInputWeights");
3363 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3364 "inputToInputWeights", "recurrentToForgeteights");
3365 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3366 "inputToInputWeights", "recurrentToCellWeights");
3367 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3368 "inputToInputWeights", "recurrentToOutputWeights");
3369
3370 // Validate matching quantization info for weight tensors (all should match each other)
3371 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3372 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3373 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3374 descriptorName, "inputToInputWeights", "inputToCellWeights");
3375 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3376 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3377
3378 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3379 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3380 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3381 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3382 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3383 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3384 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3385 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3386
3387 // Validate number of dimensions and number of elements in bias tensors
3388 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3389 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3390 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3391
3392 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3393 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3394 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3395
3396 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3397 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3398 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3399
3400 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3401 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3402 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3403
3404 // Validate data types for bias tensors (all should match each other)
3405 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3406
3407 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3408 "inputGateBias", "forgetGateBias");
3409 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3410 "inputGateBias", "cellBias");
3411 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3412 "inputGateBias", "outputGateBias");
3413
3414 // Validate bias tensor quantization info
3415 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3416 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3417 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3418 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3419}
3420
Kevin May868eb142019-09-04 17:29:31 +01003421void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3422{
3423 const std::string descriptorName{"AbsQueueDescriptor"};
3424
3425 ValidateNumInputs(workloadInfo, descriptorName, 1);
3426 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3427
3428 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3429 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3430
3431 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3432
3433 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01003434 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003435 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01003436 DataType::Float16,
3437 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003438 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003439 DataType::QAsymmU8,
Kevin Mayec52c3a2020-04-24 09:42:31 +01003440 DataType::QSymmS16,
3441 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +01003442 };
Kevin May868eb142019-09-04 17:29:31 +01003443
3444 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3445 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3446}
3447
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003448void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3449{
3450 const std::string descriptorName{"SliceQueueDescriptor"};
3451
3452 ValidateNumInputs(workloadInfo, descriptorName, 1);
3453 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3454
3455 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3456 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3457
3458 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3459
3460 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3461 if (rank > 4)
3462 {
3463 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3464 }
3465
3466 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3467
3468 // Check if m_Begin and m_Size have the expected length
3469 if (m_Parameters.m_Begin.size() != rank)
3470 {
3471 throw InvalidArgumentException(descriptorName +
3472 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3473 }
3474 if (m_Parameters.m_Size.size() != rank)
3475 {
3476 throw InvalidArgumentException(descriptorName +
3477 ": Length of size descriptor must equal rank " + std::to_string(rank));
3478 }
3479
3480 // Check if the shape of the output tensor matches m_Size
3481 const TensorShape& outputShape = outputTensorInfo.GetShape();
3482 for (unsigned int i = 0u; i < rank; ++i)
3483 {
3484 if (m_Parameters.m_Size[i] != outputShape[i])
3485 {
3486 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3487 }
3488 }
3489
3490 // Check if the sum of begin offset and size in a given dimension
3491 // does not exceed the size of corresponding input
3492 const TensorShape& inputShape = inputTensorInfo.GetShape();
3493 for(unsigned int i = 0u; i < rank; ++i)
3494 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003495 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003496 {
3497 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3498 std::to_string(i) + " exceeds input size.");
3499 }
3500 }
3501}
3502
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003503void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3504{
3505 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3506
3507 ValidateNumInputs(workloadInfo, descriptorName, 1);
3508 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3509
3510 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3511 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3512
3513 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3514 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3515
3516 std::vector<DataType> supportedTypes =
3517 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003518 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003519 DataType::Float32,
3520 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01003521 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003522 DataType::QAsymmU8,
3523 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003524 };
3525
3526 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3527 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3528
3529 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3530
3531 if (m_Parameters.m_BlockSize == 0)
3532 {
3533 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3534 }
3535
3536 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3537 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3538 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3539 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3540
3541 const TensorShape& outputShape = outputInfo.GetShape();
3542 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3543 {
3544 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3545 "must be divisible by block size.");
3546 }
3547
3548 const TensorShape& inputShape = inputInfo.GetShape();
3549 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3550 {
3551 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3552 "must be divisible by the square of block size." );
3553 }
3554}
3555
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003556void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3557{
3558 const std::string descriptorName{"ComparisonQueueDescriptor"};
3559
3560 ValidateNumInputs(workloadInfo, descriptorName, 2);
3561 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3562
3563 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3564 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3565 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3566
3567 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3568 inputTensorInfo1,
3569 outputTensorInfo,
3570 descriptorName,
3571 "input_0",
3572 "input_1");
3573
3574 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3575 {
3576 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3577 }
3578}
3579
josh minor4a3c6102020-01-06 16:40:46 -06003580void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3581{
3582 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3583
3584 ValidateNumInputs(workloadInfo, descriptorName, 1);
3585 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3586
3587 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3588 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3589
3590 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3591
3592 std::vector<DataType> supportedTypes =
3593 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003594 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003595 DataType::Float16,
3596 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01003597 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06003598 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003599 DataType::QSymmS16,
3600 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003601 };
3602
James Conroyaba90cd2020-11-06 16:28:18 +00003603 std::vector<DataType> logicalSupportedTypes =
3604 {
3605 DataType::Boolean
3606 };
3607
3608 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3609 {
3610 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3611 }
3612 else
3613 {
3614 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3615 }
3616
3617
josh minor4a3c6102020-01-06 16:40:46 -06003618 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3619}
3620
Finn Williams2605b232020-06-10 15:53:46 +01003621void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3622{
3623 const std::string descriptorName{"RankQueueDescriptor"};
3624
3625 ValidateNumInputs(workloadInfo, descriptorName, 1);
3626 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3627
3628 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3629 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3630
3631 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3632 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3633
3634 std::vector<DataType> supportedTypes =
3635 {
3636 DataType::BFloat16,
3637 DataType::Float16,
3638 DataType::Float32,
3639 DataType::QAsymmS8,
3640 DataType::QAsymmU8,
3641 DataType::QSymmS8,
3642 DataType::QSymmS16,
3643 DataType::Signed32
3644 };
3645
3646 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3647 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3648}
3649
James Conroyaba90cd2020-11-06 16:28:18 +00003650void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3651{
3652 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3653
3654 ValidateNumInputs(workloadInfo, descriptorName, 2);
3655 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3656
3657 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3658 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3659 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3660
3661 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3662 inputTensorInfo1,
3663 outputTensorInfo,
3664 descriptorName,
3665 "input_0",
3666 "input_1");
3667
3668 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3669 {
3670 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3671 }
3672
3673 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3674 {
3675 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3676 }
3677
3678 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3679 {
3680 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3681 }
3682}
3683
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003684void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3685{
3686 const std::string descriptorName{"ReduceQueueDescriptor"};
3687
3688 ValidateNumInputs(workloadInfo, descriptorName, 1);
3689 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3690
3691 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3692 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3693
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00003694 std::vector<DataType> supportedTypes =
3695 {
3696 DataType::BFloat16,
3697 DataType::Float16,
3698 DataType::Float32,
3699 DataType::QAsymmS8,
3700 DataType::QAsymmU8,
3701 DataType::QSymmS16,
3702 DataType::Signed32
3703 };
3704
3705 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3706 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3707}
3708
mathad01df9a3222021-04-28 11:42:57 +01003709} // namespace armnn