blob: 81aefa94e714350f4395a0c425a26c1fc007408b [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00006#include <backendsCommon/WorkloadData.hpp>
7#include <backendsCommon/CpuTensorHandle.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/TensorUtils.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000013#include <string>
14#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/format.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010017#include <boost/numeric/conversion/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +000029 case DataType::BFloat16:
30 return DataType::BFloat16;
telsoa01c577f2c2018-08-31 09:22:23 +010031 case DataType::Float16:
32 return DataType::Float16;
telsoa014fcda012018-03-09 14:13:49 +000033 case DataType::Float32:
34 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000035 case DataType::QAsymmS8:
36 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000037 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000038 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000039 case DataType::QSymmS8:
40 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000041 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010042 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000043 default:
44 BOOST_ASSERT_MSG(false, "Invalid input data type");
45 return DataType::Float32;
46 }
47}
48
49namespace
50{
51
52//---------------------------------------------------------------
53//android ndk does not support std::to_string function.
54template <typename T>
55std::string to_string(T value)
56{
57 std::ostringstream os;
58 os << value;
59 return os.str();
60}
61
62//---------------------------------------------------------------
63void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
64{
65 if (!ptr)
66 {
67 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
68 paramName + " parameter must be set.");
69 }
70}
71
72//---------------------------------------------------------------
73void ValidateTensorShapesMatch(const TensorInfo& first,
74 const TensorInfo& second,
75 std::string const& descName,
76 std::string const& firstName,
77 std::string const& secondName)
78{
79 if (first.GetShape() != second.GetShape())
80 {
81 throw InvalidArgumentException(descName + ": "
82 + firstName + " & " + secondName + " must have identical shapes");
83 }
84}
85
86//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010087void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000088{
Sadik Armaganeff363d2019-04-05 15:25:46 +010089 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000090 {
91 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010092 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000093 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
94 }
95}
96
97//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010098void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000099{
Sadik Armaganeff363d2019-04-05 15:25:46 +0100100 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000101 {
102 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100103 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000104 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
105 }
106}
107
108//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100109void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000110 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100111 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000112 std::string const& tensorName)
113{
114 if (tensor.GetNumDimensions() != numDimensions)
115 {
116 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
117 to_string(tensor.GetNumDimensions()) + " dimensions for " +
118 tensorName + " tensor.");
119 }
120}
121
122//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100123void ValidateTensorNumElements(const TensorInfo& tensor,
124 std::string const& descName,
125 unsigned int numElements,
126 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100127{
128 if (tensor.GetNumElements() != numElements)
129 {
130 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100131 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100132 tensorName + " tensor.");
133 }
134}
135
136//---------------------------------------------------------------
137void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100138 unsigned int numDimension,
139 unsigned int numElements,
140 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100141{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100142 const std::string functionName{"ValidateTensorNumDimNumElem"};
143 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
144 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100145}
146
147//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000148void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
149 const std::string& descName, std::string const& tensorName)
150{
151 if (tensor.GetDataType() != dataType)
152 {
153 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
154 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
155 }
156}
157
Derek Lambertid466a542020-01-22 15:37:29 +0000158void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
159{
160 ARMNN_NO_DEPRECATE_WARN_BEGIN
161 if (tensor.GetDataType() != DataType::QSymmS8 &&
162 tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
163 {
164 throw InvalidArgumentException(descName +
165 ": Expected data type which supports per-axis quantization scheme but got " +
166 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
167 }
168 ARMNN_NO_DEPRECATE_WARN_END
169}
170
telsoa014fcda012018-03-09 14:13:49 +0000171//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100172void ValidateTensorQuantizationSpace(const TensorInfo& first,
173 const TensorInfo& second,
174 const std::string& descName,
175 std::string const& firstName,
176 std::string const& secondName)
177{
178 if (!first.IsQuantized() ||
179 !second.IsQuantized())
180 {
181 // Not a quantized type, ignore the validation
182 return;
183 }
184
185 DataType firstDataType = first.GetDataType();
186 DataType secondDataType = second.GetDataType();
187
188 if (firstDataType != secondDataType)
189 {
190 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
191 " must be of the same quantized type, " +
192 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
193 secondName + " is " + GetDataTypeName(secondDataType));
194 }
195
196 if (!first.IsTypeSpaceMatch(second))
197 {
198 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
199 " must have the same quantization space, " +
200 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
201 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
202 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
203 " and scale " + to_string(second.GetQuantizationScale()));
204 }
205}
206
207//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100208void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
209 const TensorInfo& inputTensorInfo,
210 const TensorInfo& weightsTensorInfo,
211 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000212{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000213 // Helper lambda function to validate a single bias quantization scale value
214 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
215 {
ricbur013f4d7102019-10-31 16:22:18 +0000216 constexpr float tolerance = 0.000001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000217 if (std::abs(biasScale - expectedScale) > tolerance)
218 {
219 // Print the float values with extra precision to see very small differences
220 std::stringstream msg;
221 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
222 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
223 biasScale;
224 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
225 }
226 };
227
telsoa014fcda012018-03-09 14:13:49 +0000228 if (biasTensor.GetQuantizationOffset() != 0)
229 {
230 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
231 to_string(biasTensor.GetQuantizationOffset()));
232 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000233
234 if (biasTensor.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000235 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000236 // Validate per-axis quantization scales
237 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
238 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
239
240 if (weightScales.size() != biasScales.size())
241 {
242 std::stringstream msg;
243 msg << descName << ": Expected matchhing number of per-axis quantization scales, but got different "
244 << "values: weights=" << weightScales.size() << ", biases=" << biasScales.size();
245 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
246 }
247
248 for (size_t i = 0ul; i < biasScales.size(); ++i)
249 {
250 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
251 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
252 }
253 }
254 else
255 {
256 // Validate per-tensor quantization scale
257 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
258 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000259 }
260}
261
262//---------------------------------------------------------------
263void ValidateTensors(const std::vector<ITensorHandle*>& vec,
264 unsigned int numExpected,
265 const std::string& descName,
266 const std::string& varName)
267{
268 if (vec.empty() && numExpected > 0)
269 {
270 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
271 }
272
273 for (unsigned int i = 0; i < numExpected; ++i)
274 {
275 if (!vec[i])
276 {
277 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
278 }
279 }
280}
281
282//---------------------------------------------------------------
283void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
284 const TensorInfo& second,
285 const TensorInfo& output,
286 std::string const& descName,
287 std::string const& firstName,
288 std::string const& secondName)
289{
290 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
291 // broadcasted.
292 if (first.GetNumDimensions() != second.GetNumDimensions())
293 {
294 throw InvalidArgumentException(descName + ": Tensors "
295 + firstName + " & " + secondName
296 + " must have the same number of dimensions in order to be broadcasted");
297 }
298 uint32_t numDims = first.GetNumDimensions();
299 std::vector<uint32_t> outputDims(numDims, 0u);
300 for (uint32_t i = 0; i < numDims; i++)
301 {
302 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
303 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
304 if (dimsNotEqual && dimsNotOne)
305 {
306 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
307 }
308 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
309 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100310 TensorShape broadcastShape = TensorShape(boost::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000311 if (broadcastShape != output.GetShape())
312 {
313 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
314 + firstName + " & " + secondName
315 + " does not match the output shape");
316 }
317}
318
319//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100320void ValidateDataTypes(const TensorInfo& info,
321 const std::vector<armnn::DataType>& supportedTypes,
322 std::string const& descName)
323{
324 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
325 if (iterator == supportedTypes.end())
326 {
327 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
328 }
329}
330
James Conroy4d1ff582019-06-10 17:06:39 +0100331//---------------------------------------------------------------
332void ValidateTensorDataTypesMatch(const TensorInfo& first,
333 const TensorInfo& second,
334 std::string const& descName,
335 std::string const& firstName,
336 std::string const& secondName)
337{
338 if (first.GetDataType() != second.GetDataType())
339 {
340 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
341 " must have identical data types.");
342 }
343}
344
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100345//---------------------------------------------------------------
346void ValidateTensorNumElementsMatch(const TensorInfo& first,
347 const TensorInfo& second,
348 std::string const& descName,
349 std::string const& firstName,
350 std::string const& secondName)
351{
352 if (first.GetNumElements() != second.GetNumElements())
353 {
354 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
355 " must have the same number of elements.");
356 }
357}
358
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000359void ValidateWeightDataType(const TensorInfo& inputInfo,
360 const TensorInfo& weightInfo,
361 const std::string& descName)
362{
363 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000364 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000365 {
Derek Lambertid466a542020-01-22 15:37:29 +0000366 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000367 const std::vector<DataType> validTypes =
368 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000369 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000370 DataType::QAsymmS8,
Derek Lambertid466a542020-01-22 15:37:29 +0000371 DataType::QSymmS8,
372 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000373 };
Derek Lambertid466a542020-01-22 15:37:29 +0000374 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000375
376 ValidateDataTypes(weightInfo, validTypes, descName);
377 }
378 else
379 {
380 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
381 }
382}
383
384void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
385 const std::string& descName,
386 const std::string& tensorName)
387{
388 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
389 if (!quantizationDim.has_value())
390 {
391 throw InvalidArgumentException(boost::str(
392 boost::format("%1%: Quantization dimension for per-axis quantization not set on tensor %2%.")
393 % descName % tensorName));
394 }
395
396 if (quantizationDim.value() != 0)
397 {
398 throw InvalidArgumentException(boost::str(
399 boost::format("%1%: Quantization dimension for per-axis quantization expected to be 0 on tensor %2%, "
400 "but got: %3%") % descName % tensorName % quantizationDim.value()));
401 }
402}
403
404void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
405 const std::string& descName,
406 const std::string& tensorName)
407{
408 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
409 if (quantizationOffset != 0)
410 {
411 throw InvalidArgumentException(boost::str(
412 boost::format("%1%: Quantization offset for per-axis quantization expected to be 0 on tensor %2%, "
413 "but got: %3%") % descName % tensorName % quantizationOffset));
414 }
415}
416
417void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
418 const TensorInfo& outputInfo,
419 const TensorInfo& weightInfo,
420 const Optional<TensorInfo>& optionalBiasInfo,
421 const std::string& descName)
422{
423 if (weightInfo.HasPerAxisQuantization())
424 {
425 const DataType inputDataType = inputInfo.GetDataType();
426 const DataType outputDataType = outputInfo.GetDataType();
427
Keith Davis0c2eeac2020-02-11 16:51:50 +0000428 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000429
430 if (!canHavePerAxisQuantization)
431 {
432 throw InvalidArgumentException(boost::str(
433 boost::format("%1%: Per-axis quantization parameters set on tensor %2%, "
434 "but data type does not support per-axis quantization.") % descName % "weight"));
435 }
436
Derek Lambertid466a542020-01-22 15:37:29 +0000437
438 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000439 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
440 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
441
442 if (optionalBiasInfo.has_value())
443 {
444 const TensorInfo& biasInfo = optionalBiasInfo.value();
445 if (!biasInfo.HasPerAxisQuantization())
446 {
447 throw InvalidArgumentException(boost::str(
448 boost::format("%1%: Per-axis quantization parameters not set on bias tensor, despite being set on "
449 "weight tensor.") % descName));
450 }
451
452 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
453 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
454 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
455 }
456 }
457}
458
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100459} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000460
461void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
462 unsigned int numExpectedIn, unsigned int numExpectedOut) const
463{
464 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
465 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
466}
467
468//---------------------------------------------------------------
469void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
470{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100471 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000472
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100473 ValidateNumInputs(workloadInfo, descriptorName, 1);
474 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000475
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100476 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
477 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
478
479 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
480 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000481
482 if (m_Inputs.size() != m_Outputs.size())
483 {
484 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100485 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
486 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000487 }
488
489 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
490 {
491 if (!m_Inputs[i])
492 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100493 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
494 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000495 }
496
497 if (!m_Outputs[i])
498 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100499 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
500 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000501 }
502 }
503}
504
Derek Lambertif674aa02019-08-01 15:56:25 +0100505//---------------------------------------------------------------
506void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
507{
508 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
509 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
510
511 if (workloadInfo.m_InputTensorInfos.size() != 1)
512 {
513 throw InvalidArgumentException(boost::str(
514 boost::format("Number of input infos (%1%) is not 1.")
515 % workloadInfo.m_InputTensorInfos.size()));
516
517 }
518
519 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
520 {
521 throw InvalidArgumentException(boost::str(
522 boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
523 % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
524 }
525
526 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
527 {
528 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
529 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
530 {
531 throw InvalidArgumentException(boost::str(
532 boost::format("Number of elements for tensor input and output %1% does not match")
533 % i ));
534 }
535 }
536
537 if (m_Inputs.size() != 1)
538 {
539 throw InvalidArgumentException(boost::str(
540 boost::format("Number of inputs (%1%) is not 1.")
541 % m_Inputs.size()));
542 }
543
544 if (m_Inputs.size() != m_Outputs.size())
545 {
546 throw InvalidArgumentException(boost::str(
547 boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
548 % m_Inputs.size() % m_Outputs.size()));
549 }
550
551 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
552 {
553 if (!m_Inputs[i])
554 {
555 throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
556 }
557
558 if (!m_Outputs[i])
559 {
560 throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
561 }
562 }
563}
564
565//---------------------------------------------------------------
566void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
567{
568 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
569 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
570
Derek Lambertif674aa02019-08-01 15:56:25 +0100571 if (m_Inputs.size() != 1)
572 {
573 throw InvalidArgumentException(boost::str(
574 boost::format("Number of inputs (%1%) is not 1.")
575 % m_Inputs.size()));
576 }
577
578 if (m_Outputs.size() != 0)
579 {
580 throw InvalidArgumentException(boost::str(
581 boost::format("Number of outputs (%1%) is not 0.")
582 % m_Inputs.size() % m_Outputs.size()));
583 }
584
585 if (!m_Inputs[0])
586 {
587 throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
588 }
589}
590
591//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000592void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
593{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100594 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100595
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100596 ValidateNumInputs(workloadInfo, descriptorName, 1);
597 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100598
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100599 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
600 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100601
602 std::vector<DataType> supportedTypes =
603 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000604 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100605 DataType::Float16,
606 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000607 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000608 DataType::QAsymmU8,
609 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100610 };
611
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100612 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
613 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
614 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000615}
616
Nikhil Rajee391d52019-09-05 17:50:44 +0100617void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
618{
619 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
620
621 ValidateNumInputs(workloadInfo, descriptorName, 1);
622 ValidateNumOutputs(workloadInfo, descriptorName, 1);
623
624 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
625 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
626
Nikhil Raj68c2c902019-09-19 11:21:11 +0100627 if (outputTensorInfo.GetDataType() != DataType::Signed32)
628 {
629 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32.");
630 }
631
James Conroyd47a0642019-09-17 14:22:06 +0100632 std::vector<DataType> supportedInputTypes =
633 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000634 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100635 DataType::Float16,
636 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000637 DataType::QAsymmU8,
638 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000639 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +0100640 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100641
James Conroyd47a0642019-09-17 14:22:06 +0100642 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100643
644 auto inputShape = inputTensorInfo.GetShape();
645 auto outputShape = outputTensorInfo.GetShape();
646
647 auto inputNumDimensions = inputShape.GetNumDimensions();
648 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
649
650 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
651
652 // 1D input shape results in scalar output shape
653 if (inputShape.GetNumDimensions() == 1)
654 {
655 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
656 {
657 throw InvalidArgumentException(descriptorName + outputShapeError);
658 }
659 }
660 else
661 {
662 for (unsigned int i = 0; i < unsignedAxis; ++i)
663 {
664 if (outputShape[i] != inputShape[i])
665 {
666 throw InvalidArgumentException(descriptorName + outputShapeError);
667 }
668 }
669
670 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
671 {
672 if (outputShape[i - 1] != inputShape[i])
673 {
674 throw InvalidArgumentException(descriptorName + outputShapeError);
675 }
676 }
677 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100678}
679
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100680void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
681{
682 const std::string descriptorName{"SoftmaxQueueDescriptor"};
683
684 ValidateNumInputs(workloadInfo, descriptorName, 1);
685 ValidateNumOutputs(workloadInfo, descriptorName, 1);
686
687 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
688 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
689
690 std::vector<DataType> supportedTypes =
691 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000692 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100693 DataType::Float16,
694 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000695 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000696 DataType::QAsymmU8,
697 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100698 };
699
700 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
701 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
702 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
703}
704
telsoa014fcda012018-03-09 14:13:49 +0000705void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
706{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100707 const std::string descriptorName{"SplitterQueueDescriptor"};
708
709 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000710
Ruomei Yan25339c32019-05-28 16:48:20 +0100711 // Check the supported data types
712 std::vector<DataType> supportedTypes =
713 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000714 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100715 DataType::Float32,
716 DataType::Float16,
717 DataType::Boolean,
718 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000719 DataType::QAsymmU8,
720 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100721 };
722
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100723 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
724 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100725 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100726 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
727 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
728
729 const std::string outputName = "output_" + std::to_string(i);
730 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100731 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100732
telsoa014fcda012018-03-09 14:13:49 +0000733 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
734 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100735 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000736 }
737
738 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
739 {
740 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100741 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000742 "has to match number of workloadInfo.m_OutputTensorInfos. "
743 "Number of windows: " +
744 to_string(m_ViewOrigins.size()) +
745 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
746 }
747
telsoa01c577f2c2018-08-31 09:22:23 +0100748 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000749 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
750 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
751 {
telsoa01c577f2c2018-08-31 09:22:23 +0100752 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000753 ViewOrigin const& e = m_ViewOrigins[w];
754 if (e.m_Origin.size() != inputDims)
755 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100756 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000757 "have the same dimensionality as the input tensor. "
758 "Window origin (index: " +
759 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
760 " dimensions, the input "
761 "tensor has " +
762 to_string(inputDims) + " dimensions.");
763 }
764 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
765 {
766 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
767 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
768 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100769 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000770 "be smaller or equal than the size of the input in that coord.");
771 }
772 }
773 }
774}
775
Jim Flynne242f2d2019-05-22 14:24:13 +0100776void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000777{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100778 const std::string descriptorName{"ConcatQueueDescriptor"};
779
780 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000781
782 if (m_Inputs.size() <= 0)
783 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100784 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000785 }
786 if (m_Outputs.size() <= 0)
787 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100788 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000789 }
790
791 if (workloadInfo.m_InputTensorInfos.size() <= 0)
792 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100793 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000794 }
795 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
796 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100797 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000798 }
799
Nikhil Raj8599a412018-11-19 14:51:07 +0000800 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
801 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100802 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000803 }
804
805 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
806 {
807 return;
808 }
809
telsoa014fcda012018-03-09 14:13:49 +0000810 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
811 {
812 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100813 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000814 "has to match number of workloadInfo.m_InputTensorInfos. "
815 "Number of windows: " +
816 to_string(m_ViewOrigins.size()) +
817 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
818 }
819
telsoa01c577f2c2018-08-31 09:22:23 +0100820 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000821 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
822 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
823 {
telsoa01c577f2c2018-08-31 09:22:23 +0100824 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000825 ViewOrigin const& e = m_ViewOrigins[w];
826 if (e.m_Origin.size() != outputDims)
827 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100828 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000829 "have the same dimensionality as the output tensor. "
830 "Window origin (index: " +
831 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
832 " dimensions, the output "
833 "tensor has " +
834 to_string(outputDims) + " dimensions.");
835 }
telsoa01c577f2c2018-08-31 09:22:23 +0100836 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000837 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
838 {
839 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
840 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
841 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100842 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000843 "be smaller or equal than the size of the output in that coord.");
844 }
845 }
846 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100847
848 // Check the supported data types
849 std::vector<DataType> supportedTypes =
850 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000851 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100852 DataType::Float32,
853 DataType::Float16,
854 DataType::Boolean,
855 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000856 DataType::QAsymmU8,
857 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100858 };
859
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100860 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
861 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100862 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100863 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
864 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
865
866 const std::string inputName = "input_" + std::to_string(i);
867 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100868 }
telsoa014fcda012018-03-09 14:13:49 +0000869}
870
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100871void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
872{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100873 const std::string descriptorName{"StackQueueDescriptor"};
874
875 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100876
877 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
878 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100879 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100880 }
881
882 // All inputs must have the same shape, which is defined in parameters
883 const TensorShape& inputShape = m_Parameters.m_InputShape;
884 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
885 {
886 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
887 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100888 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100889 }
890 }
891
Matthew Jacksondba634f2019-08-15 15:14:18 +0100892 if (inputShape.GetNumDimensions() > 4)
893 {
894 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
895 }
896
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100897 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
898 // since the output tensor has an additional dimension.
899 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
900 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100901 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100902 "than the number of input dimensions.");
903 }
904
905 // Output shape must be as inferred from the input shape
906 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
907 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
908 {
909 if (outputShape[i] != inputShape[i])
910 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100911 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100912 "match shape inferred from input tensor.");
913 }
914 }
915
916 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
917 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100918 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100919 "match shape inferred from input tensor.");
920 }
921
922 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
923 {
924 if (outputShape[i] != inputShape[i-1])
925 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100926 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100927 "match shape inferred from input tensor.");
928 }
929 }
930
Matthew Jacksondba634f2019-08-15 15:14:18 +0100931 if (outputShape.GetNumDimensions() > 5)
932 {
933 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
934 }
935
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100936 // Check the supported data types
937 std::vector<DataType> supportedTypes =
938 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000939 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100940 DataType::Float32,
941 DataType::Float16,
942 DataType::Boolean,
943 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000944 DataType::QAsymmU8,
945 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100946 };
947
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100948 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100949
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100950 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100951 {
952 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
953 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100954 descriptorName,
955 "input_0",
956 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100957 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100958
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100959 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
960 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100961 descriptorName,
962 "input_0",
963 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100964}
965
telsoa014fcda012018-03-09 14:13:49 +0000966void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
967{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100968 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000969
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100970 ValidateNumInputs(workloadInfo, descriptorName, 1);
971 ValidateNumOutputs(workloadInfo, descriptorName, 1);
972
973 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
974 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
975
976 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
977
978 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +0000979 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100980 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +0000981 }
982
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100983 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000984
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100985 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
986 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000987
988 if (m_Parameters.m_BiasEnabled)
989 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100990 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000991
telsoa01c577f2c2018-08-31 09:22:23 +0100992 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100993 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
994 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000995
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100996 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
997 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000998 }
999
Francis Murtagh46c09d02019-05-28 08:15:28 +01001000 // Check the supported data types
1001 std::vector<DataType> supportedTypes =
1002 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001003 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001004 DataType::Float32,
1005 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001006 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001007 DataType::QAsymmU8,
1008 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001009 };
1010
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001011 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1012 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001013}
1014
telsoa014fcda012018-03-09 14:13:49 +00001015void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1016{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001017 const std::string descriptorName{"NormalizationQueueDescriptor"};
1018
1019 ValidateNumInputs(workloadInfo, descriptorName, 1);
1020 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1021
1022 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1023 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001024
1025 // Check the supported data types
1026 std::vector<DataType> supportedTypes =
1027 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001028 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001029 DataType::Float16,
1030 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001031 DataType::QAsymmU8,
1032 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001033 };
1034
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001035 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001036
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001037 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001038
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001039 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001040}
1041
1042void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1043{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001044 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001045
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001046 ValidateNumInputs(workloadInfo, descriptorName, 2);
1047 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1048
1049 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1050 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1051 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1052
1053 std::vector<DataType> supportedTypes =
1054 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001055 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001056 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001057 DataType::Float16,
1058 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001059 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001060 DataType::QSymmS16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001061 };
1062
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001063 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1064 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1065 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001066
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001067 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1068 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001069
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001070 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1071 inputTensorInfo1,
1072 outputTensorInfo,
1073 descriptorName,
1074 "input_0",
1075 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001076}
1077
telsoa014fcda012018-03-09 14:13:49 +00001078void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1079{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001080 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001081
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001082 ValidateNumInputs(workloadInfo, descriptorName, 2);
1083 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1084
1085 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1086 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1087 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1088
1089 std::vector<DataType> supportedTypes =
1090 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001091 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001092 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001093 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001094 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001095 DataType::QSymmS16,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001096 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001097 };
1098
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001099 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1100 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1101 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001102
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001103 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1104 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001105
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001106 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1107 inputTensorInfo1,
1108 outputTensorInfo,
1109 descriptorName,
1110 "input_0",
1111 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001112}
1113
1114void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1115{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001116 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001117
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001118 ValidateNumInputs(workloadInfo, descriptorName, 1);
1119 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1120
1121 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1122 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001123
1124 std::vector<DataType> supportedTypes =
1125 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001126 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001127 DataType::Float16,
1128 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001129 DataType::QAsymmU8,
1130 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001131 };
1132
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001133 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1134 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001135
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001136 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001137 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001138
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001139 ValidatePointer(m_Mean, descriptorName, "mean");
1140 ValidatePointer(m_Variance, descriptorName, "variance");
1141 ValidatePointer(m_Beta, descriptorName, "beta");
1142 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001143
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001144 const TensorInfo& mean = m_Mean->GetTensorInfo();
1145 const TensorInfo& variance = m_Variance->GetTensorInfo();
1146 const TensorInfo& beta = m_Beta->GetTensorInfo();
1147 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001148
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001149 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1150 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1151 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1152 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001153
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001154 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1155 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1156 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001157}
1158
1159void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1160{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001161 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001162
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001163 ValidateNumInputs(workloadInfo, descriptorName, 1);
1164 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001165
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001166 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1167 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001168
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001169 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1170 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001171
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001172 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001173
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001174 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1175 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001176
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001177 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001178
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001179 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001180 if (m_Parameters.m_BiasEnabled)
1181 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001182 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001183
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001184 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1185 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001186
1187 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1188 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001189 }
1190
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001191 ValidatePerAxisQuantization(inputTensorInfo,
1192 outputTensorInfo,
1193 weightTensorInfo,
1194 optionalBiasTensorInfo,
1195 descriptorName);
1196
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001197 std::vector<DataType> supportedTypes =
1198 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001199 DataType::BFloat16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001200 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001201 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001202 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001203 DataType::QSymmS16,
Keith Davis5204aa82020-01-27 15:24:59 +00001204 DataType::QSymmS8,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001205 DataType::Float16
1206 };
1207
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001208 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1209 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1210}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001211
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001212void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1213{
1214 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1215
1216 ValidateNumInputs(workloadInfo, descriptorName, 1);
1217 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1218
1219 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1220 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1221
1222 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1223 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1224
1225 ValidatePointer(m_Weight, descriptorName, "weight");
1226
1227 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1228 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1229
1230 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1231 {
1232 throw InvalidArgumentException(
1233 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
1234 "cannot be smaller than 1.") % descriptorName %
1235 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
1236 }
1237
1238 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1239
1240 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1241 // inputChannels * channelMultiplier should be equal to outputChannels.
1242 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1243 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1244 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1245 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1246 {
1247 throw InvalidArgumentException(
1248 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
1249 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
1250 "(provided %4%).") % descriptorName % numWeightOutputChannels %
1251 numWeightInputChannels % numWeightChannelMultiplier));
1252 }
1253
Teresa Charlind8df0262019-11-11 12:28:15 +00001254 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001255
Teresa Charlind8df0262019-11-11 12:28:15 +00001256 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001257 if (m_Parameters.m_BiasEnabled)
1258 {
1259 ValidatePointer(m_Bias, descriptorName, "bias");
1260
Teresa Charlind8df0262019-11-11 12:28:15 +00001261 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1262 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001263
1264 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1265 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1266 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001267 ValidatePerAxisQuantization(inputTensorInfo,
1268 outputTensorInfo,
1269 weightTensorInfo,
1270 optionalBiasTensorInfo,
1271 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001272
1273 std::vector<DataType> supportedTypes =
1274 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001275 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001276 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001277 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001278 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001279 DataType::QSymmS16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001280 DataType::Float16
1281 };
1282
1283 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1284 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001285}
1286
1287void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1288{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001289 const std::string descriptorName{"PermuteQueueDescriptor"};
1290
1291 ValidateNumInputs(workloadInfo, descriptorName, 1);
1292 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001293
1294 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1295
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001296 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1297 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001298
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001299 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1300 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001301
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001302 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001303 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001304 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001305 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001306 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1307 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1308 "must match dst dimension " + to_string(mapping[i]) +
1309 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001310 }
1311 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001312
1313 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001314}
1315
1316void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1317{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001318 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001319
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001320 ValidateNumInputs(workloadInfo, descriptorName, 1);
1321 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1322
1323 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1324 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1325
1326 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1327 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001328
1329 std::vector<DataType> supportedTypes =
1330 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001331 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001332 DataType::Float32,
1333 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001334 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001335 DataType::QAsymmU8,
1336 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001337 };
1338
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001339 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1340 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001341}
1342
1343void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1344{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001345 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001346
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001347 ValidateNumInputs(workloadInfo, descriptorName, 1);
1348 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1349
1350 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1351 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1352
1353 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1354 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001355
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001356 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001357 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001358 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001359 DataType::Float16,
1360 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001361 DataType::QAsymmU8,
1362 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001363 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001364
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001365 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1366 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001367
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001368 // ResizeBilinear only changes width and height: batch and channel count must match.
1369 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1370 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001371 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001372 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001373 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001374 boost::str(boost::format("%1%: Input batch size (%2%) "
1375 "does not match output batch size (%3%)") %
1376 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001377 }
1378
Teresa Charlin970f43b2019-07-01 13:51:07 +01001379 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001380 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1381 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001382 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001383 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001384 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001385 boost::str(boost::format("%1%: Input channel count (%2%) "
1386 "does not match output channel count (%3%)") %
1387 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001388 }
1389}
1390
1391void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1392{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001393 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001394
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001395 ValidateNumInputs(workloadInfo, descriptorName, 1);
1396 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1397
1398 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1399 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1400
1401 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1402 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001403
1404 std::vector<DataType> supportedTypes =
1405 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001406 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001407 DataType::Float16,
1408 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001409 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001410 DataType::QAsymmU8,
1411 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001412 };
1413
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001414 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1415 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001416
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001417 // Resize only changes width and height: batch and channel count must match.
1418 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1419 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001420 if (inputBatchSize != outputBatchSize)
1421 {
1422 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001423 boost::str(boost::format("%1%: Input batch size (%2%) "
1424 "does not match output batch size (%3%)") %
1425 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001426 }
1427
1428 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001429 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1430 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001431 if (inputChannelCount != outputChannelCount)
1432 {
1433 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001434 boost::str(boost::format("%1%: Input channel count (%2%) "
1435 "does not match output channel count (%3%)") %
1436 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001437 }
1438}
1439
1440void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1441{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001442 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001443
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001444 ValidateNumInputs(workloadInfo, descriptorName, 1);
1445 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1446
1447 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1448 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1449
1450 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1451 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1452
1453 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1454
telsoa014fcda012018-03-09 14:13:49 +00001455 if (m_Parameters.m_Min > m_Parameters.m_Max)
1456 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001457 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001458 }
telsoa014fcda012018-03-09 14:13:49 +00001459}
1460
Kevin Mayce5045a2019-10-02 14:07:47 +01001461void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1462{
1463 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1464
1465 ValidateNumInputs(workloadInfo, descriptorName, 1);
1466 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1467
1468 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1469 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1470
1471 if (inputTensorInfo.GetNumDimensions() > 4)
1472 {
1473 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1474 }
1475
1476 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1477
1478 // Check the supported data types
1479 std::vector<DataType> supportedTypes =
1480 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001481 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001482 DataType::Float32,
1483 DataType::Float16
1484 };
1485
1486 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001487 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001488}
1489
telsoa014fcda012018-03-09 14:13:49 +00001490void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1491{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001492 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001493
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001494 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001495 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1496
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001497 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1498 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1499
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001500 if (inputTensorInfo.GetNumDimensions() > 4)
1501 {
1502 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1503 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001504
1505 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001506
1507 // Check the supported data types
1508 std::vector<DataType> supportedTypes =
1509 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001510 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001511 DataType::Float32,
1512 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001513 DataType::QAsymmU8,
1514 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001515 };
1516
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001517 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001518 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1519}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001520
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001521void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1522{
1523 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1524
1525 ValidateNumInputs(workloadInfo, descriptorName, 1);
1526 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1527
1528 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1529 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1530
1531 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1532
1533 std::vector<DataType> supportedTypes =
1534 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001535 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001536 DataType::Float32,
1537 DataType::Float16,
1538 };
1539
1540 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001541 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001542}
1543
1544void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1545{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001546 const std::string descriptorName{"ConstantQueueDescriptor"};
1547
1548 ValidateNumInputs(workloadInfo, descriptorName, 0);
1549 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001550
1551 if (!m_LayerOutput)
1552 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001553 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001554 }
1555
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001556 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1557 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001558
1559 // Check the supported data types
1560 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001561 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001562 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001563 DataType::Float32,
1564 DataType::Float16,
1565 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001566 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001567 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +00001568 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001569 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001570 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001571
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001572 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001573}
1574
1575void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1576{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001577 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001578
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001579 ValidateNumInputs(workloadInfo, descriptorName, 1);
1580 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1581
1582 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1583 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1584
1585 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001586
1587 // Check the supported data types
1588 std::vector<DataType> supportedTypes =
1589 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001590 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001591 DataType::Float32,
1592 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001593 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001594 DataType::QSymmS16,
1595 DataType::QAsymmS8,
Keith Davis67e6c542020-02-19 10:08:33 +00001596 DataType::QAsymmU8
Nina Drozd2f2778f2019-05-27 10:37:05 +01001597 };
1598
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001599 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1600 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001601}
1602
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001603void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1604{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001605 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001606
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001607 ValidateNumInputs(workloadInfo, descriptorName, 1);
1608 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1609
1610 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1611 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1612
1613 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1614 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001615
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001616 if (m_Parameters.m_BlockShape.size() != 2)
1617 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001618 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001619 }
1620
1621 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1622 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001623 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1624 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001625 }
1626
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001627 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001628
1629 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001630 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001631
Matthew Bentham8800c002018-11-19 13:19:28 +00001632 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001633
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001634 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1635 widthPad.first + widthPad.second;
1636 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1637 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001638
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001639 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1640 inputShape[dimensionIndices.GetChannelsIndex()];
1641 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001642
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001643 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001644 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001645 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001646 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001647 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001648 }
1649
1650 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001651 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001652 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1653 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001654 }
nikraj01120522a2019-05-31 11:33:07 +01001655
1656 std::vector<DataType> supportedTypes =
1657 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001658 DataType::BFloat16,
1659 DataType::Float16,
1660 DataType::Float32,
1661 DataType::QAsymmU8,
1662 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001663 };
1664
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001665 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1666 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001667}
1668
Keith Davisa57eccb2019-06-14 17:33:22 +01001669void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1670{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001671 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001672
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001673 ValidateNumInputs(workloadInfo, descriptorName, 1);
1674 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001675
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001676 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1677 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1678
1679 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1680 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001681
1682 std::vector<DataType> supportedTypes =
1683 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001684 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001685 DataType::Float32,
1686 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001687 DataType::QAsymmU8,
1688 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001689 };
1690
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001691 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1692 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001693
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001694 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1695
1696 if (m_Parameters.m_BlockSize == 0)
1697 {
1698 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1699 }
1700
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001701 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1702 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1703 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1704 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001705
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001706 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001707 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001708 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001709 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1710 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001711 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001712
1713 const TensorShape& outputShape = outputTensorInfo.GetShape();
1714 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1715 {
1716 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1717 "must be divisible by the square of block size." );
1718 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001719}
1720
telsoa014fcda012018-03-09 14:13:49 +00001721void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1722{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001723 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001724
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001725 ValidateNumInputs(workloadInfo, descriptorName, 1);
1726 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1727
1728 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1729 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001730
1731 std::vector<DataType> supportedTypes =
1732 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001733 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001734 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001735 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001736 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001737 };
1738
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001739 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001740
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001741 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001742 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001743 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001744 }
1745}
1746
telsoa01c577f2c2018-08-31 09:22:23 +01001747void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1748{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001749 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1750
1751 const std::string descriptorName{"LstmQueueDescriptor"};
1752
1753 // check dimensions of all inputs and outputs
1754 if (workloadInfo.m_InputTensorInfos.size() != 3)
1755 {
1756 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1757 }
1758 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1759 {
1760 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1761 }
1762
1763 std::vector<DataType> supportedTypes =
1764 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001765 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001766 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001767 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001768 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001769 };
1770
Jan Eilers38e05bd2019-06-26 13:10:09 +01001771 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001772 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1773
Jan Eilers38e05bd2019-06-26 13:10:09 +01001774 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001775 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001776 {
1777 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1778 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001779 descriptorName,
1780 "input_0",
1781 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001782 }
1783 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001784 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001785 {
1786 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1787 workloadInfo.m_OutputTensorInfos[i],
1788 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001789 "input_0",
1790 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001791 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001792
janeil0117d8d852019-11-15 15:00:16 +00001793 // Making sure clipping parameters have valid values.
1794 // == 0 means no clipping
1795 // > 0 means clipping
1796 if (m_Parameters.m_ClippingThresCell < 0.0f)
1797 {
1798 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
1799 }
1800 if (m_Parameters.m_ClippingThresProj < 0.0f)
1801 {
1802 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
1803 }
1804
Jan Eilers38e05bd2019-06-26 13:10:09 +01001805
1806 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01001807 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1808 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1809 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1810 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1811 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1812 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1813
Jan Eilers38e05bd2019-06-26 13:10:09 +01001814 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001815 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1816 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001817 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001818 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1819 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001820 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001821 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1822 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001823 // scratchBufferTensor
1824 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001825 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1826 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001827 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001828 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1829 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001830 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001831 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1832 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001833 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001834 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1835 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001836
1837
1838 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1839 if ( m_InputToInputWeights )
1840 {
1841 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1842 (n_cell * n_input), "InputLayerNormWeights");
1843 }
1844
1845 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1846 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1847 (n_cell * n_input), "InputToForgetWeights");
1848
1849 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1850 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1851 (n_cell * n_input), "InputToCellWeights");
1852
1853 if ( m_RecurrentToInputWeights )
1854 {
1855 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1856 (n_cell * n_output), "RecurrentToInputWeights");
1857 }
1858
1859 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1860 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1861 (n_cell * n_output), "RecurrentToForgetWeights");
1862
1863 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1864 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1865 (n_cell * n_output), "RecurrentToCellWeights");
1866
1867 // Make sure the input-gate's parameters are either both present (regular
1868 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1869 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1870 !m_Parameters.m_CifgEnabled) ||
1871 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1872 m_Parameters.m_CifgEnabled));
1873 if (!cifg_weights_all_or_none)
1874 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001875 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1876 "RecurrentToInputWeights must either both be present (regular LSTM) "
1877 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1878 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001879 }
1880
1881 if ( m_CellToInputWeights )
1882 {
1883 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1884 n_cell, "CellToInputWeights");
1885 }
1886 if ( m_CellToForgetWeights )
1887 {
1888 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1889 n_cell, "CellToForgetWeights");
1890 }
1891 if ( m_CellToOutputWeights )
1892 {
1893 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1894 n_cell, "CellToOutputWeights");
1895 }
1896
1897 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1898 bool peephole_weights_all_or_none =
1899 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1900 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1901 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1902 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1903 if (!peephole_weights_all_or_none)
1904 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001905 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001906 }
1907
1908 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1909 if (m_Parameters.m_CifgEnabled)
1910 {
1911 if (m_InputGateBias)
1912 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001913 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001914 }
1915 }
1916 else
1917 {
1918 if (!m_InputGateBias)
1919 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001920 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1921 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001922 }
1923 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1924 n_cell, "InputGateBias");
1925 }
1926
1927 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1928 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1929
1930 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1931 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1932
1933 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1934 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1935
1936 if (m_ProjectionWeights)
1937 {
1938 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
1939 (n_cell * n_output), "ProjectionWeights");
1940 }
1941 if (m_ProjectionBias)
1942 {
1943 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
1944 }
1945
1946 // Making sure the projection tensors are consistent:
1947 // 1) If projection weight is not present, then projection bias should not be
1948 // present.
1949 // 2) If projection weight is present, then projection bias is optional.
1950 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
1951 !m_Parameters.m_ProjectionEnabled)
1952 || (m_ProjectionWeights && !m_ProjectionBias &&
1953 m_Parameters.m_ProjectionEnabled)
1954 || (m_ProjectionWeights && m_ProjectionBias &&
1955 m_Parameters.m_ProjectionEnabled));
1956 if (!projecton_tensors_consistent)
1957 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001958 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001959 }
1960
1961 // The four layer normalization weights either all have values or none of them have values. Additionally, if
1962 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
1963 // either all have values or none of them have values. Layer normalization is used when the values of all the
1964 // layer normalization weights are present
1965 if (m_InputLayerNormWeights)
1966 {
1967 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
1968 }
1969 if (m_ForgetLayerNormWeights)
1970 {
1971 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1972 }
1973 if (m_CellLayerNormWeights)
1974 {
1975 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1976 }
1977 if (m_OutputLayerNormWeights)
1978 {
1979 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1980 }
1981
Jan Eilers38e05bd2019-06-26 13:10:09 +01001982 if (m_Parameters.m_LayerNormEnabled)
1983 {
1984 if (!m_Parameters.m_CifgEnabled)
1985 {
1986 if (!m_InputLayerNormWeights)
1987 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001988 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
1989 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001990 }
1991 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
1992 1, n_cell, "InputLayerNormWeights");
1993 }
1994 else if (m_InputLayerNormWeights)
1995 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001996 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
1997 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001998 }
1999
2000 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2001 "ForgetLayerNormWeights");
2002 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2003
2004 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2005 "OutputLayerNormWeights");
2006 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2007
2008 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2009 "CellLayerNormWeights");
2010 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2011 }
2012 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2013 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002014 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2015 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002016 }
telsoa01c577f2c2018-08-31 09:22:23 +01002017}
2018
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002019void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2020{
2021 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2022
2023 ValidateNumInputs(workloadInfo, descriptorName, 1);
2024 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2025
2026 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2027 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2028
2029 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2030 {
2031 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2032 }
2033
2034 if (outputTensorInfo.GetDataType() != DataType::Float32)
2035 {
2036 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2037 }
2038
2039 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2040}
2041
telsoa01c577f2c2018-08-31 09:22:23 +01002042void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2043{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002044 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002045
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002046 ValidateNumInputs(workloadInfo, descriptorName, 1);
2047 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2048
2049 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2050 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2051
2052 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002053 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002054 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002055 }
2056
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002057 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002058 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002059 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002060 }
2061
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002062 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002063}
2064
2065void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2066{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002067 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002068
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002069 ValidateNumInputs(workloadInfo, descriptorName, 1);
2070 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2071
2072 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2073 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2074
2075 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002076 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002077 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002078 }
2079
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002080 if (outputTensorInfo.GetDataType() != DataType::Float32)
2081 {
2082 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2083 }
2084
2085 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002086}
2087
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002088void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2089{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002090 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002091
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002092 ValidateNumInputs(workloadInfo, descriptorName, 2);
2093 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2094
2095 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2096 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2097 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2098
2099 std::vector<DataType> supportedTypes =
2100 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002101 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002102 DataType::QAsymmU8,
2103 DataType::QSymmS16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002104 DataType::Float16,
2105 DataType::BFloat16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002106 };
2107
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002108 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2109 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2110 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002111
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002112 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2113 inputTensorInfo1,
2114 outputTensorInfo,
2115 descriptorName,
2116 "input_0",
2117 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002118}
2119
David Beckc2044fe2018-09-05 15:00:38 +01002120void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2121{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002122 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002123
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002124 ValidateNumInputs(workloadInfo, descriptorName, 2);
2125 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2126
2127 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2128 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2129 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2130
2131 std::vector<DataType> supportedTypes =
2132 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002133 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002134 DataType::QAsymmU8,
2135 DataType::QSymmS16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002136 DataType::Float16,
2137 DataType::BFloat16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002138 };
2139
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002140 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2141 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2142 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002143
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002144 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2145 inputTensorInfo1,
2146 outputTensorInfo,
2147 descriptorName,
2148 "input_0",
2149 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002150}
2151
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002152void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2153{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002154 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002155
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002156 ValidateNumInputs(workloadInfo, descriptorName, 2);
2157 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2158
2159 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2160 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2161 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2162
2163 std::vector<DataType> supportedTypes =
2164 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002165 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002166 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002167 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002168 DataType::Signed32,
Keith Davis67e6c542020-02-19 10:08:33 +00002169 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002170 DataType::QAsymmU8,
2171 DataType::QSymmS16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002172 };
2173
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002174 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2175 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2176 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002177
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002178 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2179 inputTensorInfo1,
2180 outputTensorInfo,
2181 descriptorName,
2182 "input_0",
2183 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002184}
2185
narpra01a6bf9122018-09-10 09:50:09 +01002186void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2187{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002188 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002189
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002190 ValidateNumInputs(workloadInfo, descriptorName, 1);
2191 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2192
2193 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2194 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002195
2196 std::vector<DataType> supportedTypes =
2197 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002198 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002199 DataType::Float32,
2200 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002201 DataType::QAsymmU8,
2202 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002203 };
narpra01eb061912018-09-10 17:35:27 +01002204
James Conroy4d1ff582019-06-10 17:06:39 +01002205 // First check if input tensor data type is supported, then
2206 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002207 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2208 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002209
narpra0132b90462018-09-13 11:07:48 +01002210 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002211 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002212 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002213 }
narpra0132b90462018-09-13 11:07:48 +01002214 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002215 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002216 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002217 }
2218 else
2219 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002220 unsigned int outputDim =
2221 inputTensorInfo.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
2222 ValidateTensorNumDimensions(outputTensorInfo,
2223 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002224 outputDim > 0 ? outputDim : 1,
2225 "output");
2226 }
narpra01a6bf9122018-09-10 09:50:09 +01002227}
2228
jimfly012c9322a2018-09-19 10:59:49 +01002229void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2230{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002231 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002232
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002233 ValidateNumInputs(workloadInfo, descriptorName, 1);
2234 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2235
2236 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2237 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002238
jimfly012c9322a2018-09-19 10:59:49 +01002239 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002240 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2241
jimfly012c9322a2018-09-19 10:59:49 +01002242 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002243 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2244 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2245 "as there are dimensions in the input tensor that is " +
2246 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2247 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002248 }
2249}
2250
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002251void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2252{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002253 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002254
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002255 ValidateNumInputs(workloadInfo, descriptorName, 1);
2256 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002257
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002258 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2259 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2260
Sadik Armagan2208b602019-07-31 16:36:27 +01002261 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002262 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002263 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002264 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002265 DataType::Float16,
2266 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002267 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002268 DataType::QAsymmU8,
2269 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002270 };
2271
2272 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002273
Keith Davis0c2eeac2020-02-11 16:51:50 +00002274 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002275 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002276 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002277 }
2278}
2279
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002280void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2281{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002282 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002283
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002284 ValidateNumInputs(workloadInfo, descriptorName, 1);
2285 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002286
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002287 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2288 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002289
2290 std::vector<DataType> supportedTypes =
2291 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002292 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002293 DataType::Float32,
2294 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002295 DataType::QAsymmU8,
2296 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002297 };
2298
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002299 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2300 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002301}
2302
Conor Kennedy430b5d82018-11-14 15:28:28 +00002303void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2304{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002305 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002306
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002307 ValidateNumInputs(workloadInfo, descriptorName, 1);
2308 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2309
2310 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2311 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002312
2313 std::vector<DataType> supportedTypes =
2314 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002315 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002316 DataType::Float16,
2317 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002318 DataType::QAsymmU8,
2319 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002320 };
2321
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002322 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2323 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002324
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002325 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002326
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002327 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002328 if (rank > 4)
2329 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002330 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002331 }
2332
Conor Kennedy430b5d82018-11-14 15:28:28 +00002333 // Begin, End & Stride length must be of rank(input0)
2334 if (m_Parameters.m_Begin.size() != rank)
2335 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002336 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002337 }
2338
2339 if (m_Parameters.m_End.size() != rank)
2340 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002341 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002342 }
2343
2344 if (m_Parameters.m_Stride.size() != rank)
2345 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002346 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002347 }
2348
2349 // Stride entries must be non-zero
2350 for (auto& stride : m_Parameters.m_Stride)
2351 {
2352 if (stride == 0)
2353 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002354 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002355 }
2356 }
2357}
2358
kevmay0190539692018-11-29 08:40:19 +00002359void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2360{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002361 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002362
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002363 ValidateNumInputs(workloadInfo, descriptorName, 2);
2364 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2365
2366 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2367 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2368 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2369
2370 std::vector<DataType> supportedTypes =
2371 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002372 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002373 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002374 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002375 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002376 DataType::QAsymmU8,
2377 DataType::QSymmS16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002378 };
2379
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002380 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2381 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2382 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002383
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002384 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2385 inputTensorInfo1,
2386 outputTensorInfo,
2387 descriptorName,
2388 "input_0",
2389 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002390}
2391
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002392void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2393{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002394 const std::string descriptorName{"DebugQueueDescriptor"};
2395
2396 ValidateNumInputs(workloadInfo, descriptorName, 1);
2397 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002398}
2399
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002400void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2401{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002402 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002403
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002404 ValidateNumInputs(workloadInfo, descriptorName, 2);
2405 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002406
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002407 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2408 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2409 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2410
2411 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2412 inputTensorInfo1,
2413 outputTensorInfo,
2414 descriptorName,
2415 "input_0",
2416 "input_1");
2417
2418 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002419 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002420 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002421 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002422}
2423
FrancisMurtagh878f0232018-12-19 10:56:15 +00002424void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2425{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002426 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002427
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002428 ValidateNumInputs(workloadInfo, descriptorName, 2);
2429 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002430
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002431 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2432 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2433 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2434
2435 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2436 inputTensorInfo1,
2437 outputTensorInfo,
2438 descriptorName,
2439 "input_0",
2440 "input_1");
2441
2442 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002443 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002444 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002445 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002446}
2447
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002448void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2449{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002450 const std::string descriptorName{"RsqrtQueueDescriptor"};
2451
2452 ValidateNumInputs(workloadInfo, descriptorName, 1);
2453 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2454
2455 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2456 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2457
2458 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002459
2460 std::vector<DataType> supportedTypes =
2461 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002462 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002463 DataType::Float16,
2464 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002465 DataType::QAsymmU8,
2466 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002467 };
2468
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002469 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2470 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002471}
2472
narpra01b89b05f2019-01-16 09:53:09 +00002473void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2474{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002475 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002476
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002477 ValidateNumInputs(workloadInfo, descriptorName, 2);
2478 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002479
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002480 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2481 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002482 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002483 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002484 }
2485
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002486 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2487 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2488
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002489 std::vector<DataType> supportedTypes =
2490 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002491 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002492 DataType::Float16,
2493 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002494 DataType::QAsymmU8,
2495 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002496 };
2497
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002498 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002499
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002500 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002501
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002502 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2503 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002504}
2505
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002506void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2507{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002508 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2509
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002510 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002511
2512 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2513 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002514 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002515 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2516 }
2517
2518 if (m_Anchors == nullptr)
2519 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002520 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002521 }
2522
2523 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002524 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2525 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2526
2527 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002528 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002529 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2530 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002531
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002532 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2533 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2534 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002535
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002536 const std::vector<DataType> supportedInputTypes =
2537 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002538 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002539 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002540 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002541 DataType::QAsymmU8,
2542 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002543 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002544
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002545 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2546 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2547 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2548
2549 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2550 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2551 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2552 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2553
2554 // NOTE: Output is always Float32 regardless of input type
2555 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2556 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2557 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2558 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002559
2560 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2561 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002562 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002563 "must be positive and less than or equal to 1.");
2564 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002565
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002566 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2567 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002568 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002569 "should be equal to number of classes + 1.");
2570 }
2571}
2572
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002573void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2574{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002575 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002576
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002577 ValidateNumInputs(workloadInfo, descriptorName, 1);
2578 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2579
2580 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2581 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2582
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002583 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002584 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002585 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002586 }
2587
Sadik Armagan2208b602019-07-31 16:36:27 +01002588 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002589 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002590 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002591 DataType::Float32,
2592 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002593 };
2594
2595 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002596}
2597
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002598void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2599{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002600 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002601
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002602 ValidateNumInputs(workloadInfo, descriptorName, 2);
2603 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002604
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002605 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2606 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2607 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002608
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002609 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2610 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2611
2612 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2613 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002614}
2615
Sadik Armaganeff363d2019-04-05 15:25:46 +01002616void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2617{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002618 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002619
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002620 ValidateNumInputs(workloadInfo, descriptorName, 2);
2621 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2622
2623 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2624 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2625
2626 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2627 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2628
2629 std::vector<DataType> supportedTypes =
2630 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002631 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002632 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002633 DataType::QAsymmU8,
2634 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002635 };
2636
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002637 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2638 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002639
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002640 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2641 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002642
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002643 ValidateTensorShapesMatch(inputTensorInfo0,
2644 outputTensorInfo0,
2645 descriptorName,
2646 "input_0",
2647 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002648
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002649 ValidateTensorShapesMatch(inputTensorInfo0,
2650 outputTensorInfo1,
2651 descriptorName,
2652 "input_0",
2653 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002654}
2655
Derek Lamberti901ea112019-12-10 22:07:09 +00002656void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002657{
2658 // This is internally generated so it should not need validation.
2659}
2660
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002661void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2662{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002663 const std::string& descriptorName{"PreluQueueDescriptor"};
2664
2665 ValidateNumInputs(workloadInfo, descriptorName, 2);
2666 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2667
2668 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2669 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2670 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002671
2672 std::vector<DataType> supportedTypes
2673 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002674 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002675 DataType::Float16,
2676 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002677 DataType::QAsymmU8,
2678 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002679 };
2680
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002681 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2682 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002683
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002684 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002685
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002686 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2687 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002688
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002689 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2690 alphaTensorInfo,
2691 outputTensorInfo,
2692 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002693 "input",
2694 "alpha");
2695}
2696
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002697void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2698{
2699 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2700
2701 ValidateNumInputs(workloadInfo, descriptorName, 1);
2702 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2703
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002704 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2705 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2706
2707 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2708 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002709
2710 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002711
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002712 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2713 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002714
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002715 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2716
2717 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002718 if (m_Parameters.m_BiasEnabled)
2719 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002720 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002721
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002722 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2723 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002724
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002725 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002726 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002727 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002728
2729 ValidatePerAxisQuantization(inputTensorInfo,
2730 outputTensorInfo,
2731 weightTensorInfo,
2732 optionalBiasTensorInfo,
2733 descriptorName);
2734
2735 std::vector<DataType> supportedTypes =
2736 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002737 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002738 DataType::Float32,
2739 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002740 DataType::QAsymmU8,
2741 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002742 };
2743
2744 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2745 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002746}
2747
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002748void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2749{
2750 const std::string descriptorName{"TransposeQueueDescriptor"};
2751
2752 ValidateNumInputs(workloadInfo, descriptorName, 1);
2753 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2754
2755 const PermutationVector& mapping = m_Parameters.m_DimMappings;
2756
2757 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2758 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2759
2760 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
2761 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
2762
2763 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
2764 {
2765 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
2766 {
2767 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
2768 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
2769 "must match dst dimension " + to_string(i) +
2770 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
2771 }
2772 }
2773
2774 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2775}
2776
James Conroy9c3cae82019-08-01 16:01:48 +01002777void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2778{
2779 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
2780
2781 // Validate number of inputs/outputs
2782 ValidateNumInputs(workloadInfo, descriptorName, 3);
2783 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2784
2785 // Input/output tensor infos
2786 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2787 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
2788 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
2789
2790 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2791 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2792
2793 std::vector<DataType> inputOutputSupportedTypes =
2794 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002795 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01002796 };
2797
2798 std::vector<DataType> cellStateSupportedTypes =
2799 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002800 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01002801 };
2802
2803 std::vector<DataType> weightsSupportedTypes =
2804 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002805 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01002806 };
2807
2808 std::vector<DataType> biasSupportedTypes =
2809 {
2810 DataType::Signed32
2811 };
2812
2813 // Validate types of input/output tensors
2814 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2815 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2816 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2817
2818 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2819 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2820
2821 // Validate matching types of input/output tensors
2822 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2823 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2824 "outputStateIn", "outputStateOut");
2825 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2826
2827 // Validate matching quantization info for input/output tensors
2828 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2829 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
2830 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002831
James Conroy9c3cae82019-08-01 16:01:48 +01002832 // Infer number of batches, input size and output size from tensor dimensions
2833 const uint32_t numBatches = inputInfo.GetShape()[0];
2834 const uint32_t inputSize = inputInfo.GetShape()[1];
2835 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
2836
2837 // Validate number of dimensions and number of elements for input/output tensors
2838 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2839 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
2840 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2841 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
2842 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2843
2844 // Validate number of dimensions and number of elements for weights tensors
2845 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
2846 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
2847 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
2848
2849 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2850 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2851 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
2852
2853 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2854 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2855 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
2856
2857 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2858 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2859 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
2860
2861 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
2862 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
2863 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
2864
2865 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2866 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2867 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
2868 " RecurrentToForgetWeights");
2869
2870 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2871 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2872 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2873
2874 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2875 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2876 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2877
2878 // Validate data types for weights tensors (all should match each other)
2879 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
2880
2881 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
2882 "inputToInputWeights", "inputToForgetWeights");
2883 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
2884 "inputToInputWeights", "inputToCellWeights");
2885 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2886 "inputToInputWeights", "inputToOutputWeights");
2887
2888 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
2889 "inputToInputWeights", "recurrentToInputWeights");
2890 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2891 "inputToInputWeights", "recurrentToForgeteights");
2892 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2893 "inputToInputWeights", "recurrentToCellWeights");
2894 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2895 "inputToInputWeights", "recurrentToOutputWeights");
2896
2897 // Validate matching quantization info for weight tensors (all should match each other)
2898 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
2899 descriptorName, "inputToInputWeights", "inputToForgetWeights");
2900 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
2901 descriptorName, "inputToInputWeights", "inputToCellWeights");
2902 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
2903 descriptorName, "inputToInputWeights", "inputToOutputWeights");
2904
2905 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
2906 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
2907 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
2908 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
2909 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
2910 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
2911 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
2912 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
2913
2914 // Validate number of dimensions and number of elements in bias tensors
2915 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
2916 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
2917 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
2918
2919 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2920 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2921 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
2922
2923 ValidatePointer(m_CellBias, descriptorName, "CellBias");
2924 auto cellBiasInfo = m_CellBias->GetTensorInfo();
2925 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
2926
2927 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
2928 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
2929 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
2930
2931 // Validate data types for bias tensors (all should match each other)
2932 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
2933
2934 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
2935 "inputGateBias", "forgetGateBias");
2936 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
2937 "inputGateBias", "cellBias");
2938 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
2939 "inputGateBias", "outputGateBias");
2940
2941 // Validate bias tensor quantization info
2942 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2943 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2944 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2945 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2946}
2947
Kevin May868eb142019-09-04 17:29:31 +01002948void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2949{
2950 const std::string descriptorName{"AbsQueueDescriptor"};
2951
2952 ValidateNumInputs(workloadInfo, descriptorName, 1);
2953 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2954
2955 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2956 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2957
2958 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2959
2960 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01002961 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002962 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002963 DataType::Float16,
2964 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002965 DataType::QAsymmU8,
2966 DataType::QSymmS16
James Conroyd47a0642019-09-17 14:22:06 +01002967 };
Kevin May868eb142019-09-04 17:29:31 +01002968
2969 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2970 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2971}
2972
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002973void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2974{
2975 const std::string descriptorName{"SliceQueueDescriptor"};
2976
2977 ValidateNumInputs(workloadInfo, descriptorName, 1);
2978 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2979
2980 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2981 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2982
2983 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2984
2985 const unsigned int rank = inputTensorInfo.GetNumDimensions();
2986 if (rank > 4)
2987 {
2988 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
2989 }
2990
2991 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
2992
2993 // Check if m_Begin and m_Size have the expected length
2994 if (m_Parameters.m_Begin.size() != rank)
2995 {
2996 throw InvalidArgumentException(descriptorName +
2997 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
2998 }
2999 if (m_Parameters.m_Size.size() != rank)
3000 {
3001 throw InvalidArgumentException(descriptorName +
3002 ": Length of size descriptor must equal rank " + std::to_string(rank));
3003 }
3004
3005 // Check if the shape of the output tensor matches m_Size
3006 const TensorShape& outputShape = outputTensorInfo.GetShape();
3007 for (unsigned int i = 0u; i < rank; ++i)
3008 {
3009 if (m_Parameters.m_Size[i] != outputShape[i])
3010 {
3011 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3012 }
3013 }
3014
3015 // Check if the sum of begin offset and size in a given dimension
3016 // does not exceed the size of corresponding input
3017 const TensorShape& inputShape = inputTensorInfo.GetShape();
3018 for(unsigned int i = 0u; i < rank; ++i)
3019 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003020 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003021 {
3022 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3023 std::to_string(i) + " exceeds input size.");
3024 }
3025 }
3026}
3027
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003028void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3029{
3030 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3031
3032 ValidateNumInputs(workloadInfo, descriptorName, 1);
3033 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3034
3035 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3036 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3037
3038 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3039 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3040
3041 std::vector<DataType> supportedTypes =
3042 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003043 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003044 DataType::Float32,
3045 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003046 DataType::QAsymmU8,
3047 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003048 };
3049
3050 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3051 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3052
3053 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3054
3055 if (m_Parameters.m_BlockSize == 0)
3056 {
3057 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3058 }
3059
3060 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3061 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3062 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3063 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3064
3065 const TensorShape& outputShape = outputInfo.GetShape();
3066 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3067 {
3068 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3069 "must be divisible by block size.");
3070 }
3071
3072 const TensorShape& inputShape = inputInfo.GetShape();
3073 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3074 {
3075 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3076 "must be divisible by the square of block size." );
3077 }
3078}
3079
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003080void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3081{
3082 const std::string descriptorName{"ComparisonQueueDescriptor"};
3083
3084 ValidateNumInputs(workloadInfo, descriptorName, 2);
3085 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3086
3087 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3088 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3089 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3090
3091 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3092 inputTensorInfo1,
3093 outputTensorInfo,
3094 descriptorName,
3095 "input_0",
3096 "input_1");
3097
3098 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3099 {
3100 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3101 }
3102}
3103
josh minor4a3c6102020-01-06 16:40:46 -06003104void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3105{
3106 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3107
3108 ValidateNumInputs(workloadInfo, descriptorName, 1);
3109 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3110
3111 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3112 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3113
3114 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3115
3116 std::vector<DataType> supportedTypes =
3117 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003118 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003119 DataType::Float16,
3120 DataType::Float32,
3121 DataType::QAsymmU8,
3122 DataType::QSymmS16
3123 };
3124
3125 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3126 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3127}
3128
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003129} // namespace armnn