blob: bb0c21ffba525864d38c302c9fa7e66f7da42909 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00006#include <backendsCommon/WorkloadData.hpp>
7#include <backendsCommon/CpuTensorHandle.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/TensorUtils.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000013#include <string>
14#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/format.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010017#include <boost/numeric/conversion/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
telsoa01c577f2c2018-08-31 09:22:23 +010029 case DataType::Float16:
30 return DataType::Float16;
telsoa014fcda012018-03-09 14:13:49 +000031 case DataType::Float32:
32 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000033 case DataType::QAsymmS8:
34 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000035 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000036 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000037 case DataType::QSymmS8:
38 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000039 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010040 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000041 default:
42 BOOST_ASSERT_MSG(false, "Invalid input data type");
43 return DataType::Float32;
44 }
45}
46
47namespace
48{
49
50//---------------------------------------------------------------
51//android ndk does not support std::to_string function.
52template <typename T>
53std::string to_string(T value)
54{
55 std::ostringstream os;
56 os << value;
57 return os.str();
58}
59
60//---------------------------------------------------------------
61void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
62{
63 if (!ptr)
64 {
65 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
66 paramName + " parameter must be set.");
67 }
68}
69
70//---------------------------------------------------------------
71void ValidateTensorShapesMatch(const TensorInfo& first,
72 const TensorInfo& second,
73 std::string const& descName,
74 std::string const& firstName,
75 std::string const& secondName)
76{
77 if (first.GetShape() != second.GetShape())
78 {
79 throw InvalidArgumentException(descName + ": "
80 + firstName + " & " + secondName + " must have identical shapes");
81 }
82}
83
84//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010085void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000086{
Sadik Armaganeff363d2019-04-05 15:25:46 +010087 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000088 {
89 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010090 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000091 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
92 }
93}
94
95//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010096void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000097{
Sadik Armaganeff363d2019-04-05 15:25:46 +010098 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000099 {
100 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100101 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000102 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
103 }
104}
105
106//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100107void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000108 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100109 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000110 std::string const& tensorName)
111{
112 if (tensor.GetNumDimensions() != numDimensions)
113 {
114 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
115 to_string(tensor.GetNumDimensions()) + " dimensions for " +
116 tensorName + " tensor.");
117 }
118}
119
120//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100121void ValidateTensorNumElements(const TensorInfo& tensor,
122 std::string const& descName,
123 unsigned int numElements,
124 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100125{
126 if (tensor.GetNumElements() != numElements)
127 {
128 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100129 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100130 tensorName + " tensor.");
131 }
132}
133
134//---------------------------------------------------------------
135void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100136 unsigned int numDimension,
137 unsigned int numElements,
138 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100139{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100140 const std::string functionName{"ValidateTensorNumDimNumElem"};
141 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
142 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100143}
144
145//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000146void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
147 const std::string& descName, std::string const& tensorName)
148{
149 if (tensor.GetDataType() != dataType)
150 {
151 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
152 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
153 }
154}
155
Derek Lambertid466a542020-01-22 15:37:29 +0000156void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
157{
158 ARMNN_NO_DEPRECATE_WARN_BEGIN
159 if (tensor.GetDataType() != DataType::QSymmS8 &&
160 tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
161 {
162 throw InvalidArgumentException(descName +
163 ": Expected data type which supports per-axis quantization scheme but got " +
164 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
165 }
166 ARMNN_NO_DEPRECATE_WARN_END
167}
168
telsoa014fcda012018-03-09 14:13:49 +0000169//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100170void ValidateTensorQuantizationSpace(const TensorInfo& first,
171 const TensorInfo& second,
172 const std::string& descName,
173 std::string const& firstName,
174 std::string const& secondName)
175{
176 if (!first.IsQuantized() ||
177 !second.IsQuantized())
178 {
179 // Not a quantized type, ignore the validation
180 return;
181 }
182
183 DataType firstDataType = first.GetDataType();
184 DataType secondDataType = second.GetDataType();
185
186 if (firstDataType != secondDataType)
187 {
188 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
189 " must be of the same quantized type, " +
190 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
191 secondName + " is " + GetDataTypeName(secondDataType));
192 }
193
194 if (!first.IsTypeSpaceMatch(second))
195 {
196 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
197 " must have the same quantization space, " +
198 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
199 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
200 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
201 " and scale " + to_string(second.GetQuantizationScale()));
202 }
203}
204
205//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100206void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
207 const TensorInfo& inputTensorInfo,
208 const TensorInfo& weightsTensorInfo,
209 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000210{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000211 // Helper lambda function to validate a single bias quantization scale value
212 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
213 {
ricbur013f4d7102019-10-31 16:22:18 +0000214 constexpr float tolerance = 0.000001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000215 if (std::abs(biasScale - expectedScale) > tolerance)
216 {
217 // Print the float values with extra precision to see very small differences
218 std::stringstream msg;
219 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
220 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
221 biasScale;
222 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
223 }
224 };
225
telsoa014fcda012018-03-09 14:13:49 +0000226 if (biasTensor.GetQuantizationOffset() != 0)
227 {
228 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
229 to_string(biasTensor.GetQuantizationOffset()));
230 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000231
232 if (biasTensor.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000233 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000234 // Validate per-axis quantization scales
235 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
236 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
237
238 if (weightScales.size() != biasScales.size())
239 {
240 std::stringstream msg;
241 msg << descName << ": Expected matchhing number of per-axis quantization scales, but got different "
242 << "values: weights=" << weightScales.size() << ", biases=" << biasScales.size();
243 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
244 }
245
246 for (size_t i = 0ul; i < biasScales.size(); ++i)
247 {
248 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
249 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
250 }
251 }
252 else
253 {
254 // Validate per-tensor quantization scale
255 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
256 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000257 }
258}
259
260//---------------------------------------------------------------
261void ValidateTensors(const std::vector<ITensorHandle*>& vec,
262 unsigned int numExpected,
263 const std::string& descName,
264 const std::string& varName)
265{
266 if (vec.empty() && numExpected > 0)
267 {
268 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
269 }
270
271 for (unsigned int i = 0; i < numExpected; ++i)
272 {
273 if (!vec[i])
274 {
275 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
276 }
277 }
278}
279
280//---------------------------------------------------------------
281void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
282 const TensorInfo& second,
283 const TensorInfo& output,
284 std::string const& descName,
285 std::string const& firstName,
286 std::string const& secondName)
287{
288 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
289 // broadcasted.
290 if (first.GetNumDimensions() != second.GetNumDimensions())
291 {
292 throw InvalidArgumentException(descName + ": Tensors "
293 + firstName + " & " + secondName
294 + " must have the same number of dimensions in order to be broadcasted");
295 }
296 uint32_t numDims = first.GetNumDimensions();
297 std::vector<uint32_t> outputDims(numDims, 0u);
298 for (uint32_t i = 0; i < numDims; i++)
299 {
300 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
301 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
302 if (dimsNotEqual && dimsNotOne)
303 {
304 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
305 }
306 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
307 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100308 TensorShape broadcastShape = TensorShape(boost::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000309 if (broadcastShape != output.GetShape())
310 {
311 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
312 + firstName + " & " + secondName
313 + " does not match the output shape");
314 }
315}
316
317//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100318void ValidateDataTypes(const TensorInfo& info,
319 const std::vector<armnn::DataType>& supportedTypes,
320 std::string const& descName)
321{
322 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
323 if (iterator == supportedTypes.end())
324 {
325 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
326 }
327}
328
James Conroy4d1ff582019-06-10 17:06:39 +0100329//---------------------------------------------------------------
330void ValidateTensorDataTypesMatch(const TensorInfo& first,
331 const TensorInfo& second,
332 std::string const& descName,
333 std::string const& firstName,
334 std::string const& secondName)
335{
336 if (first.GetDataType() != second.GetDataType())
337 {
338 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
339 " must have identical data types.");
340 }
341}
342
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100343//---------------------------------------------------------------
344void ValidateTensorNumElementsMatch(const TensorInfo& first,
345 const TensorInfo& second,
346 std::string const& descName,
347 std::string const& firstName,
348 std::string const& secondName)
349{
350 if (first.GetNumElements() != second.GetNumElements())
351 {
352 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
353 " must have the same number of elements.");
354 }
355}
356
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000357void ValidateWeightDataType(const TensorInfo& inputInfo,
358 const TensorInfo& weightInfo,
359 const std::string& descName)
360{
361 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000362 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000363 {
Derek Lambertid466a542020-01-22 15:37:29 +0000364 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000365 const std::vector<DataType> validTypes =
366 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000367 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000368 DataType::QAsymmS8,
Derek Lambertid466a542020-01-22 15:37:29 +0000369 DataType::QSymmS8,
370 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000371 };
Derek Lambertid466a542020-01-22 15:37:29 +0000372 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000373
374 ValidateDataTypes(weightInfo, validTypes, descName);
375 }
376 else
377 {
378 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
379 }
380}
381
382void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
383 const std::string& descName,
384 const std::string& tensorName)
385{
386 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
387 if (!quantizationDim.has_value())
388 {
389 throw InvalidArgumentException(boost::str(
390 boost::format("%1%: Quantization dimension for per-axis quantization not set on tensor %2%.")
391 % descName % tensorName));
392 }
393
394 if (quantizationDim.value() != 0)
395 {
396 throw InvalidArgumentException(boost::str(
397 boost::format("%1%: Quantization dimension for per-axis quantization expected to be 0 on tensor %2%, "
398 "but got: %3%") % descName % tensorName % quantizationDim.value()));
399 }
400}
401
402void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
403 const std::string& descName,
404 const std::string& tensorName)
405{
406 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
407 if (quantizationOffset != 0)
408 {
409 throw InvalidArgumentException(boost::str(
410 boost::format("%1%: Quantization offset for per-axis quantization expected to be 0 on tensor %2%, "
411 "but got: %3%") % descName % tensorName % quantizationOffset));
412 }
413}
414
415void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
416 const TensorInfo& outputInfo,
417 const TensorInfo& weightInfo,
418 const Optional<TensorInfo>& optionalBiasInfo,
419 const std::string& descName)
420{
421 if (weightInfo.HasPerAxisQuantization())
422 {
423 const DataType inputDataType = inputInfo.GetDataType();
424 const DataType outputDataType = outputInfo.GetDataType();
425
Keith Davis0c2eeac2020-02-11 16:51:50 +0000426 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000427
428 if (!canHavePerAxisQuantization)
429 {
430 throw InvalidArgumentException(boost::str(
431 boost::format("%1%: Per-axis quantization parameters set on tensor %2%, "
432 "but data type does not support per-axis quantization.") % descName % "weight"));
433 }
434
Derek Lambertid466a542020-01-22 15:37:29 +0000435
436 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000437 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
438 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
439
440 if (optionalBiasInfo.has_value())
441 {
442 const TensorInfo& biasInfo = optionalBiasInfo.value();
443 if (!biasInfo.HasPerAxisQuantization())
444 {
445 throw InvalidArgumentException(boost::str(
446 boost::format("%1%: Per-axis quantization parameters not set on bias tensor, despite being set on "
447 "weight tensor.") % descName));
448 }
449
450 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
451 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
452 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
453 }
454 }
455}
456
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100457} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000458
459void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
460 unsigned int numExpectedIn, unsigned int numExpectedOut) const
461{
462 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
463 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
464}
465
466//---------------------------------------------------------------
467void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
468{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100469 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000470
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100471 ValidateNumInputs(workloadInfo, descriptorName, 1);
472 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000473
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100474 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
475 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
476
477 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
478 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000479
480 if (m_Inputs.size() != m_Outputs.size())
481 {
482 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100483 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
484 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000485 }
486
487 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
488 {
489 if (!m_Inputs[i])
490 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100491 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
492 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000493 }
494
495 if (!m_Outputs[i])
496 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100497 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
498 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000499 }
500 }
501}
502
Derek Lambertif674aa02019-08-01 15:56:25 +0100503//---------------------------------------------------------------
504void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
505{
506 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
507 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
508
509 if (workloadInfo.m_InputTensorInfos.size() != 1)
510 {
511 throw InvalidArgumentException(boost::str(
512 boost::format("Number of input infos (%1%) is not 1.")
513 % workloadInfo.m_InputTensorInfos.size()));
514
515 }
516
517 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
518 {
519 throw InvalidArgumentException(boost::str(
520 boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
521 % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
522 }
523
524 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
525 {
526 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
527 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
528 {
529 throw InvalidArgumentException(boost::str(
530 boost::format("Number of elements for tensor input and output %1% does not match")
531 % i ));
532 }
533 }
534
535 if (m_Inputs.size() != 1)
536 {
537 throw InvalidArgumentException(boost::str(
538 boost::format("Number of inputs (%1%) is not 1.")
539 % m_Inputs.size()));
540 }
541
542 if (m_Inputs.size() != m_Outputs.size())
543 {
544 throw InvalidArgumentException(boost::str(
545 boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
546 % m_Inputs.size() % m_Outputs.size()));
547 }
548
549 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
550 {
551 if (!m_Inputs[i])
552 {
553 throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
554 }
555
556 if (!m_Outputs[i])
557 {
558 throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
559 }
560 }
561}
562
563//---------------------------------------------------------------
564void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
565{
566 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
567 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
568
Derek Lambertif674aa02019-08-01 15:56:25 +0100569 if (m_Inputs.size() != 1)
570 {
571 throw InvalidArgumentException(boost::str(
572 boost::format("Number of inputs (%1%) is not 1.")
573 % m_Inputs.size()));
574 }
575
576 if (m_Outputs.size() != 0)
577 {
578 throw InvalidArgumentException(boost::str(
579 boost::format("Number of outputs (%1%) is not 0.")
580 % m_Inputs.size() % m_Outputs.size()));
581 }
582
583 if (!m_Inputs[0])
584 {
585 throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
586 }
587}
588
589//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000590void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
591{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100592 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100593
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100594 ValidateNumInputs(workloadInfo, descriptorName, 1);
595 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100596
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100597 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
598 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100599
600 std::vector<DataType> supportedTypes =
601 {
James Conroyd47a0642019-09-17 14:22:06 +0100602 DataType::Float16,
603 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000604 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000605 DataType::QAsymmU8,
606 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100607 };
608
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100609 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
610 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
611 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000612}
613
Nikhil Rajee391d52019-09-05 17:50:44 +0100614void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
615{
616 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
617
618 ValidateNumInputs(workloadInfo, descriptorName, 1);
619 ValidateNumOutputs(workloadInfo, descriptorName, 1);
620
621 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
622 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
623
Nikhil Raj68c2c902019-09-19 11:21:11 +0100624 if (outputTensorInfo.GetDataType() != DataType::Signed32)
625 {
626 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32.");
627 }
628
James Conroyd47a0642019-09-17 14:22:06 +0100629 std::vector<DataType> supportedInputTypes =
630 {
631 DataType::Float16,
632 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000633 DataType::QAsymmU8,
634 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000635 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +0100636 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100637
James Conroyd47a0642019-09-17 14:22:06 +0100638 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100639
640 auto inputShape = inputTensorInfo.GetShape();
641 auto outputShape = outputTensorInfo.GetShape();
642
643 auto inputNumDimensions = inputShape.GetNumDimensions();
644 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
645
646 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
647
648 // 1D input shape results in scalar output shape
649 if (inputShape.GetNumDimensions() == 1)
650 {
651 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
652 {
653 throw InvalidArgumentException(descriptorName + outputShapeError);
654 }
655 }
656 else
657 {
658 for (unsigned int i = 0; i < unsignedAxis; ++i)
659 {
660 if (outputShape[i] != inputShape[i])
661 {
662 throw InvalidArgumentException(descriptorName + outputShapeError);
663 }
664 }
665
666 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
667 {
668 if (outputShape[i - 1] != inputShape[i])
669 {
670 throw InvalidArgumentException(descriptorName + outputShapeError);
671 }
672 }
673 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100674}
675
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100676void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
677{
678 const std::string descriptorName{"SoftmaxQueueDescriptor"};
679
680 ValidateNumInputs(workloadInfo, descriptorName, 1);
681 ValidateNumOutputs(workloadInfo, descriptorName, 1);
682
683 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
684 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
685
686 std::vector<DataType> supportedTypes =
687 {
James Conroyd47a0642019-09-17 14:22:06 +0100688 DataType::Float16,
689 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000690 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000691 DataType::QAsymmU8,
692 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100693 };
694
695 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
696 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
697 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
698}
699
telsoa014fcda012018-03-09 14:13:49 +0000700void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
701{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100702 const std::string descriptorName{"SplitterQueueDescriptor"};
703
704 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000705
Ruomei Yan25339c32019-05-28 16:48:20 +0100706 // Check the supported data types
707 std::vector<DataType> supportedTypes =
708 {
James Conroyd47a0642019-09-17 14:22:06 +0100709 DataType::Float32,
710 DataType::Float16,
711 DataType::Boolean,
712 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000713 DataType::QAsymmU8,
714 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100715 };
716
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100717 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
718 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100719 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100720 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
721 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
722
723 const std::string outputName = "output_" + std::to_string(i);
724 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100725 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100726
telsoa014fcda012018-03-09 14:13:49 +0000727 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
728 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100729 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000730 }
731
732 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
733 {
734 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100735 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000736 "has to match number of workloadInfo.m_OutputTensorInfos. "
737 "Number of windows: " +
738 to_string(m_ViewOrigins.size()) +
739 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
740 }
741
telsoa01c577f2c2018-08-31 09:22:23 +0100742 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000743 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
744 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
745 {
telsoa01c577f2c2018-08-31 09:22:23 +0100746 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000747 ViewOrigin const& e = m_ViewOrigins[w];
748 if (e.m_Origin.size() != inputDims)
749 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100750 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000751 "have the same dimensionality as the input tensor. "
752 "Window origin (index: " +
753 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
754 " dimensions, the input "
755 "tensor has " +
756 to_string(inputDims) + " dimensions.");
757 }
758 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
759 {
760 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
761 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
762 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100763 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000764 "be smaller or equal than the size of the input in that coord.");
765 }
766 }
767 }
768}
769
Jim Flynne242f2d2019-05-22 14:24:13 +0100770void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000771{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100772 const std::string descriptorName{"ConcatQueueDescriptor"};
773
774 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000775
776 if (m_Inputs.size() <= 0)
777 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100778 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000779 }
780 if (m_Outputs.size() <= 0)
781 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100782 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000783 }
784
785 if (workloadInfo.m_InputTensorInfos.size() <= 0)
786 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100787 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000788 }
789 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
790 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100791 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000792 }
793
Nikhil Raj8599a412018-11-19 14:51:07 +0000794 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
795 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100796 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000797 }
798
799 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
800 {
801 return;
802 }
803
telsoa014fcda012018-03-09 14:13:49 +0000804 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
805 {
806 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100807 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000808 "has to match number of workloadInfo.m_InputTensorInfos. "
809 "Number of windows: " +
810 to_string(m_ViewOrigins.size()) +
811 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
812 }
813
telsoa01c577f2c2018-08-31 09:22:23 +0100814 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000815 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
816 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
817 {
telsoa01c577f2c2018-08-31 09:22:23 +0100818 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000819 ViewOrigin const& e = m_ViewOrigins[w];
820 if (e.m_Origin.size() != outputDims)
821 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100822 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000823 "have the same dimensionality as the output tensor. "
824 "Window origin (index: " +
825 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
826 " dimensions, the output "
827 "tensor has " +
828 to_string(outputDims) + " dimensions.");
829 }
telsoa01c577f2c2018-08-31 09:22:23 +0100830 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000831 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
832 {
833 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
834 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
835 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100836 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000837 "be smaller or equal than the size of the output in that coord.");
838 }
839 }
840 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100841
842 // Check the supported data types
843 std::vector<DataType> supportedTypes =
844 {
James Conroyd47a0642019-09-17 14:22:06 +0100845 DataType::Float32,
846 DataType::Float16,
847 DataType::Boolean,
848 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000849 DataType::QAsymmU8,
850 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100851 };
852
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100853 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
854 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100855 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100856 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
857 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
858
859 const std::string inputName = "input_" + std::to_string(i);
860 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100861 }
telsoa014fcda012018-03-09 14:13:49 +0000862}
863
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100864void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
865{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100866 const std::string descriptorName{"StackQueueDescriptor"};
867
868 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100869
870 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
871 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100872 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100873 }
874
875 // All inputs must have the same shape, which is defined in parameters
876 const TensorShape& inputShape = m_Parameters.m_InputShape;
877 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
878 {
879 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
880 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100881 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100882 }
883 }
884
Matthew Jacksondba634f2019-08-15 15:14:18 +0100885 if (inputShape.GetNumDimensions() > 4)
886 {
887 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
888 }
889
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100890 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
891 // since the output tensor has an additional dimension.
892 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
893 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100894 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100895 "than the number of input dimensions.");
896 }
897
898 // Output shape must be as inferred from the input shape
899 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
900 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
901 {
902 if (outputShape[i] != inputShape[i])
903 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100904 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100905 "match shape inferred from input tensor.");
906 }
907 }
908
909 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
910 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100911 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100912 "match shape inferred from input tensor.");
913 }
914
915 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
916 {
917 if (outputShape[i] != inputShape[i-1])
918 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100919 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100920 "match shape inferred from input tensor.");
921 }
922 }
923
Matthew Jacksondba634f2019-08-15 15:14:18 +0100924 if (outputShape.GetNumDimensions() > 5)
925 {
926 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
927 }
928
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100929 // Check the supported data types
930 std::vector<DataType> supportedTypes =
931 {
James Conroyd47a0642019-09-17 14:22:06 +0100932 DataType::Float32,
933 DataType::Float16,
934 DataType::Boolean,
935 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000936 DataType::QAsymmU8,
937 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100938 };
939
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100940 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100941
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100942 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100943 {
944 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
945 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100946 descriptorName,
947 "input_0",
948 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100949 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100950
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100951 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
952 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100953 descriptorName,
954 "input_0",
955 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100956}
957
telsoa014fcda012018-03-09 14:13:49 +0000958void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
959{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100960 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000961
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100962 ValidateNumInputs(workloadInfo, descriptorName, 1);
963 ValidateNumOutputs(workloadInfo, descriptorName, 1);
964
965 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
966 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
967
968 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
969
970 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +0000971 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100972 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +0000973 }
974
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100975 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000976
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100977 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
978 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000979
980 if (m_Parameters.m_BiasEnabled)
981 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100982 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000983
telsoa01c577f2c2018-08-31 09:22:23 +0100984 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100985 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
986 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000987
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100988 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
989 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000990 }
991
Francis Murtagh46c09d02019-05-28 08:15:28 +0100992 // Check the supported data types
993 std::vector<DataType> supportedTypes =
994 {
James Conroyd47a0642019-09-17 14:22:06 +0100995 DataType::Float32,
996 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +0000997 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000998 DataType::QAsymmU8,
999 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001000 };
1001
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001002 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1003 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001004}
1005
telsoa014fcda012018-03-09 14:13:49 +00001006void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1007{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001008 const std::string descriptorName{"NormalizationQueueDescriptor"};
1009
1010 ValidateNumInputs(workloadInfo, descriptorName, 1);
1011 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1012
1013 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1014 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001015
1016 // Check the supported data types
1017 std::vector<DataType> supportedTypes =
1018 {
1019 DataType::Float16,
1020 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001021 DataType::QAsymmU8,
1022 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001023 };
1024
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001025 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001026
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001027 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001028
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001029 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001030}
1031
1032void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1033{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001034 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001035
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001036 ValidateNumInputs(workloadInfo, descriptorName, 2);
1037 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1038
1039 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1040 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1041 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1042
1043 std::vector<DataType> supportedTypes =
1044 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001045 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001046 DataType::Float16,
1047 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001048 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001049 DataType::QSymmS16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001050 };
1051
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001052 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1053 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1054 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001055
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001056 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1057 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001058
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001059 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1060 inputTensorInfo1,
1061 outputTensorInfo,
1062 descriptorName,
1063 "input_0",
1064 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001065}
1066
telsoa014fcda012018-03-09 14:13:49 +00001067void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1068{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001069 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001070
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001071 ValidateNumInputs(workloadInfo, descriptorName, 2);
1072 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1073
1074 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1075 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1076 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1077
1078 std::vector<DataType> supportedTypes =
1079 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001080 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001081 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001082 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001083 DataType::QSymmS16,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001084 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001085 };
1086
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001087 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1088 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1089 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001090
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001091 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1092 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001093
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001094 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1095 inputTensorInfo1,
1096 outputTensorInfo,
1097 descriptorName,
1098 "input_0",
1099 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001100}
1101
1102void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1103{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001104 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001105
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001106 ValidateNumInputs(workloadInfo, descriptorName, 1);
1107 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1108
1109 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1110 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001111
1112 std::vector<DataType> supportedTypes =
1113 {
1114 DataType::Float16,
1115 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001116 DataType::QAsymmU8,
1117 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001118 };
1119
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001120 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1121 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001122
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001123 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001124 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001125
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001126 ValidatePointer(m_Mean, descriptorName, "mean");
1127 ValidatePointer(m_Variance, descriptorName, "variance");
1128 ValidatePointer(m_Beta, descriptorName, "beta");
1129 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001130
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001131 const TensorInfo& mean = m_Mean->GetTensorInfo();
1132 const TensorInfo& variance = m_Variance->GetTensorInfo();
1133 const TensorInfo& beta = m_Beta->GetTensorInfo();
1134 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001135
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001136 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1137 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1138 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1139 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001140
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001141 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1142 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1143 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001144}
1145
1146void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1147{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001148 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001149
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001150 ValidateNumInputs(workloadInfo, descriptorName, 1);
1151 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001152
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001153 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1154 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001155
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001156 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1157 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001158
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001159 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001160
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001161 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1162 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001163
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001164 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001165
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001166 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001167 if (m_Parameters.m_BiasEnabled)
1168 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001169 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001170
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001171 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1172 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001173
1174 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1175 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001176 }
1177
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001178 ValidatePerAxisQuantization(inputTensorInfo,
1179 outputTensorInfo,
1180 weightTensorInfo,
1181 optionalBiasTensorInfo,
1182 descriptorName);
1183
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001184 std::vector<DataType> supportedTypes =
1185 {
Ruomei Yan88d44b82019-05-23 14:29:06 +01001186 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001187 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001188 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001189 DataType::QSymmS16,
Keith Davis5204aa82020-01-27 15:24:59 +00001190 DataType::QSymmS8,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001191 DataType::Float16
1192 };
1193
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001194 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1195 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1196}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001197
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001198void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1199{
1200 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1201
1202 ValidateNumInputs(workloadInfo, descriptorName, 1);
1203 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1204
1205 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1206 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1207
1208 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1209 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1210
1211 ValidatePointer(m_Weight, descriptorName, "weight");
1212
1213 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1214 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1215
1216 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1217 {
1218 throw InvalidArgumentException(
1219 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
1220 "cannot be smaller than 1.") % descriptorName %
1221 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
1222 }
1223
1224 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1225
1226 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1227 // inputChannels * channelMultiplier should be equal to outputChannels.
1228 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1229 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1230 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1231 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1232 {
1233 throw InvalidArgumentException(
1234 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
1235 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
1236 "(provided %4%).") % descriptorName % numWeightOutputChannels %
1237 numWeightInputChannels % numWeightChannelMultiplier));
1238 }
1239
Teresa Charlind8df0262019-11-11 12:28:15 +00001240 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001241
Teresa Charlind8df0262019-11-11 12:28:15 +00001242 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001243 if (m_Parameters.m_BiasEnabled)
1244 {
1245 ValidatePointer(m_Bias, descriptorName, "bias");
1246
Teresa Charlind8df0262019-11-11 12:28:15 +00001247 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1248 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001249
1250 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1251 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1252 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001253 ValidatePerAxisQuantization(inputTensorInfo,
1254 outputTensorInfo,
1255 weightTensorInfo,
1256 optionalBiasTensorInfo,
1257 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001258
1259 std::vector<DataType> supportedTypes =
1260 {
1261 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001262 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001263 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001264 DataType::QSymmS16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001265 DataType::Float16
1266 };
1267
1268 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1269 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001270}
1271
1272void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1273{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001274 const std::string descriptorName{"PermuteQueueDescriptor"};
1275
1276 ValidateNumInputs(workloadInfo, descriptorName, 1);
1277 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001278
1279 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1280
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001281 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1282 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001283
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001284 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1285 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001286
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001287 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001288 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001289 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001290 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001291 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1292 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1293 "must match dst dimension " + to_string(mapping[i]) +
1294 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001295 }
1296 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001297
1298 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001299}
1300
1301void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1302{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001303 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001304
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001305 ValidateNumInputs(workloadInfo, descriptorName, 1);
1306 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1307
1308 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1309 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1310
1311 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1312 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001313
1314 std::vector<DataType> supportedTypes =
1315 {
1316 DataType::Float32,
1317 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001318 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001319 DataType::QAsymmU8,
1320 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001321 };
1322
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001323 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1324 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001325}
1326
1327void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1328{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001329 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001330
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001331 ValidateNumInputs(workloadInfo, descriptorName, 1);
1332 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1333
1334 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1335 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1336
1337 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1338 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001339
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001340 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001341 {
1342 DataType::Float16,
1343 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001344 DataType::QAsymmU8,
1345 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001346 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001347
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001348 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1349 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001350
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001351 // ResizeBilinear only changes width and height: batch and channel count must match.
1352 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1353 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001354 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001355 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001356 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001357 boost::str(boost::format("%1%: Input batch size (%2%) "
1358 "does not match output batch size (%3%)") %
1359 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001360 }
1361
Teresa Charlin970f43b2019-07-01 13:51:07 +01001362 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001363 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1364 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001365 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001366 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001367 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001368 boost::str(boost::format("%1%: Input channel count (%2%) "
1369 "does not match output channel count (%3%)") %
1370 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001371 }
1372}
1373
1374void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1375{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001376 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001377
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001378 ValidateNumInputs(workloadInfo, descriptorName, 1);
1379 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1380
1381 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1382 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1383
1384 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1385 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001386
1387 std::vector<DataType> supportedTypes =
1388 {
1389 DataType::Float16,
1390 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001391 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001392 DataType::QAsymmU8,
1393 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001394 };
1395
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001396 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1397 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001398
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001399 // Resize only changes width and height: batch and channel count must match.
1400 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1401 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001402 if (inputBatchSize != outputBatchSize)
1403 {
1404 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001405 boost::str(boost::format("%1%: Input batch size (%2%) "
1406 "does not match output batch size (%3%)") %
1407 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001408 }
1409
1410 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001411 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1412 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001413 if (inputChannelCount != outputChannelCount)
1414 {
1415 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001416 boost::str(boost::format("%1%: Input channel count (%2%) "
1417 "does not match output channel count (%3%)") %
1418 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001419 }
1420}
1421
1422void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1423{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001424 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001425
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001426 ValidateNumInputs(workloadInfo, descriptorName, 1);
1427 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1428
1429 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1430 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1431
1432 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1433 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1434
1435 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1436
telsoa014fcda012018-03-09 14:13:49 +00001437 if (m_Parameters.m_Min > m_Parameters.m_Max)
1438 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001439 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001440 }
telsoa014fcda012018-03-09 14:13:49 +00001441}
1442
Kevin Mayce5045a2019-10-02 14:07:47 +01001443void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1444{
1445 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1446
1447 ValidateNumInputs(workloadInfo, descriptorName, 1);
1448 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1449
1450 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1451 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1452
1453 if (inputTensorInfo.GetNumDimensions() > 4)
1454 {
1455 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1456 }
1457
1458 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1459
1460 // Check the supported data types
1461 std::vector<DataType> supportedTypes =
1462 {
1463 DataType::Float32,
1464 DataType::Float16
1465 };
1466
1467 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001468 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001469}
1470
telsoa014fcda012018-03-09 14:13:49 +00001471void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1472{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001473 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001474
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001475 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001476 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1477
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001478 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1479 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1480
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001481 if (inputTensorInfo.GetNumDimensions() > 4)
1482 {
1483 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1484 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001485
1486 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001487
1488 // Check the supported data types
1489 std::vector<DataType> supportedTypes =
1490 {
1491 DataType::Float32,
1492 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001493 DataType::QAsymmU8,
1494 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001495 };
1496
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001497 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001498 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1499}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001500
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001501void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1502{
1503 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1504
1505 ValidateNumInputs(workloadInfo, descriptorName, 1);
1506 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1507
1508 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1509 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1510
1511 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1512
1513 std::vector<DataType> supportedTypes =
1514 {
1515 DataType::Float32,
1516 DataType::Float16,
1517 };
1518
1519 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001520 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001521}
1522
1523void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1524{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001525 const std::string descriptorName{"ConstantQueueDescriptor"};
1526
1527 ValidateNumInputs(workloadInfo, descriptorName, 0);
1528 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001529
1530 if (!m_LayerOutput)
1531 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001532 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001533 }
1534
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001535 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1536 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001537
1538 // Check the supported data types
1539 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001540 {
1541 DataType::Float32,
1542 DataType::Float16,
1543 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001544 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001545 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +00001546 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001547 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001548 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001549
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001550 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001551}
1552
1553void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1554{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001555 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001556
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001557 ValidateNumInputs(workloadInfo, descriptorName, 1);
1558 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1559
1560 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1561 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1562
1563 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001564
1565 // Check the supported data types
1566 std::vector<DataType> supportedTypes =
1567 {
1568 DataType::Float32,
1569 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001570 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001571 DataType::QSymmS16,
1572 DataType::QAsymmS8,
Keith Davis67e6c542020-02-19 10:08:33 +00001573 DataType::QAsymmU8
Nina Drozd2f2778f2019-05-27 10:37:05 +01001574 };
1575
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001576 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1577 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001578}
1579
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001580void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1581{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001582 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001583
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001584 ValidateNumInputs(workloadInfo, descriptorName, 1);
1585 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1586
1587 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1588 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1589
1590 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1591 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001592
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001593 if (m_Parameters.m_BlockShape.size() != 2)
1594 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001595 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001596 }
1597
1598 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1599 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001600 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1601 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001602 }
1603
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001604 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001605
1606 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001607 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001608
Matthew Bentham8800c002018-11-19 13:19:28 +00001609 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001610
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001611 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1612 widthPad.first + widthPad.second;
1613 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1614 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001615
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001616 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1617 inputShape[dimensionIndices.GetChannelsIndex()];
1618 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001619
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001620 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001621 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001622 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001623 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001624 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001625 }
1626
1627 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001628 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001629 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1630 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001631 }
nikraj01120522a2019-05-31 11:33:07 +01001632
1633 std::vector<DataType> supportedTypes =
1634 {
1635 DataType::Float16,
1636 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001637 DataType::QAsymmU8,
1638 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001639 };
1640
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001641 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1642 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001643}
1644
Keith Davisa57eccb2019-06-14 17:33:22 +01001645void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1646{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001647 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001648
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001649 ValidateNumInputs(workloadInfo, descriptorName, 1);
1650 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001651
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001652 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1653 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1654
1655 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1656 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001657
1658 std::vector<DataType> supportedTypes =
1659 {
1660 DataType::Float32,
1661 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001662 DataType::QAsymmU8,
1663 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001664 };
1665
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001666 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1667 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001668
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001669 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1670
1671 if (m_Parameters.m_BlockSize == 0)
1672 {
1673 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1674 }
1675
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001676 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1677 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1678 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1679 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001680
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001681 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001682 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001683 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001684 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1685 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001686 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001687
1688 const TensorShape& outputShape = outputTensorInfo.GetShape();
1689 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1690 {
1691 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1692 "must be divisible by the square of block size." );
1693 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001694}
1695
telsoa014fcda012018-03-09 14:13:49 +00001696void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1697{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001698 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001699
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001700 ValidateNumInputs(workloadInfo, descriptorName, 1);
1701 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1702
1703 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1704 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001705
1706 std::vector<DataType> supportedTypes =
1707 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001708 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001709 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001710 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001711 };
1712
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001713 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001714
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001715 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001716 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001717 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001718 }
1719}
1720
telsoa01c577f2c2018-08-31 09:22:23 +01001721void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1722{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001723 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1724
1725 const std::string descriptorName{"LstmQueueDescriptor"};
1726
1727 // check dimensions of all inputs and outputs
1728 if (workloadInfo.m_InputTensorInfos.size() != 3)
1729 {
1730 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1731 }
1732 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1733 {
1734 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1735 }
1736
1737 std::vector<DataType> supportedTypes =
1738 {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001739 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001740 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001741 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001742 };
1743
Jan Eilers38e05bd2019-06-26 13:10:09 +01001744 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001745 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1746
Jan Eilers38e05bd2019-06-26 13:10:09 +01001747 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001748 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001749 {
1750 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1751 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001752 descriptorName,
1753 "input_0",
1754 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001755 }
1756 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001757 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001758 {
1759 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1760 workloadInfo.m_OutputTensorInfos[i],
1761 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001762 "input_0",
1763 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001764 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001765
janeil0117d8d852019-11-15 15:00:16 +00001766 // Making sure clipping parameters have valid values.
1767 // == 0 means no clipping
1768 // > 0 means clipping
1769 if (m_Parameters.m_ClippingThresCell < 0.0f)
1770 {
1771 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
1772 }
1773 if (m_Parameters.m_ClippingThresProj < 0.0f)
1774 {
1775 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
1776 }
1777
Jan Eilers38e05bd2019-06-26 13:10:09 +01001778
1779 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01001780 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1781 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1782 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1783 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1784 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1785 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1786
Jan Eilers38e05bd2019-06-26 13:10:09 +01001787 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001788 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1789 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001790 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001791 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1792 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001793 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001794 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1795 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001796 // scratchBufferTensor
1797 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001798 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1799 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001800 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001801 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1802 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001803 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001804 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1805 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001806 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001807 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1808 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001809
1810
1811 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1812 if ( m_InputToInputWeights )
1813 {
1814 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1815 (n_cell * n_input), "InputLayerNormWeights");
1816 }
1817
1818 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1819 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1820 (n_cell * n_input), "InputToForgetWeights");
1821
1822 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1823 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1824 (n_cell * n_input), "InputToCellWeights");
1825
1826 if ( m_RecurrentToInputWeights )
1827 {
1828 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1829 (n_cell * n_output), "RecurrentToInputWeights");
1830 }
1831
1832 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1833 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1834 (n_cell * n_output), "RecurrentToForgetWeights");
1835
1836 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1837 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1838 (n_cell * n_output), "RecurrentToCellWeights");
1839
1840 // Make sure the input-gate's parameters are either both present (regular
1841 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1842 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1843 !m_Parameters.m_CifgEnabled) ||
1844 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1845 m_Parameters.m_CifgEnabled));
1846 if (!cifg_weights_all_or_none)
1847 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001848 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1849 "RecurrentToInputWeights must either both be present (regular LSTM) "
1850 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1851 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001852 }
1853
1854 if ( m_CellToInputWeights )
1855 {
1856 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1857 n_cell, "CellToInputWeights");
1858 }
1859 if ( m_CellToForgetWeights )
1860 {
1861 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1862 n_cell, "CellToForgetWeights");
1863 }
1864 if ( m_CellToOutputWeights )
1865 {
1866 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1867 n_cell, "CellToOutputWeights");
1868 }
1869
1870 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1871 bool peephole_weights_all_or_none =
1872 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1873 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1874 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1875 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1876 if (!peephole_weights_all_or_none)
1877 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001878 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001879 }
1880
1881 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1882 if (m_Parameters.m_CifgEnabled)
1883 {
1884 if (m_InputGateBias)
1885 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001886 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001887 }
1888 }
1889 else
1890 {
1891 if (!m_InputGateBias)
1892 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001893 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1894 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001895 }
1896 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1897 n_cell, "InputGateBias");
1898 }
1899
1900 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1901 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1902
1903 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1904 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1905
1906 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1907 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1908
1909 if (m_ProjectionWeights)
1910 {
1911 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
1912 (n_cell * n_output), "ProjectionWeights");
1913 }
1914 if (m_ProjectionBias)
1915 {
1916 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
1917 }
1918
1919 // Making sure the projection tensors are consistent:
1920 // 1) If projection weight is not present, then projection bias should not be
1921 // present.
1922 // 2) If projection weight is present, then projection bias is optional.
1923 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
1924 !m_Parameters.m_ProjectionEnabled)
1925 || (m_ProjectionWeights && !m_ProjectionBias &&
1926 m_Parameters.m_ProjectionEnabled)
1927 || (m_ProjectionWeights && m_ProjectionBias &&
1928 m_Parameters.m_ProjectionEnabled));
1929 if (!projecton_tensors_consistent)
1930 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001931 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001932 }
1933
1934 // The four layer normalization weights either all have values or none of them have values. Additionally, if
1935 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
1936 // either all have values or none of them have values. Layer normalization is used when the values of all the
1937 // layer normalization weights are present
1938 if (m_InputLayerNormWeights)
1939 {
1940 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
1941 }
1942 if (m_ForgetLayerNormWeights)
1943 {
1944 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1945 }
1946 if (m_CellLayerNormWeights)
1947 {
1948 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1949 }
1950 if (m_OutputLayerNormWeights)
1951 {
1952 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1953 }
1954
Jan Eilers38e05bd2019-06-26 13:10:09 +01001955 if (m_Parameters.m_LayerNormEnabled)
1956 {
1957 if (!m_Parameters.m_CifgEnabled)
1958 {
1959 if (!m_InputLayerNormWeights)
1960 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001961 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
1962 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001963 }
1964 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
1965 1, n_cell, "InputLayerNormWeights");
1966 }
1967 else if (m_InputLayerNormWeights)
1968 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001969 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
1970 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001971 }
1972
1973 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
1974 "ForgetLayerNormWeights");
1975 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1976
1977 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
1978 "OutputLayerNormWeights");
1979 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1980
1981 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
1982 "CellLayerNormWeights");
1983 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1984 }
1985 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
1986 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001987 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
1988 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001989 }
telsoa01c577f2c2018-08-31 09:22:23 +01001990}
1991
1992void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1993{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001994 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001995
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001996 ValidateNumInputs(workloadInfo, descriptorName, 1);
1997 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1998
1999 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2000 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2001
2002 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002003 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002004 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002005 }
2006
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002007 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002008 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002009 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002010 }
2011
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002012 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002013}
2014
2015void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2016{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002017 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002018
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002019 ValidateNumInputs(workloadInfo, descriptorName, 1);
2020 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2021
2022 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2023 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2024
2025 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002026 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002027 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002028 }
2029
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002030 if (outputTensorInfo.GetDataType() != DataType::Float32)
2031 {
2032 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2033 }
2034
2035 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002036}
2037
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002038void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2039{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002040 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002041
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002042 ValidateNumInputs(workloadInfo, descriptorName, 2);
2043 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2044
2045 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2046 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2047 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2048
2049 std::vector<DataType> supportedTypes =
2050 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002051 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002052 DataType::QAsymmU8,
2053 DataType::QSymmS16,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01002054 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002055 };
2056
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002057 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2058 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2059 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002060
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002061 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2062 inputTensorInfo1,
2063 outputTensorInfo,
2064 descriptorName,
2065 "input_0",
2066 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002067}
2068
David Beckc2044fe2018-09-05 15:00:38 +01002069void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2070{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002071 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002072
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002073 ValidateNumInputs(workloadInfo, descriptorName, 2);
2074 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2075
2076 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2077 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2078 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2079
2080 std::vector<DataType> supportedTypes =
2081 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002082 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002083 DataType::QAsymmU8,
2084 DataType::QSymmS16,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01002085 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002086 };
2087
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002088 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2089 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2090 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002091
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002092 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2093 inputTensorInfo1,
2094 outputTensorInfo,
2095 descriptorName,
2096 "input_0",
2097 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002098}
2099
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002100void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2101{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002102 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002103
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002104 ValidateNumInputs(workloadInfo, descriptorName, 2);
2105 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2106
2107 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2108 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2109 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2110
2111 std::vector<DataType> supportedTypes =
2112 {
Mike Kelly1da02362019-08-01 08:43:57 +01002113 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002114 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002115 DataType::Signed32,
Keith Davis67e6c542020-02-19 10:08:33 +00002116 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002117 DataType::QAsymmU8,
2118 DataType::QSymmS16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002119 };
2120
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002121 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2122 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2123 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002124
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002125 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2126 inputTensorInfo1,
2127 outputTensorInfo,
2128 descriptorName,
2129 "input_0",
2130 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002131}
2132
narpra01a6bf9122018-09-10 09:50:09 +01002133void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2134{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002135 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002136
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002137 ValidateNumInputs(workloadInfo, descriptorName, 1);
2138 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2139
2140 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2141 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002142
2143 std::vector<DataType> supportedTypes =
2144 {
2145 DataType::Float32,
2146 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002147 DataType::QAsymmU8,
2148 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002149 };
narpra01eb061912018-09-10 17:35:27 +01002150
James Conroy4d1ff582019-06-10 17:06:39 +01002151 // First check if input tensor data type is supported, then
2152 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002153 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2154 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002155
narpra0132b90462018-09-13 11:07:48 +01002156 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002157 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002158 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002159 }
narpra0132b90462018-09-13 11:07:48 +01002160 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002161 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002162 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002163 }
2164 else
2165 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002166 unsigned int outputDim =
2167 inputTensorInfo.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
2168 ValidateTensorNumDimensions(outputTensorInfo,
2169 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002170 outputDim > 0 ? outputDim : 1,
2171 "output");
2172 }
narpra01a6bf9122018-09-10 09:50:09 +01002173}
2174
jimfly012c9322a2018-09-19 10:59:49 +01002175void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2176{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002177 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002178
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002179 ValidateNumInputs(workloadInfo, descriptorName, 1);
2180 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2181
2182 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2183 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002184
jimfly012c9322a2018-09-19 10:59:49 +01002185 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002186 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2187
jimfly012c9322a2018-09-19 10:59:49 +01002188 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002189 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2190 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2191 "as there are dimensions in the input tensor that is " +
2192 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2193 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002194 }
2195}
2196
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002197void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2198{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002199 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002200
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002201 ValidateNumInputs(workloadInfo, descriptorName, 1);
2202 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002203
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002204 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2205 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2206
Sadik Armagan2208b602019-07-31 16:36:27 +01002207 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002208 {
James Conroyd47a0642019-09-17 14:22:06 +01002209 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002210 DataType::Float16,
2211 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002212 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002213 DataType::QAsymmU8,
2214 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002215 };
2216
2217 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002218
Keith Davis0c2eeac2020-02-11 16:51:50 +00002219 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002220 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002221 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002222 }
2223}
2224
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002225void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2226{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002227 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002228
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002229 ValidateNumInputs(workloadInfo, descriptorName, 1);
2230 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002231
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002232 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2233 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002234
2235 std::vector<DataType> supportedTypes =
2236 {
James Conroyd47a0642019-09-17 14:22:06 +01002237 DataType::Float32,
2238 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002239 DataType::QAsymmU8,
2240 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002241 };
2242
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002243 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2244 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002245}
2246
Conor Kennedy430b5d82018-11-14 15:28:28 +00002247void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2248{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002249 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002250
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002251 ValidateNumInputs(workloadInfo, descriptorName, 1);
2252 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2253
2254 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2255 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002256
2257 std::vector<DataType> supportedTypes =
2258 {
2259 DataType::Float16,
2260 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002261 DataType::QAsymmU8,
2262 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002263 };
2264
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002265 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2266 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002267
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002268 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002269
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002270 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002271 if (rank > 4)
2272 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002273 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002274 }
2275
Conor Kennedy430b5d82018-11-14 15:28:28 +00002276 // Begin, End & Stride length must be of rank(input0)
2277 if (m_Parameters.m_Begin.size() != rank)
2278 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002279 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002280 }
2281
2282 if (m_Parameters.m_End.size() != rank)
2283 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002284 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002285 }
2286
2287 if (m_Parameters.m_Stride.size() != rank)
2288 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002289 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002290 }
2291
2292 // Stride entries must be non-zero
2293 for (auto& stride : m_Parameters.m_Stride)
2294 {
2295 if (stride == 0)
2296 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002297 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002298 }
2299 }
2300}
2301
kevmay0190539692018-11-29 08:40:19 +00002302void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2303{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002304 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002305
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002306 ValidateNumInputs(workloadInfo, descriptorName, 2);
2307 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2308
2309 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2310 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2311 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2312
2313 std::vector<DataType> supportedTypes =
2314 {
Mike Kelly1da02362019-08-01 08:43:57 +01002315 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002316 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002317 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002318 DataType::QAsymmU8,
2319 DataType::QSymmS16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002320 };
2321
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002322 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2323 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2324 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002325
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002326 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2327 inputTensorInfo1,
2328 outputTensorInfo,
2329 descriptorName,
2330 "input_0",
2331 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002332}
2333
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002334void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2335{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002336 const std::string descriptorName{"DebugQueueDescriptor"};
2337
2338 ValidateNumInputs(workloadInfo, descriptorName, 1);
2339 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002340}
2341
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002342void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2343{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002344 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002345
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002346 ValidateNumInputs(workloadInfo, descriptorName, 2);
2347 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002348
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002349 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2350 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2351 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2352
2353 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2354 inputTensorInfo1,
2355 outputTensorInfo,
2356 descriptorName,
2357 "input_0",
2358 "input_1");
2359
2360 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002361 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002362 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002363 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002364}
2365
FrancisMurtagh878f0232018-12-19 10:56:15 +00002366void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2367{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002368 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002369
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002370 ValidateNumInputs(workloadInfo, descriptorName, 2);
2371 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002372
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002373 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2374 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2375 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2376
2377 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2378 inputTensorInfo1,
2379 outputTensorInfo,
2380 descriptorName,
2381 "input_0",
2382 "input_1");
2383
2384 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002385 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002386 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002387 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002388}
2389
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002390void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2391{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002392 const std::string descriptorName{"RsqrtQueueDescriptor"};
2393
2394 ValidateNumInputs(workloadInfo, descriptorName, 1);
2395 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2396
2397 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2398 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2399
2400 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002401
2402 std::vector<DataType> supportedTypes =
2403 {
James Conroyd47a0642019-09-17 14:22:06 +01002404 DataType::Float16,
2405 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002406 DataType::QAsymmU8,
2407 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002408 };
2409
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002410 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2411 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002412}
2413
narpra01b89b05f2019-01-16 09:53:09 +00002414void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2415{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002416 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002417
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002418 ValidateNumInputs(workloadInfo, descriptorName, 2);
2419 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002420
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002421 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2422 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002423 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002424 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002425 }
2426
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002427 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2428 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2429
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002430 std::vector<DataType> supportedTypes =
2431 {
James Conroyd47a0642019-09-17 14:22:06 +01002432 DataType::Float16,
2433 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002434 DataType::QAsymmU8,
2435 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002436 };
2437
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002438 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002439
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002440 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002441
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002442 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2443 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002444}
2445
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002446void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2447{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002448 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2449
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002450 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002451
2452 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2453 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002454 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002455 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2456 }
2457
2458 if (m_Anchors == nullptr)
2459 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002460 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002461 }
2462
2463 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002464 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2465 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2466
2467 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002468 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002469 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2470 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002471
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002472 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2473 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2474 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002475
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002476 const std::vector<DataType> supportedInputTypes =
2477 {
2478 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002479 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002480 DataType::QAsymmU8,
2481 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002482 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002483
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002484 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2485 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2486 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2487
2488 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2489 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2490 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2491 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2492
2493 // NOTE: Output is always Float32 regardless of input type
2494 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2495 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2496 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2497 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002498
2499 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2500 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002501 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002502 "must be positive and less than or equal to 1.");
2503 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002504
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002505 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2506 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002507 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002508 "should be equal to number of classes + 1.");
2509 }
2510}
2511
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002512void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2513{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002514 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002515
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002516 ValidateNumInputs(workloadInfo, descriptorName, 1);
2517 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2518
2519 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2520 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2521
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002522 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002523 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002524 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002525 }
2526
Sadik Armagan2208b602019-07-31 16:36:27 +01002527 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002528 {
James Conroyd47a0642019-09-17 14:22:06 +01002529 DataType::Float32,
2530 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002531 };
2532
2533 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002534}
2535
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002536void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2537{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002538 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002539
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002540 ValidateNumInputs(workloadInfo, descriptorName, 2);
2541 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002542
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002543 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2544 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2545 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002546
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002547 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2548 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2549
2550 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2551 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002552}
2553
Sadik Armaganeff363d2019-04-05 15:25:46 +01002554void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2555{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002556 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002557
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002558 ValidateNumInputs(workloadInfo, descriptorName, 2);
2559 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2560
2561 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2562 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2563
2564 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2565 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2566
2567 std::vector<DataType> supportedTypes =
2568 {
Sadik Armaganeff363d2019-04-05 15:25:46 +01002569 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002570 DataType::QAsymmU8,
2571 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002572 };
2573
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002574 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2575 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002576
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002577 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2578 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002579
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002580 ValidateTensorShapesMatch(inputTensorInfo0,
2581 outputTensorInfo0,
2582 descriptorName,
2583 "input_0",
2584 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002585
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002586 ValidateTensorShapesMatch(inputTensorInfo0,
2587 outputTensorInfo1,
2588 descriptorName,
2589 "input_0",
2590 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002591}
2592
Derek Lamberti901ea112019-12-10 22:07:09 +00002593void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002594{
2595 // This is internally generated so it should not need validation.
2596}
2597
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002598void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2599{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002600 const std::string& descriptorName{"PreluQueueDescriptor"};
2601
2602 ValidateNumInputs(workloadInfo, descriptorName, 2);
2603 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2604
2605 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2606 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2607 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002608
2609 std::vector<DataType> supportedTypes
2610 {
2611 DataType::Float16,
2612 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002613 DataType::QAsymmU8,
2614 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002615 };
2616
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002617 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2618 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002619
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002620 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002621
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002622 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2623 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002624
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002625 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2626 alphaTensorInfo,
2627 outputTensorInfo,
2628 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002629 "input",
2630 "alpha");
2631}
2632
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002633void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2634{
2635 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2636
2637 ValidateNumInputs(workloadInfo, descriptorName, 1);
2638 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2639
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002640 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2641 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2642
2643 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2644 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002645
2646 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002647
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002648 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2649 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002650
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002651 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2652
2653 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002654 if (m_Parameters.m_BiasEnabled)
2655 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002656 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002657
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002658 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2659 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002660
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002661 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002662 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002663 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002664
2665 ValidatePerAxisQuantization(inputTensorInfo,
2666 outputTensorInfo,
2667 weightTensorInfo,
2668 optionalBiasTensorInfo,
2669 descriptorName);
2670
2671 std::vector<DataType> supportedTypes =
2672 {
2673 DataType::Float32,
2674 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002675 DataType::QAsymmU8,
2676 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002677 };
2678
2679 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2680 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002681}
2682
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002683void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2684{
2685 const std::string descriptorName{"TransposeQueueDescriptor"};
2686
2687 ValidateNumInputs(workloadInfo, descriptorName, 1);
2688 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2689
2690 const PermutationVector& mapping = m_Parameters.m_DimMappings;
2691
2692 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2693 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2694
2695 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
2696 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
2697
2698 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
2699 {
2700 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
2701 {
2702 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
2703 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
2704 "must match dst dimension " + to_string(i) +
2705 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
2706 }
2707 }
2708
2709 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2710}
2711
James Conroy9c3cae82019-08-01 16:01:48 +01002712void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2713{
2714 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
2715
2716 // Validate number of inputs/outputs
2717 ValidateNumInputs(workloadInfo, descriptorName, 3);
2718 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2719
2720 // Input/output tensor infos
2721 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2722 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
2723 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
2724
2725 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2726 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2727
2728 std::vector<DataType> inputOutputSupportedTypes =
2729 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002730 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01002731 };
2732
2733 std::vector<DataType> cellStateSupportedTypes =
2734 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002735 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01002736 };
2737
2738 std::vector<DataType> weightsSupportedTypes =
2739 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002740 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01002741 };
2742
2743 std::vector<DataType> biasSupportedTypes =
2744 {
2745 DataType::Signed32
2746 };
2747
2748 // Validate types of input/output tensors
2749 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2750 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2751 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2752
2753 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2754 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2755
2756 // Validate matching types of input/output tensors
2757 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2758 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2759 "outputStateIn", "outputStateOut");
2760 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2761
2762 // Validate matching quantization info for input/output tensors
2763 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2764 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
2765 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002766
James Conroy9c3cae82019-08-01 16:01:48 +01002767 // Infer number of batches, input size and output size from tensor dimensions
2768 const uint32_t numBatches = inputInfo.GetShape()[0];
2769 const uint32_t inputSize = inputInfo.GetShape()[1];
2770 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
2771
2772 // Validate number of dimensions and number of elements for input/output tensors
2773 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2774 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
2775 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2776 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
2777 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2778
2779 // Validate number of dimensions and number of elements for weights tensors
2780 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
2781 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
2782 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
2783
2784 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2785 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2786 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
2787
2788 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2789 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2790 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
2791
2792 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2793 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2794 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
2795
2796 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
2797 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
2798 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
2799
2800 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2801 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2802 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
2803 " RecurrentToForgetWeights");
2804
2805 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2806 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2807 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2808
2809 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2810 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2811 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2812
2813 // Validate data types for weights tensors (all should match each other)
2814 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
2815
2816 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
2817 "inputToInputWeights", "inputToForgetWeights");
2818 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
2819 "inputToInputWeights", "inputToCellWeights");
2820 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2821 "inputToInputWeights", "inputToOutputWeights");
2822
2823 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
2824 "inputToInputWeights", "recurrentToInputWeights");
2825 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2826 "inputToInputWeights", "recurrentToForgeteights");
2827 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2828 "inputToInputWeights", "recurrentToCellWeights");
2829 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2830 "inputToInputWeights", "recurrentToOutputWeights");
2831
2832 // Validate matching quantization info for weight tensors (all should match each other)
2833 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
2834 descriptorName, "inputToInputWeights", "inputToForgetWeights");
2835 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
2836 descriptorName, "inputToInputWeights", "inputToCellWeights");
2837 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
2838 descriptorName, "inputToInputWeights", "inputToOutputWeights");
2839
2840 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
2841 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
2842 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
2843 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
2844 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
2845 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
2846 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
2847 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
2848
2849 // Validate number of dimensions and number of elements in bias tensors
2850 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
2851 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
2852 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
2853
2854 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2855 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2856 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
2857
2858 ValidatePointer(m_CellBias, descriptorName, "CellBias");
2859 auto cellBiasInfo = m_CellBias->GetTensorInfo();
2860 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
2861
2862 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
2863 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
2864 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
2865
2866 // Validate data types for bias tensors (all should match each other)
2867 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
2868
2869 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
2870 "inputGateBias", "forgetGateBias");
2871 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
2872 "inputGateBias", "cellBias");
2873 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
2874 "inputGateBias", "outputGateBias");
2875
2876 // Validate bias tensor quantization info
2877 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2878 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2879 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2880 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2881}
2882
Kevin May868eb142019-09-04 17:29:31 +01002883void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2884{
2885 const std::string descriptorName{"AbsQueueDescriptor"};
2886
2887 ValidateNumInputs(workloadInfo, descriptorName, 1);
2888 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2889
2890 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2891 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2892
2893 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2894
2895 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01002896 {
2897 DataType::Float16,
2898 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002899 DataType::QAsymmU8,
2900 DataType::QSymmS16
James Conroyd47a0642019-09-17 14:22:06 +01002901 };
Kevin May868eb142019-09-04 17:29:31 +01002902
2903 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2904 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2905}
2906
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002907void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2908{
2909 const std::string descriptorName{"SliceQueueDescriptor"};
2910
2911 ValidateNumInputs(workloadInfo, descriptorName, 1);
2912 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2913
2914 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2915 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2916
2917 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2918
2919 const unsigned int rank = inputTensorInfo.GetNumDimensions();
2920 if (rank > 4)
2921 {
2922 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
2923 }
2924
2925 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
2926
2927 // Check if m_Begin and m_Size have the expected length
2928 if (m_Parameters.m_Begin.size() != rank)
2929 {
2930 throw InvalidArgumentException(descriptorName +
2931 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
2932 }
2933 if (m_Parameters.m_Size.size() != rank)
2934 {
2935 throw InvalidArgumentException(descriptorName +
2936 ": Length of size descriptor must equal rank " + std::to_string(rank));
2937 }
2938
2939 // Check if the shape of the output tensor matches m_Size
2940 const TensorShape& outputShape = outputTensorInfo.GetShape();
2941 for (unsigned int i = 0u; i < rank; ++i)
2942 {
2943 if (m_Parameters.m_Size[i] != outputShape[i])
2944 {
2945 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
2946 }
2947 }
2948
2949 // Check if the sum of begin offset and size in a given dimension
2950 // does not exceed the size of corresponding input
2951 const TensorShape& inputShape = inputTensorInfo.GetShape();
2952 for(unsigned int i = 0u; i < rank; ++i)
2953 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002954 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002955 {
2956 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
2957 std::to_string(i) + " exceeds input size.");
2958 }
2959 }
2960}
2961
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01002962void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2963{
2964 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
2965
2966 ValidateNumInputs(workloadInfo, descriptorName, 1);
2967 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2968
2969 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
2970 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
2971
2972 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
2973 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
2974
2975 std::vector<DataType> supportedTypes =
2976 {
2977 DataType::Float32,
2978 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002979 DataType::QAsymmU8,
2980 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01002981 };
2982
2983 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
2984 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
2985
2986 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
2987
2988 if (m_Parameters.m_BlockSize == 0)
2989 {
2990 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
2991 }
2992
2993 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
2994 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
2995 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
2996 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
2997
2998 const TensorShape& outputShape = outputInfo.GetShape();
2999 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3000 {
3001 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3002 "must be divisible by block size.");
3003 }
3004
3005 const TensorShape& inputShape = inputInfo.GetShape();
3006 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3007 {
3008 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3009 "must be divisible by the square of block size." );
3010 }
3011}
3012
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003013void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3014{
3015 const std::string descriptorName{"ComparisonQueueDescriptor"};
3016
3017 ValidateNumInputs(workloadInfo, descriptorName, 2);
3018 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3019
3020 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3021 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3022 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3023
3024 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3025 inputTensorInfo1,
3026 outputTensorInfo,
3027 descriptorName,
3028 "input_0",
3029 "input_1");
3030
3031 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3032 {
3033 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3034 }
3035}
3036
josh minor4a3c6102020-01-06 16:40:46 -06003037void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3038{
3039 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3040
3041 ValidateNumInputs(workloadInfo, descriptorName, 1);
3042 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3043
3044 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3045 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3046
3047 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3048
3049 std::vector<DataType> supportedTypes =
3050 {
3051 DataType::Float16,
3052 DataType::Float32,
3053 DataType::QAsymmU8,
3054 DataType::QSymmS16
3055 };
3056
3057 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3058 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3059}
3060
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003061} // namespace armnn