blob: 8bf2b0f988a1f7bd55a402fab0e557c6f224bf82 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00006#include <backendsCommon/WorkloadData.hpp>
7#include <backendsCommon/CpuTensorHandle.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/TensorUtils.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000013#include <string>
14#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/format.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010017#include <boost/numeric/conversion/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
telsoa01c577f2c2018-08-31 09:22:23 +010029 case DataType::Float16:
30 return DataType::Float16;
telsoa014fcda012018-03-09 14:13:49 +000031 case DataType::Float32:
32 return DataType::Float32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000033 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000034 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000035 case DataType::QSymmS8:
36 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000037 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010038 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000039 default:
40 BOOST_ASSERT_MSG(false, "Invalid input data type");
41 return DataType::Float32;
42 }
43}
44
45namespace
46{
47
48//---------------------------------------------------------------
49//android ndk does not support std::to_string function.
50template <typename T>
51std::string to_string(T value)
52{
53 std::ostringstream os;
54 os << value;
55 return os.str();
56}
57
58//---------------------------------------------------------------
59void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
60{
61 if (!ptr)
62 {
63 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
64 paramName + " parameter must be set.");
65 }
66}
67
68//---------------------------------------------------------------
69void ValidateTensorShapesMatch(const TensorInfo& first,
70 const TensorInfo& second,
71 std::string const& descName,
72 std::string const& firstName,
73 std::string const& secondName)
74{
75 if (first.GetShape() != second.GetShape())
76 {
77 throw InvalidArgumentException(descName + ": "
78 + firstName + " & " + secondName + " must have identical shapes");
79 }
80}
81
82//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010083void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000084{
Sadik Armaganeff363d2019-04-05 15:25:46 +010085 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000086 {
87 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010088 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000089 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
90 }
91}
92
93//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010094void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000095{
Sadik Armaganeff363d2019-04-05 15:25:46 +010096 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000097 {
98 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010099 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000100 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
101 }
102}
103
104//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100105void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000106 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100107 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000108 std::string const& tensorName)
109{
110 if (tensor.GetNumDimensions() != numDimensions)
111 {
112 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
113 to_string(tensor.GetNumDimensions()) + " dimensions for " +
114 tensorName + " tensor.");
115 }
116}
117
118//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100119void ValidateTensorNumElements(const TensorInfo& tensor,
120 std::string const& descName,
121 unsigned int numElements,
122 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100123{
124 if (tensor.GetNumElements() != numElements)
125 {
126 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100127 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100128 tensorName + " tensor.");
129 }
130}
131
132//---------------------------------------------------------------
133void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100134 unsigned int numDimension,
135 unsigned int numElements,
136 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100137{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100138 const std::string functionName{"ValidateTensorNumDimNumElem"};
139 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
140 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100141}
142
143//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000144void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
145 const std::string& descName, std::string const& tensorName)
146{
147 if (tensor.GetDataType() != dataType)
148 {
149 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
150 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
151 }
152}
153
Derek Lambertid466a542020-01-22 15:37:29 +0000154void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
155{
156 ARMNN_NO_DEPRECATE_WARN_BEGIN
157 if (tensor.GetDataType() != DataType::QSymmS8 &&
158 tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
159 {
160 throw InvalidArgumentException(descName +
161 ": Expected data type which supports per-axis quantization scheme but got " +
162 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
163 }
164 ARMNN_NO_DEPRECATE_WARN_END
165}
166
telsoa014fcda012018-03-09 14:13:49 +0000167//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100168void ValidateTensorQuantizationSpace(const TensorInfo& first,
169 const TensorInfo& second,
170 const std::string& descName,
171 std::string const& firstName,
172 std::string const& secondName)
173{
174 if (!first.IsQuantized() ||
175 !second.IsQuantized())
176 {
177 // Not a quantized type, ignore the validation
178 return;
179 }
180
181 DataType firstDataType = first.GetDataType();
182 DataType secondDataType = second.GetDataType();
183
184 if (firstDataType != secondDataType)
185 {
186 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
187 " must be of the same quantized type, " +
188 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
189 secondName + " is " + GetDataTypeName(secondDataType));
190 }
191
192 if (!first.IsTypeSpaceMatch(second))
193 {
194 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
195 " must have the same quantization space, " +
196 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
197 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
198 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
199 " and scale " + to_string(second.GetQuantizationScale()));
200 }
201}
202
203//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100204void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
205 const TensorInfo& inputTensorInfo,
206 const TensorInfo& weightsTensorInfo,
207 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000208{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000209 // Helper lambda function to validate a single bias quantization scale value
210 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
211 {
ricbur013f4d7102019-10-31 16:22:18 +0000212 constexpr float tolerance = 0.000001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000213 if (std::abs(biasScale - expectedScale) > tolerance)
214 {
215 // Print the float values with extra precision to see very small differences
216 std::stringstream msg;
217 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
218 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
219 biasScale;
220 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
221 }
222 };
223
telsoa014fcda012018-03-09 14:13:49 +0000224 if (biasTensor.GetQuantizationOffset() != 0)
225 {
226 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
227 to_string(biasTensor.GetQuantizationOffset()));
228 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000229
230 if (biasTensor.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000231 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000232 // Validate per-axis quantization scales
233 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
234 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
235
236 if (weightScales.size() != biasScales.size())
237 {
238 std::stringstream msg;
239 msg << descName << ": Expected matchhing number of per-axis quantization scales, but got different "
240 << "values: weights=" << weightScales.size() << ", biases=" << biasScales.size();
241 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
242 }
243
244 for (size_t i = 0ul; i < biasScales.size(); ++i)
245 {
246 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
247 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
248 }
249 }
250 else
251 {
252 // Validate per-tensor quantization scale
253 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
254 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000255 }
256}
257
258//---------------------------------------------------------------
259void ValidateTensors(const std::vector<ITensorHandle*>& vec,
260 unsigned int numExpected,
261 const std::string& descName,
262 const std::string& varName)
263{
264 if (vec.empty() && numExpected > 0)
265 {
266 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
267 }
268
269 for (unsigned int i = 0; i < numExpected; ++i)
270 {
271 if (!vec[i])
272 {
273 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
274 }
275 }
276}
277
278//---------------------------------------------------------------
279void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
280 const TensorInfo& second,
281 const TensorInfo& output,
282 std::string const& descName,
283 std::string const& firstName,
284 std::string const& secondName)
285{
286 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
287 // broadcasted.
288 if (first.GetNumDimensions() != second.GetNumDimensions())
289 {
290 throw InvalidArgumentException(descName + ": Tensors "
291 + firstName + " & " + secondName
292 + " must have the same number of dimensions in order to be broadcasted");
293 }
294 uint32_t numDims = first.GetNumDimensions();
295 std::vector<uint32_t> outputDims(numDims, 0u);
296 for (uint32_t i = 0; i < numDims; i++)
297 {
298 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
299 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
300 if (dimsNotEqual && dimsNotOne)
301 {
302 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
303 }
304 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
305 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100306 TensorShape broadcastShape = TensorShape(boost::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000307 if (broadcastShape != output.GetShape())
308 {
309 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
310 + firstName + " & " + secondName
311 + " does not match the output shape");
312 }
313}
314
315//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100316void ValidateDataTypes(const TensorInfo& info,
317 const std::vector<armnn::DataType>& supportedTypes,
318 std::string const& descName)
319{
320 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
321 if (iterator == supportedTypes.end())
322 {
323 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
324 }
325}
326
James Conroy4d1ff582019-06-10 17:06:39 +0100327//---------------------------------------------------------------
328void ValidateTensorDataTypesMatch(const TensorInfo& first,
329 const TensorInfo& second,
330 std::string const& descName,
331 std::string const& firstName,
332 std::string const& secondName)
333{
334 if (first.GetDataType() != second.GetDataType())
335 {
336 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
337 " must have identical data types.");
338 }
339}
340
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100341//---------------------------------------------------------------
342void ValidateTensorNumElementsMatch(const TensorInfo& first,
343 const TensorInfo& second,
344 std::string const& descName,
345 std::string const& firstName,
346 std::string const& secondName)
347{
348 if (first.GetNumElements() != second.GetNumElements())
349 {
350 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
351 " must have the same number of elements.");
352 }
353}
354
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000355void ValidateWeightDataType(const TensorInfo& inputInfo,
356 const TensorInfo& weightInfo,
357 const std::string& descName)
358{
359 const DataType inputType = inputInfo.GetDataType();
Derek Lambertif90c56d2020-01-10 17:14:08 +0000360 if (inputType == DataType::QAsymmU8)
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000361 {
Derek Lambertid466a542020-01-22 15:37:29 +0000362 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000363 const std::vector<DataType> validTypes =
364 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000365 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000366 DataType::QSymmS8,
367 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000368 };
Derek Lambertid466a542020-01-22 15:37:29 +0000369 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000370
371 ValidateDataTypes(weightInfo, validTypes, descName);
372 }
373 else
374 {
375 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
376 }
377}
378
379void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
380 const std::string& descName,
381 const std::string& tensorName)
382{
383 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
384 if (!quantizationDim.has_value())
385 {
386 throw InvalidArgumentException(boost::str(
387 boost::format("%1%: Quantization dimension for per-axis quantization not set on tensor %2%.")
388 % descName % tensorName));
389 }
390
391 if (quantizationDim.value() != 0)
392 {
393 throw InvalidArgumentException(boost::str(
394 boost::format("%1%: Quantization dimension for per-axis quantization expected to be 0 on tensor %2%, "
395 "but got: %3%") % descName % tensorName % quantizationDim.value()));
396 }
397}
398
399void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
400 const std::string& descName,
401 const std::string& tensorName)
402{
403 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
404 if (quantizationOffset != 0)
405 {
406 throw InvalidArgumentException(boost::str(
407 boost::format("%1%: Quantization offset for per-axis quantization expected to be 0 on tensor %2%, "
408 "but got: %3%") % descName % tensorName % quantizationOffset));
409 }
410}
411
412void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
413 const TensorInfo& outputInfo,
414 const TensorInfo& weightInfo,
415 const Optional<TensorInfo>& optionalBiasInfo,
416 const std::string& descName)
417{
418 if (weightInfo.HasPerAxisQuantization())
419 {
420 const DataType inputDataType = inputInfo.GetDataType();
421 const DataType outputDataType = outputInfo.GetDataType();
422
Keith Davis5204aa82020-01-27 15:24:59 +0000423 const bool canHavePerAxisQuantization = (inputDataType == DataType::QSymmS8 ||
424 inputDataType == DataType::QAsymmU8) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000425
426 if (!canHavePerAxisQuantization)
427 {
428 throw InvalidArgumentException(boost::str(
429 boost::format("%1%: Per-axis quantization parameters set on tensor %2%, "
430 "but data type does not support per-axis quantization.") % descName % "weight"));
431 }
432
Derek Lambertid466a542020-01-22 15:37:29 +0000433
434 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000435 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
436 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
437
438 if (optionalBiasInfo.has_value())
439 {
440 const TensorInfo& biasInfo = optionalBiasInfo.value();
441 if (!biasInfo.HasPerAxisQuantization())
442 {
443 throw InvalidArgumentException(boost::str(
444 boost::format("%1%: Per-axis quantization parameters not set on bias tensor, despite being set on "
445 "weight tensor.") % descName));
446 }
447
448 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
449 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
450 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
451 }
452 }
453}
454
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100455} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000456
457void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
458 unsigned int numExpectedIn, unsigned int numExpectedOut) const
459{
460 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
461 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
462}
463
464//---------------------------------------------------------------
465void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
466{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100467 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000468
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100469 ValidateNumInputs(workloadInfo, descriptorName, 1);
470 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000471
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100472 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
473 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
474
475 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
476 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000477
478 if (m_Inputs.size() != m_Outputs.size())
479 {
480 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100481 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
482 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000483 }
484
485 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
486 {
487 if (!m_Inputs[i])
488 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100489 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
490 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000491 }
492
493 if (!m_Outputs[i])
494 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100495 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
496 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000497 }
498 }
499}
500
Derek Lambertif674aa02019-08-01 15:56:25 +0100501//---------------------------------------------------------------
502void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
503{
504 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
505 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
506
507 if (workloadInfo.m_InputTensorInfos.size() != 1)
508 {
509 throw InvalidArgumentException(boost::str(
510 boost::format("Number of input infos (%1%) is not 1.")
511 % workloadInfo.m_InputTensorInfos.size()));
512
513 }
514
515 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
516 {
517 throw InvalidArgumentException(boost::str(
518 boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
519 % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
520 }
521
522 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
523 {
524 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
525 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
526 {
527 throw InvalidArgumentException(boost::str(
528 boost::format("Number of elements for tensor input and output %1% does not match")
529 % i ));
530 }
531 }
532
533 if (m_Inputs.size() != 1)
534 {
535 throw InvalidArgumentException(boost::str(
536 boost::format("Number of inputs (%1%) is not 1.")
537 % m_Inputs.size()));
538 }
539
540 if (m_Inputs.size() != m_Outputs.size())
541 {
542 throw InvalidArgumentException(boost::str(
543 boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
544 % m_Inputs.size() % m_Outputs.size()));
545 }
546
547 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
548 {
549 if (!m_Inputs[i])
550 {
551 throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
552 }
553
554 if (!m_Outputs[i])
555 {
556 throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
557 }
558 }
559}
560
561//---------------------------------------------------------------
562void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
563{
564 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
565 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
566
Derek Lambertif674aa02019-08-01 15:56:25 +0100567 if (m_Inputs.size() != 1)
568 {
569 throw InvalidArgumentException(boost::str(
570 boost::format("Number of inputs (%1%) is not 1.")
571 % m_Inputs.size()));
572 }
573
574 if (m_Outputs.size() != 0)
575 {
576 throw InvalidArgumentException(boost::str(
577 boost::format("Number of outputs (%1%) is not 0.")
578 % m_Inputs.size() % m_Outputs.size()));
579 }
580
581 if (!m_Inputs[0])
582 {
583 throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
584 }
585}
586
587//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000588void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
589{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100590 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100591
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100592 ValidateNumInputs(workloadInfo, descriptorName, 1);
593 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100594
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100595 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
596 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100597
598 std::vector<DataType> supportedTypes =
599 {
James Conroyd47a0642019-09-17 14:22:06 +0100600 DataType::Float16,
601 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000602 DataType::QAsymmU8,
603 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100604 };
605
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100606 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
607 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
608 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000609}
610
Nikhil Rajee391d52019-09-05 17:50:44 +0100611void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
612{
613 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
614
615 ValidateNumInputs(workloadInfo, descriptorName, 1);
616 ValidateNumOutputs(workloadInfo, descriptorName, 1);
617
618 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
619 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
620
Nikhil Raj68c2c902019-09-19 11:21:11 +0100621 if (outputTensorInfo.GetDataType() != DataType::Signed32)
622 {
623 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32.");
624 }
625
James Conroyd47a0642019-09-17 14:22:06 +0100626 std::vector<DataType> supportedInputTypes =
627 {
628 DataType::Float16,
629 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000630 DataType::QAsymmU8,
631 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000632 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +0100633 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100634
James Conroyd47a0642019-09-17 14:22:06 +0100635 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100636
637 auto inputShape = inputTensorInfo.GetShape();
638 auto outputShape = outputTensorInfo.GetShape();
639
640 auto inputNumDimensions = inputShape.GetNumDimensions();
641 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
642
643 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
644
645 // 1D input shape results in scalar output shape
646 if (inputShape.GetNumDimensions() == 1)
647 {
648 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
649 {
650 throw InvalidArgumentException(descriptorName + outputShapeError);
651 }
652 }
653 else
654 {
655 for (unsigned int i = 0; i < unsignedAxis; ++i)
656 {
657 if (outputShape[i] != inputShape[i])
658 {
659 throw InvalidArgumentException(descriptorName + outputShapeError);
660 }
661 }
662
663 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
664 {
665 if (outputShape[i - 1] != inputShape[i])
666 {
667 throw InvalidArgumentException(descriptorName + outputShapeError);
668 }
669 }
670 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100671}
672
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100673void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
674{
675 const std::string descriptorName{"SoftmaxQueueDescriptor"};
676
677 ValidateNumInputs(workloadInfo, descriptorName, 1);
678 ValidateNumOutputs(workloadInfo, descriptorName, 1);
679
680 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
681 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
682
683 std::vector<DataType> supportedTypes =
684 {
James Conroyd47a0642019-09-17 14:22:06 +0100685 DataType::Float16,
686 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000687 DataType::QAsymmU8,
688 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100689 };
690
691 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
692 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
693 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
694}
695
telsoa014fcda012018-03-09 14:13:49 +0000696void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
697{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100698 const std::string descriptorName{"SplitterQueueDescriptor"};
699
700 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000701
Ruomei Yan25339c32019-05-28 16:48:20 +0100702 // Check the supported data types
703 std::vector<DataType> supportedTypes =
704 {
James Conroyd47a0642019-09-17 14:22:06 +0100705 DataType::Float32,
706 DataType::Float16,
707 DataType::Boolean,
708 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000709 DataType::QAsymmU8,
710 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100711 };
712
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100713 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
714 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100715 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100716 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
717 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
718
719 const std::string outputName = "output_" + std::to_string(i);
720 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100721 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100722
telsoa014fcda012018-03-09 14:13:49 +0000723 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
724 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100725 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000726 }
727
728 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
729 {
730 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100731 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000732 "has to match number of workloadInfo.m_OutputTensorInfos. "
733 "Number of windows: " +
734 to_string(m_ViewOrigins.size()) +
735 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
736 }
737
telsoa01c577f2c2018-08-31 09:22:23 +0100738 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000739 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
740 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
741 {
telsoa01c577f2c2018-08-31 09:22:23 +0100742 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000743 ViewOrigin const& e = m_ViewOrigins[w];
744 if (e.m_Origin.size() != inputDims)
745 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100746 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000747 "have the same dimensionality as the input tensor. "
748 "Window origin (index: " +
749 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
750 " dimensions, the input "
751 "tensor has " +
752 to_string(inputDims) + " dimensions.");
753 }
754 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
755 {
756 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
757 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
758 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100759 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000760 "be smaller or equal than the size of the input in that coord.");
761 }
762 }
763 }
764}
765
Jim Flynne242f2d2019-05-22 14:24:13 +0100766void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000767{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100768 const std::string descriptorName{"ConcatQueueDescriptor"};
769
770 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000771
772 if (m_Inputs.size() <= 0)
773 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100774 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000775 }
776 if (m_Outputs.size() <= 0)
777 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100778 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000779 }
780
781 if (workloadInfo.m_InputTensorInfos.size() <= 0)
782 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100783 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000784 }
785 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
786 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100787 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000788 }
789
Nikhil Raj8599a412018-11-19 14:51:07 +0000790 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
791 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100792 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000793 }
794
795 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
796 {
797 return;
798 }
799
telsoa014fcda012018-03-09 14:13:49 +0000800 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
801 {
802 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100803 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000804 "has to match number of workloadInfo.m_InputTensorInfos. "
805 "Number of windows: " +
806 to_string(m_ViewOrigins.size()) +
807 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
808 }
809
telsoa01c577f2c2018-08-31 09:22:23 +0100810 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000811 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
812 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
813 {
telsoa01c577f2c2018-08-31 09:22:23 +0100814 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000815 ViewOrigin const& e = m_ViewOrigins[w];
816 if (e.m_Origin.size() != outputDims)
817 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100818 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000819 "have the same dimensionality as the output tensor. "
820 "Window origin (index: " +
821 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
822 " dimensions, the output "
823 "tensor has " +
824 to_string(outputDims) + " dimensions.");
825 }
telsoa01c577f2c2018-08-31 09:22:23 +0100826 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000827 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
828 {
829 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
830 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
831 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100832 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000833 "be smaller or equal than the size of the output in that coord.");
834 }
835 }
836 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100837
838 // Check the supported data types
839 std::vector<DataType> supportedTypes =
840 {
James Conroyd47a0642019-09-17 14:22:06 +0100841 DataType::Float32,
842 DataType::Float16,
843 DataType::Boolean,
844 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000845 DataType::QAsymmU8,
846 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100847 };
848
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100849 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
850 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100851 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100852 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
853 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
854
855 const std::string inputName = "input_" + std::to_string(i);
856 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100857 }
telsoa014fcda012018-03-09 14:13:49 +0000858}
859
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100860void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
861{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100862 const std::string descriptorName{"StackQueueDescriptor"};
863
864 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100865
866 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
867 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100868 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100869 }
870
871 // All inputs must have the same shape, which is defined in parameters
872 const TensorShape& inputShape = m_Parameters.m_InputShape;
873 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
874 {
875 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
876 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100877 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100878 }
879 }
880
Matthew Jacksondba634f2019-08-15 15:14:18 +0100881 if (inputShape.GetNumDimensions() > 4)
882 {
883 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
884 }
885
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100886 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
887 // since the output tensor has an additional dimension.
888 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
889 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100890 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100891 "than the number of input dimensions.");
892 }
893
894 // Output shape must be as inferred from the input shape
895 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
896 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
897 {
898 if (outputShape[i] != inputShape[i])
899 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100900 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100901 "match shape inferred from input tensor.");
902 }
903 }
904
905 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
906 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100907 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100908 "match shape inferred from input tensor.");
909 }
910
911 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
912 {
913 if (outputShape[i] != inputShape[i-1])
914 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100915 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100916 "match shape inferred from input tensor.");
917 }
918 }
919
Matthew Jacksondba634f2019-08-15 15:14:18 +0100920 if (outputShape.GetNumDimensions() > 5)
921 {
922 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
923 }
924
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100925 // Check the supported data types
926 std::vector<DataType> supportedTypes =
927 {
James Conroyd47a0642019-09-17 14:22:06 +0100928 DataType::Float32,
929 DataType::Float16,
930 DataType::Boolean,
931 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000932 DataType::QAsymmU8,
933 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100934 };
935
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100936 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100937
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100938 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100939 {
940 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
941 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100942 descriptorName,
943 "input_0",
944 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100945 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100946
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100947 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
948 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100949 descriptorName,
950 "input_0",
951 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100952}
953
telsoa014fcda012018-03-09 14:13:49 +0000954void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
955{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100956 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000957
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100958 ValidateNumInputs(workloadInfo, descriptorName, 1);
959 ValidateNumOutputs(workloadInfo, descriptorName, 1);
960
961 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
962 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
963
964 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
965
966 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +0000967 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100968 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +0000969 }
970
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100971 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000972
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100973 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
974 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000975
976 if (m_Parameters.m_BiasEnabled)
977 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100978 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000979
telsoa01c577f2c2018-08-31 09:22:23 +0100980 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100981 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
982 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000983
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100984 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
985 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000986 }
987
Francis Murtagh46c09d02019-05-28 08:15:28 +0100988 // Check the supported data types
989 std::vector<DataType> supportedTypes =
990 {
James Conroyd47a0642019-09-17 14:22:06 +0100991 DataType::Float32,
992 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000993 DataType::QAsymmU8,
994 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100995 };
996
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100997 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
998 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000999}
1000
telsoa014fcda012018-03-09 14:13:49 +00001001void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1002{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001003 const std::string descriptorName{"NormalizationQueueDescriptor"};
1004
1005 ValidateNumInputs(workloadInfo, descriptorName, 1);
1006 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1007
1008 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1009 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001010
1011 // Check the supported data types
1012 std::vector<DataType> supportedTypes =
1013 {
1014 DataType::Float16,
1015 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001016 DataType::QAsymmU8,
1017 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001018 };
1019
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001020 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001021
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001022 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001023
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001024 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001025}
1026
1027void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1028{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001029 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001030
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001031 ValidateNumInputs(workloadInfo, descriptorName, 2);
1032 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1033
1034 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1035 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1036 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1037
1038 std::vector<DataType> supportedTypes =
1039 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001040 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001041 DataType::QAsymmU8,
1042 DataType::QSymmS16,
Keith Davis5204aa82020-01-27 15:24:59 +00001043 DataType::QSymmS8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001044 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001045 };
1046
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001047 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1048 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1049 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001050
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001051 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1052 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001053
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001054 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1055 inputTensorInfo1,
1056 outputTensorInfo,
1057 descriptorName,
1058 "input_0",
1059 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001060}
1061
telsoa014fcda012018-03-09 14:13:49 +00001062void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1063{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001064 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001065
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001066 ValidateNumInputs(workloadInfo, descriptorName, 2);
1067 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1068
1069 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1070 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1071 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1072
1073 std::vector<DataType> supportedTypes =
1074 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001075 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001076 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001077 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001078 DataType::QSymmS16,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001079 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001080 };
1081
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001082 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1083 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1084 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001085
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001086 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1087 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001088
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001089 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1090 inputTensorInfo1,
1091 outputTensorInfo,
1092 descriptorName,
1093 "input_0",
1094 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001095}
1096
1097void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1098{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001099 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001100
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001101 ValidateNumInputs(workloadInfo, descriptorName, 1);
1102 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1103
1104 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1105 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001106
1107 std::vector<DataType> supportedTypes =
1108 {
1109 DataType::Float16,
1110 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001111 DataType::QAsymmU8,
1112 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001113 };
1114
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001115 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1116 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001117
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001118 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1119 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1120 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001121
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001122 ValidatePointer(m_Mean, descriptorName, "mean");
1123 ValidatePointer(m_Variance, descriptorName, "variance");
1124 ValidatePointer(m_Beta, descriptorName, "beta");
1125 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001126
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001127 const TensorInfo& mean = m_Mean->GetTensorInfo();
1128 const TensorInfo& variance = m_Variance->GetTensorInfo();
1129 const TensorInfo& beta = m_Beta->GetTensorInfo();
1130 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001131
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001132 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1133 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1134 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1135 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001136
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001137 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1138 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1139 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001140}
1141
1142void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1143{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001144 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001145
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001146 ValidateNumInputs(workloadInfo, descriptorName, 1);
1147 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001148
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001149 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1150 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001151
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001152 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1153 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001154
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001155 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001156
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001157 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1158 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001159
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001160 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001161
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001162 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001163 if (m_Parameters.m_BiasEnabled)
1164 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001165 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001166
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001167 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1168 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001169
1170 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1171 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001172 }
1173
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001174 ValidatePerAxisQuantization(inputTensorInfo,
1175 outputTensorInfo,
1176 weightTensorInfo,
1177 optionalBiasTensorInfo,
1178 descriptorName);
1179
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001180 std::vector<DataType> supportedTypes =
1181 {
Ruomei Yan88d44b82019-05-23 14:29:06 +01001182 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001183 DataType::QAsymmU8,
1184 DataType::QSymmS16,
Keith Davis5204aa82020-01-27 15:24:59 +00001185 DataType::QSymmS8,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001186 DataType::Float16
1187 };
1188
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001189 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1190 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1191}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001192
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001193void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1194{
1195 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1196
1197 ValidateNumInputs(workloadInfo, descriptorName, 1);
1198 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1199
1200 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1201 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1202
1203 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1204 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1205
1206 ValidatePointer(m_Weight, descriptorName, "weight");
1207
1208 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1209 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1210
1211 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1212 {
1213 throw InvalidArgumentException(
1214 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
1215 "cannot be smaller than 1.") % descriptorName %
1216 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
1217 }
1218
1219 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1220
1221 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1222 // inputChannels * channelMultiplier should be equal to outputChannels.
1223 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1224 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1225 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1226 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1227 {
1228 throw InvalidArgumentException(
1229 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
1230 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
1231 "(provided %4%).") % descriptorName % numWeightOutputChannels %
1232 numWeightInputChannels % numWeightChannelMultiplier));
1233 }
1234
Teresa Charlind8df0262019-11-11 12:28:15 +00001235 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001236
Teresa Charlind8df0262019-11-11 12:28:15 +00001237 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001238 if (m_Parameters.m_BiasEnabled)
1239 {
1240 ValidatePointer(m_Bias, descriptorName, "bias");
1241
Teresa Charlind8df0262019-11-11 12:28:15 +00001242 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1243 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001244
1245 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1246 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1247 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001248 ValidatePerAxisQuantization(inputTensorInfo,
1249 outputTensorInfo,
1250 weightTensorInfo,
1251 optionalBiasTensorInfo,
1252 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001253
1254 std::vector<DataType> supportedTypes =
1255 {
1256 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001257 DataType::QAsymmU8,
1258 DataType::QSymmS16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001259 DataType::Float16
1260 };
1261
1262 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1263 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001264}
1265
1266void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1267{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001268 const std::string descriptorName{"PermuteQueueDescriptor"};
1269
1270 ValidateNumInputs(workloadInfo, descriptorName, 1);
1271 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001272
1273 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1274
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001275 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1276 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001277
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001278 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1279 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001280
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001281 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001282 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001283 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001284 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001285 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1286 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1287 "must match dst dimension " + to_string(mapping[i]) +
1288 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001289 }
1290 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001291
1292 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001293}
1294
1295void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1296{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001297 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001298
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001299 ValidateNumInputs(workloadInfo, descriptorName, 1);
1300 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1301
1302 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1303 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1304
1305 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1306 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001307
1308 std::vector<DataType> supportedTypes =
1309 {
1310 DataType::Float32,
1311 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001312 DataType::QAsymmU8,
1313 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001314 };
1315
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001316 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1317 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001318}
1319
1320void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1321{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001322 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001323
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001324 ValidateNumInputs(workloadInfo, descriptorName, 1);
1325 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1326
1327 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1328 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1329
1330 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1331 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001332
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001333 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001334 {
1335 DataType::Float16,
1336 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001337 DataType::QAsymmU8,
1338 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001339 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001340
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001341 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1342 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001343
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001344 // ResizeBilinear only changes width and height: batch and channel count must match.
1345 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1346 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001347 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001348 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001349 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001350 boost::str(boost::format("%1%: Input batch size (%2%) "
1351 "does not match output batch size (%3%)") %
1352 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001353 }
1354
Teresa Charlin970f43b2019-07-01 13:51:07 +01001355 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001356 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1357 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001358 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001359 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001360 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001361 boost::str(boost::format("%1%: Input channel count (%2%) "
1362 "does not match output channel count (%3%)") %
1363 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001364 }
1365}
1366
1367void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1368{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001369 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001370
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001371 ValidateNumInputs(workloadInfo, descriptorName, 1);
1372 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1373
1374 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1375 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1376
1377 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1378 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001379
1380 std::vector<DataType> supportedTypes =
1381 {
1382 DataType::Float16,
1383 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001384 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001385 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001386 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001387 };
1388
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001389 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1390 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001391
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001392 // Resize only changes width and height: batch and channel count must match.
1393 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1394 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001395 if (inputBatchSize != outputBatchSize)
1396 {
1397 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001398 boost::str(boost::format("%1%: Input batch size (%2%) "
1399 "does not match output batch size (%3%)") %
1400 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001401 }
1402
1403 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001404 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1405 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001406 if (inputChannelCount != outputChannelCount)
1407 {
1408 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001409 boost::str(boost::format("%1%: Input channel count (%2%) "
1410 "does not match output channel count (%3%)") %
1411 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001412 }
1413}
1414
1415void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1416{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001417 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001418
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001419 ValidateNumInputs(workloadInfo, descriptorName, 1);
1420 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1421
1422 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1423 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1424
1425 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1426 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1427
1428 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1429
telsoa014fcda012018-03-09 14:13:49 +00001430 if (m_Parameters.m_Min > m_Parameters.m_Max)
1431 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001432 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001433 }
telsoa014fcda012018-03-09 14:13:49 +00001434}
1435
Kevin Mayce5045a2019-10-02 14:07:47 +01001436void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1437{
1438 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1439
1440 ValidateNumInputs(workloadInfo, descriptorName, 1);
1441 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1442
1443 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1444 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1445
1446 if (inputTensorInfo.GetNumDimensions() > 4)
1447 {
1448 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1449 }
1450
1451 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1452
1453 // Check the supported data types
1454 std::vector<DataType> supportedTypes =
1455 {
1456 DataType::Float32,
1457 DataType::Float16
1458 };
1459
1460 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001461 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001462}
1463
telsoa014fcda012018-03-09 14:13:49 +00001464void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1465{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001466 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001467
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001468 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001469 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1470
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001471 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1472 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1473
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001474 if (inputTensorInfo.GetNumDimensions() > 4)
1475 {
1476 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1477 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001478
1479 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001480
1481 // Check the supported data types
1482 std::vector<DataType> supportedTypes =
1483 {
1484 DataType::Float32,
1485 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001486 DataType::QAsymmU8,
1487 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001488 };
1489
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001490 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001491 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1492}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001493
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001494void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1495{
1496 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1497
1498 ValidateNumInputs(workloadInfo, descriptorName, 1);
1499 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1500
1501 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1502 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1503
1504 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1505
1506 std::vector<DataType> supportedTypes =
1507 {
1508 DataType::Float32,
1509 DataType::Float16,
1510 };
1511
1512 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001513 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001514}
1515
1516void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1517{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001518 const std::string descriptorName{"ConstantQueueDescriptor"};
1519
1520 ValidateNumInputs(workloadInfo, descriptorName, 0);
1521 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001522
1523 if (!m_LayerOutput)
1524 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001525 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001526 }
1527
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001528 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1529 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001530
1531 // Check the supported data types
1532 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001533 {
1534 DataType::Float32,
1535 DataType::Float16,
1536 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001537 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001538 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001539 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001540 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001541
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001542 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001543}
1544
1545void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1546{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001547 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001548
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001549 ValidateNumInputs(workloadInfo, descriptorName, 1);
1550 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1551
1552 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1553 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1554
1555 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001556
1557 // Check the supported data types
1558 std::vector<DataType> supportedTypes =
1559 {
1560 DataType::Float32,
1561 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001562 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001563 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001564 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001565 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001566 };
1567
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001568 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1569 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001570}
1571
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001572void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1573{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001574 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001575
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001576 ValidateNumInputs(workloadInfo, descriptorName, 1);
1577 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1578
1579 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1580 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1581
1582 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1583 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001584
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001585 if (m_Parameters.m_BlockShape.size() != 2)
1586 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001587 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001588 }
1589
1590 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1591 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001592 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1593 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001594 }
1595
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001596 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001597
1598 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001599 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001600
Matthew Bentham8800c002018-11-19 13:19:28 +00001601 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001602
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001603 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1604 widthPad.first + widthPad.second;
1605 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1606 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001607
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001608 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1609 inputShape[dimensionIndices.GetChannelsIndex()];
1610 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001611
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001612 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001613 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001614 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001615 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001616 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001617 }
1618
1619 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001620 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001621 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1622 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001623 }
nikraj01120522a2019-05-31 11:33:07 +01001624
1625 std::vector<DataType> supportedTypes =
1626 {
1627 DataType::Float16,
1628 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001629 DataType::QAsymmU8,
1630 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001631 };
1632
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001633 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1634 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001635}
1636
Keith Davisa57eccb2019-06-14 17:33:22 +01001637void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1638{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001639 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001640
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001641 ValidateNumInputs(workloadInfo, descriptorName, 1);
1642 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001643
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001644 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1645 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1646
1647 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1648 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001649
1650 std::vector<DataType> supportedTypes =
1651 {
1652 DataType::Float32,
1653 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001654 DataType::QAsymmU8,
1655 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001656 };
1657
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001658 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1659 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001660
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001661 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1662
1663 if (m_Parameters.m_BlockSize == 0)
1664 {
1665 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1666 }
1667
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001668 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1669 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1670 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1671 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001672
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001673 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001674 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001675 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001676 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1677 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001678 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001679
1680 const TensorShape& outputShape = outputTensorInfo.GetShape();
1681 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1682 {
1683 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1684 "must be divisible by the square of block size." );
1685 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001686}
1687
telsoa014fcda012018-03-09 14:13:49 +00001688void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1689{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001690 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001691
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001692 ValidateNumInputs(workloadInfo, descriptorName, 1);
1693 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1694
1695 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1696 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001697
1698 std::vector<DataType> supportedTypes =
1699 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001700 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001701 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001702 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001703 };
1704
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001705 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001706
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001707 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001708 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001709 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001710 }
1711}
1712
telsoa01c577f2c2018-08-31 09:22:23 +01001713void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1714{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001715 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1716
1717 const std::string descriptorName{"LstmQueueDescriptor"};
1718
1719 // check dimensions of all inputs and outputs
1720 if (workloadInfo.m_InputTensorInfos.size() != 3)
1721 {
1722 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1723 }
1724 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1725 {
1726 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1727 }
1728
1729 std::vector<DataType> supportedTypes =
1730 {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001731 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001732 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001733 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001734 };
1735
Jan Eilers38e05bd2019-06-26 13:10:09 +01001736 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001737 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1738
Jan Eilers38e05bd2019-06-26 13:10:09 +01001739 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001740 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001741 {
1742 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1743 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001744 descriptorName,
1745 "input_0",
1746 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001747 }
1748 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001749 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001750 {
1751 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1752 workloadInfo.m_OutputTensorInfos[i],
1753 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001754 "input_0",
1755 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001756 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001757
janeil0117d8d852019-11-15 15:00:16 +00001758 // Making sure clipping parameters have valid values.
1759 // == 0 means no clipping
1760 // > 0 means clipping
1761 if (m_Parameters.m_ClippingThresCell < 0.0f)
1762 {
1763 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
1764 }
1765 if (m_Parameters.m_ClippingThresProj < 0.0f)
1766 {
1767 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
1768 }
1769
Jan Eilers38e05bd2019-06-26 13:10:09 +01001770
1771 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01001772 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1773 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1774 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1775 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1776 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1777 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1778
Jan Eilers38e05bd2019-06-26 13:10:09 +01001779 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001780 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1781 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001782 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001783 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1784 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001785 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001786 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1787 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001788 // scratchBufferTensor
1789 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001790 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1791 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001792 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001793 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1794 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001795 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001796 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1797 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001798 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001799 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1800 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001801
1802
1803 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1804 if ( m_InputToInputWeights )
1805 {
1806 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1807 (n_cell * n_input), "InputLayerNormWeights");
1808 }
1809
1810 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1811 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1812 (n_cell * n_input), "InputToForgetWeights");
1813
1814 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1815 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1816 (n_cell * n_input), "InputToCellWeights");
1817
1818 if ( m_RecurrentToInputWeights )
1819 {
1820 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1821 (n_cell * n_output), "RecurrentToInputWeights");
1822 }
1823
1824 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1825 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1826 (n_cell * n_output), "RecurrentToForgetWeights");
1827
1828 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1829 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1830 (n_cell * n_output), "RecurrentToCellWeights");
1831
1832 // Make sure the input-gate's parameters are either both present (regular
1833 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1834 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1835 !m_Parameters.m_CifgEnabled) ||
1836 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1837 m_Parameters.m_CifgEnabled));
1838 if (!cifg_weights_all_or_none)
1839 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001840 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1841 "RecurrentToInputWeights must either both be present (regular LSTM) "
1842 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1843 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001844 }
1845
1846 if ( m_CellToInputWeights )
1847 {
1848 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1849 n_cell, "CellToInputWeights");
1850 }
1851 if ( m_CellToForgetWeights )
1852 {
1853 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1854 n_cell, "CellToForgetWeights");
1855 }
1856 if ( m_CellToOutputWeights )
1857 {
1858 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1859 n_cell, "CellToOutputWeights");
1860 }
1861
1862 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1863 bool peephole_weights_all_or_none =
1864 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1865 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1866 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1867 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1868 if (!peephole_weights_all_or_none)
1869 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001870 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001871 }
1872
1873 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1874 if (m_Parameters.m_CifgEnabled)
1875 {
1876 if (m_InputGateBias)
1877 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001878 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001879 }
1880 }
1881 else
1882 {
1883 if (!m_InputGateBias)
1884 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001885 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1886 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001887 }
1888 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1889 n_cell, "InputGateBias");
1890 }
1891
1892 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1893 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1894
1895 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1896 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1897
1898 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1899 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1900
1901 if (m_ProjectionWeights)
1902 {
1903 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
1904 (n_cell * n_output), "ProjectionWeights");
1905 }
1906 if (m_ProjectionBias)
1907 {
1908 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
1909 }
1910
1911 // Making sure the projection tensors are consistent:
1912 // 1) If projection weight is not present, then projection bias should not be
1913 // present.
1914 // 2) If projection weight is present, then projection bias is optional.
1915 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
1916 !m_Parameters.m_ProjectionEnabled)
1917 || (m_ProjectionWeights && !m_ProjectionBias &&
1918 m_Parameters.m_ProjectionEnabled)
1919 || (m_ProjectionWeights && m_ProjectionBias &&
1920 m_Parameters.m_ProjectionEnabled));
1921 if (!projecton_tensors_consistent)
1922 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001923 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001924 }
1925
1926 // The four layer normalization weights either all have values or none of them have values. Additionally, if
1927 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
1928 // either all have values or none of them have values. Layer normalization is used when the values of all the
1929 // layer normalization weights are present
1930 if (m_InputLayerNormWeights)
1931 {
1932 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
1933 }
1934 if (m_ForgetLayerNormWeights)
1935 {
1936 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1937 }
1938 if (m_CellLayerNormWeights)
1939 {
1940 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1941 }
1942 if (m_OutputLayerNormWeights)
1943 {
1944 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1945 }
1946
Jan Eilers38e05bd2019-06-26 13:10:09 +01001947 if (m_Parameters.m_LayerNormEnabled)
1948 {
1949 if (!m_Parameters.m_CifgEnabled)
1950 {
1951 if (!m_InputLayerNormWeights)
1952 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001953 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
1954 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001955 }
1956 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
1957 1, n_cell, "InputLayerNormWeights");
1958 }
1959 else if (m_InputLayerNormWeights)
1960 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001961 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
1962 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001963 }
1964
1965 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
1966 "ForgetLayerNormWeights");
1967 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1968
1969 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
1970 "OutputLayerNormWeights");
1971 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1972
1973 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
1974 "CellLayerNormWeights");
1975 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1976 }
1977 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
1978 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001979 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
1980 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001981 }
telsoa01c577f2c2018-08-31 09:22:23 +01001982}
1983
1984void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1985{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001986 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001987
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001988 ValidateNumInputs(workloadInfo, descriptorName, 1);
1989 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1990
1991 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1992 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1993
1994 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01001995 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001996 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01001997 }
1998
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001999 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002000 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002001 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002002 }
2003
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002004 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002005}
2006
2007void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2008{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002009 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002010
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002011 ValidateNumInputs(workloadInfo, descriptorName, 1);
2012 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2013
2014 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2015 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2016
2017 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002018 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002019 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002020 }
2021
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002022 if (outputTensorInfo.GetDataType() != DataType::Float32)
2023 {
2024 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2025 }
2026
2027 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002028}
2029
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002030void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2031{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002032 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002033
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002034 ValidateNumInputs(workloadInfo, descriptorName, 2);
2035 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2036
2037 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2038 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2039 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2040
2041 std::vector<DataType> supportedTypes =
2042 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002043 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002044 DataType::QAsymmU8,
2045 DataType::QSymmS16,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01002046 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002047 };
2048
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002049 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2050 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2051 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002052
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002053 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2054 inputTensorInfo1,
2055 outputTensorInfo,
2056 descriptorName,
2057 "input_0",
2058 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002059}
2060
David Beckc2044fe2018-09-05 15:00:38 +01002061void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2062{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002063 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002064
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002065 ValidateNumInputs(workloadInfo, descriptorName, 2);
2066 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2067
2068 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2069 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2070 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2071
2072 std::vector<DataType> supportedTypes =
2073 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002074 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002075 DataType::QAsymmU8,
2076 DataType::QSymmS16,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01002077 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002078 };
2079
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002080 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2081 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2082 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002083
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002084 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2085 inputTensorInfo1,
2086 outputTensorInfo,
2087 descriptorName,
2088 "input_0",
2089 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002090}
2091
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002092void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2093{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002094 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002095
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002096 ValidateNumInputs(workloadInfo, descriptorName, 2);
2097 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2098
2099 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2100 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2101 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2102
2103 std::vector<DataType> supportedTypes =
2104 {
Mike Kelly1da02362019-08-01 08:43:57 +01002105 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002106 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002107 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002108 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00002109 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002110 DataType::QSymmS16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002111 };
2112
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002113 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2114 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2115 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002116
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002117 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2118 inputTensorInfo1,
2119 outputTensorInfo,
2120 descriptorName,
2121 "input_0",
2122 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002123}
2124
narpra01a6bf9122018-09-10 09:50:09 +01002125void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2126{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002127 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002128
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002129 ValidateNumInputs(workloadInfo, descriptorName, 1);
2130 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2131
2132 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2133 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002134
2135 std::vector<DataType> supportedTypes =
2136 {
2137 DataType::Float32,
2138 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002139 DataType::QAsymmU8,
2140 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002141 };
narpra01eb061912018-09-10 17:35:27 +01002142
James Conroy4d1ff582019-06-10 17:06:39 +01002143 // First check if input tensor data type is supported, then
2144 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002145 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2146 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002147
narpra0132b90462018-09-13 11:07:48 +01002148 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002149 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002150 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002151 }
narpra0132b90462018-09-13 11:07:48 +01002152 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002153 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002154 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002155 }
2156 else
2157 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002158 unsigned int outputDim =
2159 inputTensorInfo.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
2160 ValidateTensorNumDimensions(outputTensorInfo,
2161 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002162 outputDim > 0 ? outputDim : 1,
2163 "output");
2164 }
narpra01a6bf9122018-09-10 09:50:09 +01002165}
2166
jimfly012c9322a2018-09-19 10:59:49 +01002167void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2168{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002169 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002170
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002171 ValidateNumInputs(workloadInfo, descriptorName, 1);
2172 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2173
2174 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2175 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002176
jimfly012c9322a2018-09-19 10:59:49 +01002177 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002178 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2179
jimfly012c9322a2018-09-19 10:59:49 +01002180 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002181 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2182 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2183 "as there are dimensions in the input tensor that is " +
2184 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2185 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002186 }
2187}
2188
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002189void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2190{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002191 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002192
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002193 ValidateNumInputs(workloadInfo, descriptorName, 1);
2194 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002195
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002196 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2197 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2198
Sadik Armagan2208b602019-07-31 16:36:27 +01002199 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002200 {
James Conroyd47a0642019-09-17 14:22:06 +01002201 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002202 DataType::Float16,
2203 DataType::QSymmS8,
2204 DataType::QAsymmU8,
2205 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002206 };
2207
2208 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002209
Derek Lambertif90c56d2020-01-10 17:14:08 +00002210 if (outputTensorInfo.GetDataType() != DataType::QAsymmU8 &&
Finn Williamsfd271062019-12-04 14:27:27 +00002211 outputTensorInfo.GetDataType() != DataType::QSymmS8 &&
Derek Lambertif90c56d2020-01-10 17:14:08 +00002212 outputTensorInfo.GetDataType() != DataType::QSymmS16)
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002213 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002214 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002215 }
2216}
2217
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002218void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2219{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002220 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002221
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002222 ValidateNumInputs(workloadInfo, descriptorName, 1);
2223 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002224
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002225 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2226 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002227
2228 std::vector<DataType> supportedTypes =
2229 {
James Conroyd47a0642019-09-17 14:22:06 +01002230 DataType::Float32,
2231 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002232 DataType::QAsymmU8,
2233 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002234 };
2235
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002236 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2237 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002238}
2239
Conor Kennedy430b5d82018-11-14 15:28:28 +00002240void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2241{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002242 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002243
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002244 ValidateNumInputs(workloadInfo, descriptorName, 1);
2245 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2246
2247 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2248 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002249
2250 std::vector<DataType> supportedTypes =
2251 {
2252 DataType::Float16,
2253 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002254 DataType::QAsymmU8,
2255 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002256 };
2257
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002258 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2259 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002260
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002261 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002262
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002263 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002264 if (rank > 4)
2265 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002266 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002267 }
2268
Conor Kennedy430b5d82018-11-14 15:28:28 +00002269 // Begin, End & Stride length must be of rank(input0)
2270 if (m_Parameters.m_Begin.size() != rank)
2271 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002272 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002273 }
2274
2275 if (m_Parameters.m_End.size() != rank)
2276 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002277 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002278 }
2279
2280 if (m_Parameters.m_Stride.size() != rank)
2281 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002282 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002283 }
2284
2285 // Stride entries must be non-zero
2286 for (auto& stride : m_Parameters.m_Stride)
2287 {
2288 if (stride == 0)
2289 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002290 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002291 }
2292 }
2293}
2294
kevmay0190539692018-11-29 08:40:19 +00002295void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2296{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002297 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002298
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002299 ValidateNumInputs(workloadInfo, descriptorName, 2);
2300 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2301
2302 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2303 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2304 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2305
2306 std::vector<DataType> supportedTypes =
2307 {
Mike Kelly1da02362019-08-01 08:43:57 +01002308 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002309 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002310 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002311 DataType::QAsymmU8,
2312 DataType::QSymmS16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002313 };
2314
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002315 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2316 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2317 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002318
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002319 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2320 inputTensorInfo1,
2321 outputTensorInfo,
2322 descriptorName,
2323 "input_0",
2324 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002325}
2326
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002327void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2328{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002329 const std::string descriptorName{"DebugQueueDescriptor"};
2330
2331 ValidateNumInputs(workloadInfo, descriptorName, 1);
2332 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002333}
2334
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002335void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2336{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002337 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002338
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002339 ValidateNumInputs(workloadInfo, descriptorName, 2);
2340 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002341
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002342 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2343 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2344 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2345
2346 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2347 inputTensorInfo1,
2348 outputTensorInfo,
2349 descriptorName,
2350 "input_0",
2351 "input_1");
2352
2353 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002354 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002355 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002356 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002357}
2358
FrancisMurtagh878f0232018-12-19 10:56:15 +00002359void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2360{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002361 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002362
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002363 ValidateNumInputs(workloadInfo, descriptorName, 2);
2364 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002365
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002366 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2367 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2368 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2369
2370 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2371 inputTensorInfo1,
2372 outputTensorInfo,
2373 descriptorName,
2374 "input_0",
2375 "input_1");
2376
2377 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002378 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002379 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002380 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002381}
2382
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002383void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2384{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002385 const std::string descriptorName{"RsqrtQueueDescriptor"};
2386
2387 ValidateNumInputs(workloadInfo, descriptorName, 1);
2388 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2389
2390 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2391 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2392
2393 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002394
2395 std::vector<DataType> supportedTypes =
2396 {
James Conroyd47a0642019-09-17 14:22:06 +01002397 DataType::Float16,
2398 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002399 DataType::QAsymmU8,
2400 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002401 };
2402
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002403 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2404 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002405}
2406
narpra01b89b05f2019-01-16 09:53:09 +00002407void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2408{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002409 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002410
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002411 ValidateNumInputs(workloadInfo, descriptorName, 2);
2412 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002413
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002414 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2415 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002416 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002417 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002418 }
2419
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002420 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2421 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2422
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002423 std::vector<DataType> supportedTypes =
2424 {
James Conroyd47a0642019-09-17 14:22:06 +01002425 DataType::Float16,
2426 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002427 DataType::QAsymmU8,
2428 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002429 };
2430
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002431 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002432
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002433 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002434
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002435 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2436 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002437}
2438
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002439void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2440{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002441 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2442
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002443 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002444
2445 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2446 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002447 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002448 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2449 }
2450
2451 if (m_Anchors == nullptr)
2452 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002453 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002454 }
2455
2456 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002457 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2458 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2459
2460 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002461 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002462 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2463 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002464
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002465 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2466 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2467 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002468
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002469 const std::vector<DataType> supportedInputTypes =
2470 {
2471 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002472 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002473 DataType::QAsymmU8,
2474 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002475 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002476
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002477 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2478 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2479 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2480
2481 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2482 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2483 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2484 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2485
2486 // NOTE: Output is always Float32 regardless of input type
2487 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2488 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2489 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2490 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002491
2492 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2493 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002494 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002495 "must be positive and less than or equal to 1.");
2496 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002497
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002498 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2499 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002500 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002501 "should be equal to number of classes + 1.");
2502 }
2503}
2504
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002505void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2506{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002507 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002508
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002509 ValidateNumInputs(workloadInfo, descriptorName, 1);
2510 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2511
2512 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2513 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2514
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002515 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002516 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002517 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002518 }
2519
Sadik Armagan2208b602019-07-31 16:36:27 +01002520 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002521 {
James Conroyd47a0642019-09-17 14:22:06 +01002522 DataType::Float32,
2523 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002524 };
2525
2526 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002527}
2528
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002529void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2530{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002531 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002532
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002533 ValidateNumInputs(workloadInfo, descriptorName, 2);
2534 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002535
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002536 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2537 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2538 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002539
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002540 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2541 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2542
2543 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2544 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002545}
2546
Sadik Armaganeff363d2019-04-05 15:25:46 +01002547void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2548{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002549 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002550
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002551 ValidateNumInputs(workloadInfo, descriptorName, 2);
2552 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2553
2554 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2555 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2556
2557 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2558 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2559
2560 std::vector<DataType> supportedTypes =
2561 {
Sadik Armaganeff363d2019-04-05 15:25:46 +01002562 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002563 DataType::QAsymmU8,
2564 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002565 };
2566
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002567 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2568 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002569
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002570 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2571 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002572
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002573 ValidateTensorShapesMatch(inputTensorInfo0,
2574 outputTensorInfo0,
2575 descriptorName,
2576 "input_0",
2577 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002578
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002579 ValidateTensorShapesMatch(inputTensorInfo0,
2580 outputTensorInfo1,
2581 descriptorName,
2582 "input_0",
2583 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002584}
2585
Derek Lamberti901ea112019-12-10 22:07:09 +00002586void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002587{
2588 // This is internally generated so it should not need validation.
2589}
2590
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002591void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2592{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002593 const std::string& descriptorName{"PreluQueueDescriptor"};
2594
2595 ValidateNumInputs(workloadInfo, descriptorName, 2);
2596 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2597
2598 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2599 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2600 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002601
2602 std::vector<DataType> supportedTypes
2603 {
2604 DataType::Float16,
2605 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002606 DataType::QAsymmU8,
2607 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002608 };
2609
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002610 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2611 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002612
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002613 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002614
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002615 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2616 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002617
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002618 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2619 alphaTensorInfo,
2620 outputTensorInfo,
2621 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002622 "input",
2623 "alpha");
2624}
2625
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002626void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2627{
2628 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2629
2630 ValidateNumInputs(workloadInfo, descriptorName, 1);
2631 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2632
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002633 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2634 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2635
2636 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2637 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002638
2639 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002640
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002641 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2642 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002643
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002644 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2645
2646 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002647 if (m_Parameters.m_BiasEnabled)
2648 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002649 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002650
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002651 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2652 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002653
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002654 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002655 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002656 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002657
2658 ValidatePerAxisQuantization(inputTensorInfo,
2659 outputTensorInfo,
2660 weightTensorInfo,
2661 optionalBiasTensorInfo,
2662 descriptorName);
2663
2664 std::vector<DataType> supportedTypes =
2665 {
2666 DataType::Float32,
2667 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002668 DataType::QAsymmU8,
2669 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002670 };
2671
2672 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2673 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002674}
2675
James Conroy9c3cae82019-08-01 16:01:48 +01002676void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2677{
2678 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
2679
2680 // Validate number of inputs/outputs
2681 ValidateNumInputs(workloadInfo, descriptorName, 3);
2682 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2683
2684 // Input/output tensor infos
2685 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2686 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
2687 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
2688
2689 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2690 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2691
2692 std::vector<DataType> inputOutputSupportedTypes =
2693 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002694 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01002695 };
2696
2697 std::vector<DataType> cellStateSupportedTypes =
2698 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002699 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01002700 };
2701
2702 std::vector<DataType> weightsSupportedTypes =
2703 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002704 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01002705 };
2706
2707 std::vector<DataType> biasSupportedTypes =
2708 {
2709 DataType::Signed32
2710 };
2711
2712 // Validate types of input/output tensors
2713 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2714 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2715 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2716
2717 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2718 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2719
2720 // Validate matching types of input/output tensors
2721 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2722 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2723 "outputStateIn", "outputStateOut");
2724 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2725
2726 // Validate matching quantization info for input/output tensors
2727 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2728 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
2729 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002730
James Conroy9c3cae82019-08-01 16:01:48 +01002731 // Infer number of batches, input size and output size from tensor dimensions
2732 const uint32_t numBatches = inputInfo.GetShape()[0];
2733 const uint32_t inputSize = inputInfo.GetShape()[1];
2734 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
2735
2736 // Validate number of dimensions and number of elements for input/output tensors
2737 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2738 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
2739 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2740 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
2741 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2742
2743 // Validate number of dimensions and number of elements for weights tensors
2744 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
2745 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
2746 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
2747
2748 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2749 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2750 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
2751
2752 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2753 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2754 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
2755
2756 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2757 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2758 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
2759
2760 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
2761 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
2762 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
2763
2764 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2765 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2766 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
2767 " RecurrentToForgetWeights");
2768
2769 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2770 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2771 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2772
2773 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2774 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2775 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2776
2777 // Validate data types for weights tensors (all should match each other)
2778 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
2779
2780 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
2781 "inputToInputWeights", "inputToForgetWeights");
2782 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
2783 "inputToInputWeights", "inputToCellWeights");
2784 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2785 "inputToInputWeights", "inputToOutputWeights");
2786
2787 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
2788 "inputToInputWeights", "recurrentToInputWeights");
2789 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2790 "inputToInputWeights", "recurrentToForgeteights");
2791 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2792 "inputToInputWeights", "recurrentToCellWeights");
2793 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2794 "inputToInputWeights", "recurrentToOutputWeights");
2795
2796 // Validate matching quantization info for weight tensors (all should match each other)
2797 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
2798 descriptorName, "inputToInputWeights", "inputToForgetWeights");
2799 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
2800 descriptorName, "inputToInputWeights", "inputToCellWeights");
2801 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
2802 descriptorName, "inputToInputWeights", "inputToOutputWeights");
2803
2804 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
2805 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
2806 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
2807 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
2808 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
2809 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
2810 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
2811 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
2812
2813 // Validate number of dimensions and number of elements in bias tensors
2814 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
2815 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
2816 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
2817
2818 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2819 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2820 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
2821
2822 ValidatePointer(m_CellBias, descriptorName, "CellBias");
2823 auto cellBiasInfo = m_CellBias->GetTensorInfo();
2824 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
2825
2826 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
2827 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
2828 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
2829
2830 // Validate data types for bias tensors (all should match each other)
2831 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
2832
2833 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
2834 "inputGateBias", "forgetGateBias");
2835 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
2836 "inputGateBias", "cellBias");
2837 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
2838 "inputGateBias", "outputGateBias");
2839
2840 // Validate bias tensor quantization info
2841 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2842 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2843 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2844 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2845}
2846
Kevin May868eb142019-09-04 17:29:31 +01002847void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2848{
2849 const std::string descriptorName{"AbsQueueDescriptor"};
2850
2851 ValidateNumInputs(workloadInfo, descriptorName, 1);
2852 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2853
2854 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2855 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2856
2857 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2858
2859 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01002860 {
2861 DataType::Float16,
2862 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002863 DataType::QAsymmU8,
2864 DataType::QSymmS16
James Conroyd47a0642019-09-17 14:22:06 +01002865 };
Kevin May868eb142019-09-04 17:29:31 +01002866
2867 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2868 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2869}
2870
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002871void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2872{
2873 const std::string descriptorName{"SliceQueueDescriptor"};
2874
2875 ValidateNumInputs(workloadInfo, descriptorName, 1);
2876 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2877
2878 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2879 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2880
2881 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2882
2883 const unsigned int rank = inputTensorInfo.GetNumDimensions();
2884 if (rank > 4)
2885 {
2886 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
2887 }
2888
2889 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
2890
2891 // Check if m_Begin and m_Size have the expected length
2892 if (m_Parameters.m_Begin.size() != rank)
2893 {
2894 throw InvalidArgumentException(descriptorName +
2895 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
2896 }
2897 if (m_Parameters.m_Size.size() != rank)
2898 {
2899 throw InvalidArgumentException(descriptorName +
2900 ": Length of size descriptor must equal rank " + std::to_string(rank));
2901 }
2902
2903 // Check if the shape of the output tensor matches m_Size
2904 const TensorShape& outputShape = outputTensorInfo.GetShape();
2905 for (unsigned int i = 0u; i < rank; ++i)
2906 {
2907 if (m_Parameters.m_Size[i] != outputShape[i])
2908 {
2909 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
2910 }
2911 }
2912
2913 // Check if the sum of begin offset and size in a given dimension
2914 // does not exceed the size of corresponding input
2915 const TensorShape& inputShape = inputTensorInfo.GetShape();
2916 for(unsigned int i = 0u; i < rank; ++i)
2917 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002918 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002919 {
2920 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
2921 std::to_string(i) + " exceeds input size.");
2922 }
2923 }
2924}
2925
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01002926void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2927{
2928 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
2929
2930 ValidateNumInputs(workloadInfo, descriptorName, 1);
2931 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2932
2933 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
2934 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
2935
2936 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
2937 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
2938
2939 std::vector<DataType> supportedTypes =
2940 {
2941 DataType::Float32,
2942 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002943 DataType::QAsymmU8,
2944 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01002945 };
2946
2947 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
2948 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
2949
2950 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
2951
2952 if (m_Parameters.m_BlockSize == 0)
2953 {
2954 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
2955 }
2956
2957 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
2958 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
2959 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
2960 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
2961
2962 const TensorShape& outputShape = outputInfo.GetShape();
2963 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
2964 {
2965 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
2966 "must be divisible by block size.");
2967 }
2968
2969 const TensorShape& inputShape = inputInfo.GetShape();
2970 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
2971 {
2972 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
2973 "must be divisible by the square of block size." );
2974 }
2975}
2976
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01002977void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2978{
2979 const std::string descriptorName{"ComparisonQueueDescriptor"};
2980
2981 ValidateNumInputs(workloadInfo, descriptorName, 2);
2982 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2983
2984 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2985 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2986 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2987
2988 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2989 inputTensorInfo1,
2990 outputTensorInfo,
2991 descriptorName,
2992 "input_0",
2993 "input_1");
2994
2995 if (outputTensorInfo.GetDataType() != DataType::Boolean)
2996 {
2997 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
2998 }
2999}
3000
josh minor4a3c6102020-01-06 16:40:46 -06003001void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3002{
3003 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3004
3005 ValidateNumInputs(workloadInfo, descriptorName, 1);
3006 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3007
3008 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3009 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3010
3011 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3012
3013 std::vector<DataType> supportedTypes =
3014 {
3015 DataType::Float16,
3016 DataType::Float32,
3017 DataType::QAsymmU8,
3018 DataType::QSymmS16
3019 };
3020
3021 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3022 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3023}
3024
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003025} // namespace armnn