blob: 85c074a500cf82581285204ab02e7f746c199862 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe011d202019-11-28 11:35:47 +00005
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00006#include <backendsCommon/WorkloadData.hpp>
7#include <backendsCommon/CpuTensorHandle.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/TensorUtils.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000013#include <string>
14#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/format.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010017#include <boost/numeric/conversion/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +000029 case DataType::BFloat16:
30 return DataType::BFloat16;
telsoa01c577f2c2018-08-31 09:22:23 +010031 case DataType::Float16:
32 return DataType::Float16;
telsoa014fcda012018-03-09 14:13:49 +000033 case DataType::Float32:
34 return DataType::Float32;
Keith Davis0c2eeac2020-02-11 16:51:50 +000035 case DataType::QAsymmS8:
36 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000037 case DataType::QAsymmU8:
telsoa014fcda012018-03-09 14:13:49 +000038 return DataType::Signed32;
Keith Davis5204aa82020-01-27 15:24:59 +000039 case DataType::QSymmS8:
40 return DataType::Signed32;
Derek Lambertif90c56d2020-01-10 17:14:08 +000041 case DataType::QSymmS16:
Ruomei Yan88d44b82019-05-23 14:29:06 +010042 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000043 default:
44 BOOST_ASSERT_MSG(false, "Invalid input data type");
45 return DataType::Float32;
46 }
47}
48
49namespace
50{
51
52//---------------------------------------------------------------
53//android ndk does not support std::to_string function.
54template <typename T>
55std::string to_string(T value)
56{
57 std::ostringstream os;
58 os << value;
59 return os.str();
60}
61
62//---------------------------------------------------------------
63void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
64{
65 if (!ptr)
66 {
67 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
68 paramName + " parameter must be set.");
69 }
70}
71
72//---------------------------------------------------------------
73void ValidateTensorShapesMatch(const TensorInfo& first,
74 const TensorInfo& second,
75 std::string const& descName,
76 std::string const& firstName,
77 std::string const& secondName)
78{
79 if (first.GetShape() != second.GetShape())
80 {
81 throw InvalidArgumentException(descName + ": "
82 + firstName + " & " + secondName + " must have identical shapes");
83 }
84}
85
86//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010087void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000088{
Sadik Armaganeff363d2019-04-05 15:25:46 +010089 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000090 {
91 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010092 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000093 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
94 }
95}
96
97//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010098void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000099{
Sadik Armaganeff363d2019-04-05 15:25:46 +0100100 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +0000101 {
102 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +0100103 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +0000104 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
105 }
106}
107
108//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100109void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000110 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100111 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000112 std::string const& tensorName)
113{
114 if (tensor.GetNumDimensions() != numDimensions)
115 {
116 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
117 to_string(tensor.GetNumDimensions()) + " dimensions for " +
118 tensorName + " tensor.");
119 }
120}
121
122//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100123void ValidateTensorNumElements(const TensorInfo& tensor,
124 std::string const& descName,
125 unsigned int numElements,
126 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100127{
128 if (tensor.GetNumElements() != numElements)
129 {
130 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100131 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100132 tensorName + " tensor.");
133 }
134}
135
136//---------------------------------------------------------------
137void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100138 unsigned int numDimension,
139 unsigned int numElements,
140 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100141{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100142 const std::string functionName{"ValidateTensorNumDimNumElem"};
143 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
144 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100145}
146
147//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000148void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
149 const std::string& descName, std::string const& tensorName)
150{
151 if (tensor.GetDataType() != dataType)
152 {
153 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
154 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
155 }
156}
157
Derek Lambertid466a542020-01-22 15:37:29 +0000158void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
159{
160 ARMNN_NO_DEPRECATE_WARN_BEGIN
161 if (tensor.GetDataType() != DataType::QSymmS8 &&
162 tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
163 {
164 throw InvalidArgumentException(descName +
165 ": Expected data type which supports per-axis quantization scheme but got " +
166 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
167 }
168 ARMNN_NO_DEPRECATE_WARN_END
169}
170
telsoa014fcda012018-03-09 14:13:49 +0000171//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100172void ValidateTensorQuantizationSpace(const TensorInfo& first,
173 const TensorInfo& second,
174 const std::string& descName,
175 std::string const& firstName,
176 std::string const& secondName)
177{
178 if (!first.IsQuantized() ||
179 !second.IsQuantized())
180 {
181 // Not a quantized type, ignore the validation
182 return;
183 }
184
185 DataType firstDataType = first.GetDataType();
186 DataType secondDataType = second.GetDataType();
187
188 if (firstDataType != secondDataType)
189 {
190 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
191 " must be of the same quantized type, " +
192 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
193 secondName + " is " + GetDataTypeName(secondDataType));
194 }
195
196 if (!first.IsTypeSpaceMatch(second))
197 {
198 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
199 " must have the same quantization space, " +
200 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
201 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
202 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
203 " and scale " + to_string(second.GetQuantizationScale()));
204 }
205}
206
207//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100208void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
209 const TensorInfo& inputTensorInfo,
210 const TensorInfo& weightsTensorInfo,
211 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000212{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000213 // Helper lambda function to validate a single bias quantization scale value
214 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
215 {
ricbur013f4d7102019-10-31 16:22:18 +0000216 constexpr float tolerance = 0.000001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000217 if (std::abs(biasScale - expectedScale) > tolerance)
218 {
219 // Print the float values with extra precision to see very small differences
220 std::stringstream msg;
221 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
222 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
223 biasScale;
224 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
225 }
226 };
227
telsoa014fcda012018-03-09 14:13:49 +0000228 if (biasTensor.GetQuantizationOffset() != 0)
229 {
230 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
231 to_string(biasTensor.GetQuantizationOffset()));
232 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000233
234 if (biasTensor.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000235 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000236 // Validate per-axis quantization scales
237 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
238 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
239
240 if (weightScales.size() != biasScales.size())
241 {
242 std::stringstream msg;
243 msg << descName << ": Expected matchhing number of per-axis quantization scales, but got different "
244 << "values: weights=" << weightScales.size() << ", biases=" << biasScales.size();
245 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
246 }
247
248 for (size_t i = 0ul; i < biasScales.size(); ++i)
249 {
250 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
251 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
252 }
253 }
254 else
255 {
256 // Validate per-tensor quantization scale
257 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
258 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000259 }
260}
261
262//---------------------------------------------------------------
263void ValidateTensors(const std::vector<ITensorHandle*>& vec,
264 unsigned int numExpected,
265 const std::string& descName,
266 const std::string& varName)
267{
268 if (vec.empty() && numExpected > 0)
269 {
270 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
271 }
272
273 for (unsigned int i = 0; i < numExpected; ++i)
274 {
275 if (!vec[i])
276 {
277 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
278 }
279 }
280}
281
282//---------------------------------------------------------------
283void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
284 const TensorInfo& second,
285 const TensorInfo& output,
286 std::string const& descName,
287 std::string const& firstName,
288 std::string const& secondName)
289{
290 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
291 // broadcasted.
292 if (first.GetNumDimensions() != second.GetNumDimensions())
293 {
294 throw InvalidArgumentException(descName + ": Tensors "
295 + firstName + " & " + secondName
296 + " must have the same number of dimensions in order to be broadcasted");
297 }
298 uint32_t numDims = first.GetNumDimensions();
299 std::vector<uint32_t> outputDims(numDims, 0u);
300 for (uint32_t i = 0; i < numDims; i++)
301 {
302 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
303 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
304 if (dimsNotEqual && dimsNotOne)
305 {
306 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
307 }
308 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
309 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100310 TensorShape broadcastShape = TensorShape(boost::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000311 if (broadcastShape != output.GetShape())
312 {
313 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
314 + firstName + " & " + secondName
315 + " does not match the output shape");
316 }
317}
318
319//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100320void ValidateDataTypes(const TensorInfo& info,
321 const std::vector<armnn::DataType>& supportedTypes,
322 std::string const& descName)
323{
324 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
325 if (iterator == supportedTypes.end())
326 {
327 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
328 }
329}
330
James Conroy4d1ff582019-06-10 17:06:39 +0100331//---------------------------------------------------------------
332void ValidateTensorDataTypesMatch(const TensorInfo& first,
333 const TensorInfo& second,
334 std::string const& descName,
335 std::string const& firstName,
336 std::string const& secondName)
337{
338 if (first.GetDataType() != second.GetDataType())
339 {
340 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
341 " must have identical data types.");
342 }
343}
344
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100345//---------------------------------------------------------------
346void ValidateTensorNumElementsMatch(const TensorInfo& first,
347 const TensorInfo& second,
348 std::string const& descName,
349 std::string const& firstName,
350 std::string const& secondName)
351{
352 if (first.GetNumElements() != second.GetNumElements())
353 {
354 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
355 " must have the same number of elements.");
356 }
357}
358
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000359void ValidateWeightDataType(const TensorInfo& inputInfo,
360 const TensorInfo& weightInfo,
361 const std::string& descName)
362{
363 const DataType inputType = inputInfo.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000364 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000365 {
Derek Lambertid466a542020-01-22 15:37:29 +0000366 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000367 const std::vector<DataType> validTypes =
368 {
Derek Lambertif90c56d2020-01-10 17:14:08 +0000369 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000370 DataType::QAsymmS8,
Derek Lambertid466a542020-01-22 15:37:29 +0000371 DataType::QSymmS8,
372 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000373 };
Derek Lambertid466a542020-01-22 15:37:29 +0000374 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000375
376 ValidateDataTypes(weightInfo, validTypes, descName);
377 }
378 else
379 {
380 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
381 }
382}
383
384void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
385 const std::string& descName,
386 const std::string& tensorName)
387{
388 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
389 if (!quantizationDim.has_value())
390 {
391 throw InvalidArgumentException(boost::str(
392 boost::format("%1%: Quantization dimension for per-axis quantization not set on tensor %2%.")
393 % descName % tensorName));
394 }
395
396 if (quantizationDim.value() != 0)
397 {
398 throw InvalidArgumentException(boost::str(
399 boost::format("%1%: Quantization dimension for per-axis quantization expected to be 0 on tensor %2%, "
400 "but got: %3%") % descName % tensorName % quantizationDim.value()));
401 }
402}
403
404void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
405 const std::string& descName,
406 const std::string& tensorName)
407{
408 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
409 if (quantizationOffset != 0)
410 {
411 throw InvalidArgumentException(boost::str(
412 boost::format("%1%: Quantization offset for per-axis quantization expected to be 0 on tensor %2%, "
413 "but got: %3%") % descName % tensorName % quantizationOffset));
414 }
415}
416
417void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
418 const TensorInfo& outputInfo,
419 const TensorInfo& weightInfo,
420 const Optional<TensorInfo>& optionalBiasInfo,
421 const std::string& descName)
422{
423 if (weightInfo.HasPerAxisQuantization())
424 {
425 const DataType inputDataType = inputInfo.GetDataType();
426 const DataType outputDataType = outputInfo.GetDataType();
427
Keith Davis0c2eeac2020-02-11 16:51:50 +0000428 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000429
430 if (!canHavePerAxisQuantization)
431 {
432 throw InvalidArgumentException(boost::str(
433 boost::format("%1%: Per-axis quantization parameters set on tensor %2%, "
434 "but data type does not support per-axis quantization.") % descName % "weight"));
435 }
436
Derek Lambertid466a542020-01-22 15:37:29 +0000437
438 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000439 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
440 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
441
442 if (optionalBiasInfo.has_value())
443 {
444 const TensorInfo& biasInfo = optionalBiasInfo.value();
445 if (!biasInfo.HasPerAxisQuantization())
446 {
447 throw InvalidArgumentException(boost::str(
448 boost::format("%1%: Per-axis quantization parameters not set on bias tensor, despite being set on "
449 "weight tensor.") % descName));
450 }
451
452 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
453 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
454 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
455 }
456 }
457}
458
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100459} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000460
461void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
462 unsigned int numExpectedIn, unsigned int numExpectedOut) const
463{
464 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
465 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
466}
467
468//---------------------------------------------------------------
469void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
470{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100471 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000472
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100473 ValidateNumInputs(workloadInfo, descriptorName, 1);
474 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000475
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100476 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
477 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
478
479 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
480 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000481
482 if (m_Inputs.size() != m_Outputs.size())
483 {
484 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100485 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
486 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000487 }
488
489 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
490 {
491 if (!m_Inputs[i])
492 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100493 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
494 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000495 }
496
497 if (!m_Outputs[i])
498 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100499 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
500 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000501 }
502 }
503}
504
Derek Lambertif674aa02019-08-01 15:56:25 +0100505//---------------------------------------------------------------
506void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
507{
508 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
509 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
510
511 if (workloadInfo.m_InputTensorInfos.size() != 1)
512 {
513 throw InvalidArgumentException(boost::str(
514 boost::format("Number of input infos (%1%) is not 1.")
515 % workloadInfo.m_InputTensorInfos.size()));
516
517 }
518
519 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
520 {
521 throw InvalidArgumentException(boost::str(
522 boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
523 % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
524 }
525
526 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
527 {
528 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
529 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
530 {
531 throw InvalidArgumentException(boost::str(
532 boost::format("Number of elements for tensor input and output %1% does not match")
533 % i ));
534 }
535 }
536
537 if (m_Inputs.size() != 1)
538 {
539 throw InvalidArgumentException(boost::str(
540 boost::format("Number of inputs (%1%) is not 1.")
541 % m_Inputs.size()));
542 }
543
544 if (m_Inputs.size() != m_Outputs.size())
545 {
546 throw InvalidArgumentException(boost::str(
547 boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
548 % m_Inputs.size() % m_Outputs.size()));
549 }
550
551 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
552 {
553 if (!m_Inputs[i])
554 {
555 throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
556 }
557
558 if (!m_Outputs[i])
559 {
560 throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
561 }
562 }
563}
564
565//---------------------------------------------------------------
566void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
567{
568 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
569 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
570
Derek Lambertif674aa02019-08-01 15:56:25 +0100571 if (m_Inputs.size() != 1)
572 {
573 throw InvalidArgumentException(boost::str(
574 boost::format("Number of inputs (%1%) is not 1.")
575 % m_Inputs.size()));
576 }
577
578 if (m_Outputs.size() != 0)
579 {
580 throw InvalidArgumentException(boost::str(
581 boost::format("Number of outputs (%1%) is not 0.")
582 % m_Inputs.size() % m_Outputs.size()));
583 }
584
585 if (!m_Inputs[0])
586 {
587 throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
588 }
589}
590
591//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000592void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
593{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100594 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100595
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100596 ValidateNumInputs(workloadInfo, descriptorName, 1);
597 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100598
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100599 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
600 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100601
602 std::vector<DataType> supportedTypes =
603 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000604 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100605 DataType::Float16,
606 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000607 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000608 DataType::QAsymmU8,
609 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +0100610 };
611
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100612 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
613 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
614 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000615}
616
Nikhil Rajee391d52019-09-05 17:50:44 +0100617void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
618{
619 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
620
621 ValidateNumInputs(workloadInfo, descriptorName, 1);
622 ValidateNumOutputs(workloadInfo, descriptorName, 1);
623
624 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
625 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
626
Nikhil Raj68c2c902019-09-19 11:21:11 +0100627 if (outputTensorInfo.GetDataType() != DataType::Signed32)
628 {
629 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32.");
630 }
631
James Conroyd47a0642019-09-17 14:22:06 +0100632 std::vector<DataType> supportedInputTypes =
633 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000634 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100635 DataType::Float16,
636 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000637 DataType::QAsymmU8,
638 DataType::QSymmS16,
Francis Murtagh1939df52019-11-13 15:21:09 +0000639 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +0100640 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100641
James Conroyd47a0642019-09-17 14:22:06 +0100642 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100643
644 auto inputShape = inputTensorInfo.GetShape();
645 auto outputShape = outputTensorInfo.GetShape();
646
647 auto inputNumDimensions = inputShape.GetNumDimensions();
648 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
649
650 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
651
652 // 1D input shape results in scalar output shape
653 if (inputShape.GetNumDimensions() == 1)
654 {
655 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
656 {
657 throw InvalidArgumentException(descriptorName + outputShapeError);
658 }
659 }
660 else
661 {
662 for (unsigned int i = 0; i < unsignedAxis; ++i)
663 {
664 if (outputShape[i] != inputShape[i])
665 {
666 throw InvalidArgumentException(descriptorName + outputShapeError);
667 }
668 }
669
670 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
671 {
672 if (outputShape[i - 1] != inputShape[i])
673 {
674 throw InvalidArgumentException(descriptorName + outputShapeError);
675 }
676 }
677 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100678}
679
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100680void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
681{
682 const std::string descriptorName{"SoftmaxQueueDescriptor"};
683
684 ValidateNumInputs(workloadInfo, descriptorName, 1);
685 ValidateNumOutputs(workloadInfo, descriptorName, 1);
686
687 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
688 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
689
690 std::vector<DataType> supportedTypes =
691 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000692 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100693 DataType::Float16,
694 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000695 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000696 DataType::QAsymmU8,
697 DataType::QSymmS16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100698 };
699
700 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
701 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
702 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
703}
704
telsoa014fcda012018-03-09 14:13:49 +0000705void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
706{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100707 const std::string descriptorName{"SplitterQueueDescriptor"};
708
709 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000710
Ruomei Yan25339c32019-05-28 16:48:20 +0100711 // Check the supported data types
712 std::vector<DataType> supportedTypes =
713 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000714 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100715 DataType::Float32,
716 DataType::Float16,
717 DataType::Boolean,
718 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000719 DataType::QAsymmU8,
720 DataType::QSymmS16
Ruomei Yan25339c32019-05-28 16:48:20 +0100721 };
722
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100723 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
724 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100725 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100726 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
727 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
728
729 const std::string outputName = "output_" + std::to_string(i);
730 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100731 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100732
telsoa014fcda012018-03-09 14:13:49 +0000733 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
734 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100735 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000736 }
737
738 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
739 {
740 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100741 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000742 "has to match number of workloadInfo.m_OutputTensorInfos. "
743 "Number of windows: " +
744 to_string(m_ViewOrigins.size()) +
745 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
746 }
747
telsoa01c577f2c2018-08-31 09:22:23 +0100748 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000749 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
750 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
751 {
telsoa01c577f2c2018-08-31 09:22:23 +0100752 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000753 ViewOrigin const& e = m_ViewOrigins[w];
754 if (e.m_Origin.size() != inputDims)
755 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100756 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000757 "have the same dimensionality as the input tensor. "
758 "Window origin (index: " +
759 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
760 " dimensions, the input "
761 "tensor has " +
762 to_string(inputDims) + " dimensions.");
763 }
764 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
765 {
766 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
767 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
768 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100769 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000770 "be smaller or equal than the size of the input in that coord.");
771 }
772 }
773 }
774}
775
Jim Flynne242f2d2019-05-22 14:24:13 +0100776void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000777{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100778 const std::string descriptorName{"ConcatQueueDescriptor"};
779
780 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000781
782 if (m_Inputs.size() <= 0)
783 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100784 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000785 }
786 if (m_Outputs.size() <= 0)
787 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100788 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000789 }
790
791 if (workloadInfo.m_InputTensorInfos.size() <= 0)
792 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100793 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000794 }
795 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
796 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100797 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000798 }
799
Nikhil Raj8599a412018-11-19 14:51:07 +0000800 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
801 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100802 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000803 }
804
805 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
806 {
807 return;
808 }
809
telsoa014fcda012018-03-09 14:13:49 +0000810 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
811 {
812 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100813 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000814 "has to match number of workloadInfo.m_InputTensorInfos. "
815 "Number of windows: " +
816 to_string(m_ViewOrigins.size()) +
817 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
818 }
819
telsoa01c577f2c2018-08-31 09:22:23 +0100820 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000821 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
822 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
823 {
telsoa01c577f2c2018-08-31 09:22:23 +0100824 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000825 ViewOrigin const& e = m_ViewOrigins[w];
826 if (e.m_Origin.size() != outputDims)
827 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100828 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000829 "have the same dimensionality as the output tensor. "
830 "Window origin (index: " +
831 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
832 " dimensions, the output "
833 "tensor has " +
834 to_string(outputDims) + " dimensions.");
835 }
telsoa01c577f2c2018-08-31 09:22:23 +0100836 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000837 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
838 {
839 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
840 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
841 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100842 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000843 "be smaller or equal than the size of the output in that coord.");
844 }
845 }
846 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100847
848 // Check the supported data types
849 std::vector<DataType> supportedTypes =
850 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000851 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100852 DataType::Float32,
853 DataType::Float16,
854 DataType::Boolean,
855 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000856 DataType::QAsymmU8,
857 DataType::QSymmS16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100858 };
859
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100860 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
861 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100862 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100863 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
864 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
865
866 const std::string inputName = "input_" + std::to_string(i);
867 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100868 }
telsoa014fcda012018-03-09 14:13:49 +0000869}
870
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100871void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
872{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100873 const std::string descriptorName{"StackQueueDescriptor"};
874
875 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100876
877 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
878 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100879 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100880 }
881
882 // All inputs must have the same shape, which is defined in parameters
883 const TensorShape& inputShape = m_Parameters.m_InputShape;
884 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
885 {
886 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
887 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100888 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100889 }
890 }
891
Matthew Jacksondba634f2019-08-15 15:14:18 +0100892 if (inputShape.GetNumDimensions() > 4)
893 {
894 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
895 }
896
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100897 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
898 // since the output tensor has an additional dimension.
899 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
900 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100901 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100902 "than the number of input dimensions.");
903 }
904
905 // Output shape must be as inferred from the input shape
906 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
907 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
908 {
909 if (outputShape[i] != inputShape[i])
910 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100911 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100912 "match shape inferred from input tensor.");
913 }
914 }
915
916 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
917 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100918 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100919 "match shape inferred from input tensor.");
920 }
921
922 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
923 {
924 if (outputShape[i] != inputShape[i-1])
925 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100926 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100927 "match shape inferred from input tensor.");
928 }
929 }
930
Matthew Jacksondba634f2019-08-15 15:14:18 +0100931 if (outputShape.GetNumDimensions() > 5)
932 {
933 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
934 }
935
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100936 // Check the supported data types
937 std::vector<DataType> supportedTypes =
938 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000939 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +0100940 DataType::Float32,
941 DataType::Float16,
942 DataType::Boolean,
943 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000944 DataType::QAsymmU8,
945 DataType::QSymmS16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100946 };
947
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100948 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100949
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100950 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100951 {
952 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
953 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100954 descriptorName,
955 "input_0",
956 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100957 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100958
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100959 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
960 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100961 descriptorName,
962 "input_0",
963 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100964}
965
telsoa014fcda012018-03-09 14:13:49 +0000966void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
967{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100968 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000969
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100970 ValidateNumInputs(workloadInfo, descriptorName, 1);
971 ValidateNumOutputs(workloadInfo, descriptorName, 1);
972
973 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
974 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
975
976 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
977
978 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +0000979 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100980 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +0000981 }
982
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100983 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000984
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100985 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
986 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000987
988 if (m_Parameters.m_BiasEnabled)
989 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100990 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000991
telsoa01c577f2c2018-08-31 09:22:23 +0100992 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100993 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
994 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000995
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100996 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
997 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000998 }
999
Francis Murtagh46c09d02019-05-28 08:15:28 +01001000 // Check the supported data types
1001 std::vector<DataType> supportedTypes =
1002 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001003 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01001004 DataType::Float32,
1005 DataType::Float16,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001006 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001007 DataType::QAsymmU8,
1008 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001009 };
1010
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001011 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1012 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001013}
1014
telsoa014fcda012018-03-09 14:13:49 +00001015void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1016{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001017 const std::string descriptorName{"NormalizationQueueDescriptor"};
1018
1019 ValidateNumInputs(workloadInfo, descriptorName, 1);
1020 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1021
1022 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1023 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001024
1025 // Check the supported data types
1026 std::vector<DataType> supportedTypes =
1027 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001028 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001029 DataType::Float16,
1030 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001031 DataType::QAsymmU8,
1032 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001033 };
1034
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001035 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001036
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001037 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001038
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001039 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001040}
1041
1042void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1043{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001044 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001045
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001046 ValidateNumInputs(workloadInfo, descriptorName, 2);
1047 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1048
1049 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1050 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1051 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1052
1053 std::vector<DataType> supportedTypes =
1054 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001055 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001056 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001057 DataType::Float16,
1058 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001059 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001060 DataType::QSymmS16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001061 };
1062
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001063 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1064 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1065 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001066
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001067 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1068 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001069
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001070 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1071 inputTensorInfo1,
1072 outputTensorInfo,
1073 descriptorName,
1074 "input_0",
1075 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001076}
1077
telsoa014fcda012018-03-09 14:13:49 +00001078void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1079{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001080 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001081
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001082 ValidateNumInputs(workloadInfo, descriptorName, 2);
1083 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1084
1085 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1086 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1087 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1088
1089 std::vector<DataType> supportedTypes =
1090 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001091 DataType::BFloat16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001092 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001093 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001094 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001095 DataType::QSymmS16,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001096 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001097 };
1098
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001099 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1100 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1101 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001102
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001103 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1104 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001105
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001106 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1107 inputTensorInfo1,
1108 outputTensorInfo,
1109 descriptorName,
1110 "input_0",
1111 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001112}
1113
1114void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1115{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001116 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001117
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001118 ValidateNumInputs(workloadInfo, descriptorName, 1);
1119 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1120
1121 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1122 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001123
1124 std::vector<DataType> supportedTypes =
1125 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001126 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001127 DataType::Float16,
1128 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001129 DataType::QAsymmU8,
1130 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001131 };
1132
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001133 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1134 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001135
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001136 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001137 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001138
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001139 ValidatePointer(m_Mean, descriptorName, "mean");
1140 ValidatePointer(m_Variance, descriptorName, "variance");
1141 ValidatePointer(m_Beta, descriptorName, "beta");
1142 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001143
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001144 const TensorInfo& mean = m_Mean->GetTensorInfo();
1145 const TensorInfo& variance = m_Variance->GetTensorInfo();
1146 const TensorInfo& beta = m_Beta->GetTensorInfo();
1147 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001148
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001149 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1150 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1151 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1152 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001153
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001154 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1155 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1156 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001157}
1158
1159void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1160{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001161 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001162
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001163 ValidateNumInputs(workloadInfo, descriptorName, 1);
1164 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001165
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001166 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1167 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001168
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001169 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1170 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001171
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001172 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001173
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001174 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1175 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001176
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001177 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001178
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001179 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001180 if (m_Parameters.m_BiasEnabled)
1181 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001182 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001183
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001184 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1185 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001186
1187 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1188 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001189 }
1190
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001191 ValidatePerAxisQuantization(inputTensorInfo,
1192 outputTensorInfo,
1193 weightTensorInfo,
1194 optionalBiasTensorInfo,
1195 descriptorName);
1196
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001197 std::vector<DataType> supportedTypes =
1198 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001199 DataType::BFloat16,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001200 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001201 DataType::QAsymmS8,
Francis Murtaghddb1d062020-03-10 13:51:45 +00001202 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001203 DataType::QSymmS16,
Keith Davis5204aa82020-01-27 15:24:59 +00001204 DataType::QSymmS8,
Ruomei Yan88d44b82019-05-23 14:29:06 +01001205 DataType::Float16
1206 };
1207
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001208 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1209 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1210}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001211
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001212void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1213{
1214 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1215
1216 ValidateNumInputs(workloadInfo, descriptorName, 1);
1217 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1218
1219 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1220 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1221
1222 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1223 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1224
1225 ValidatePointer(m_Weight, descriptorName, "weight");
1226
1227 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1228 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1229
1230 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1231 {
1232 throw InvalidArgumentException(
1233 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
1234 "cannot be smaller than 1.") % descriptorName %
1235 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
1236 }
1237
1238 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1239
1240 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1241 // inputChannels * channelMultiplier should be equal to outputChannels.
1242 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1243 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1244 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1245 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1246 {
1247 throw InvalidArgumentException(
1248 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
1249 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
1250 "(provided %4%).") % descriptorName % numWeightOutputChannels %
1251 numWeightInputChannels % numWeightChannelMultiplier));
1252 }
1253
Teresa Charlind8df0262019-11-11 12:28:15 +00001254 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001255
Teresa Charlind8df0262019-11-11 12:28:15 +00001256 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001257 if (m_Parameters.m_BiasEnabled)
1258 {
1259 ValidatePointer(m_Bias, descriptorName, "bias");
1260
Teresa Charlind8df0262019-11-11 12:28:15 +00001261 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1262 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001263
1264 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1265 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1266 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001267 ValidatePerAxisQuantization(inputTensorInfo,
1268 outputTensorInfo,
1269 weightTensorInfo,
1270 optionalBiasTensorInfo,
1271 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001272
1273 std::vector<DataType> supportedTypes =
1274 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001275 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001276 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001277 DataType::QAsymmU8,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001278 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001279 DataType::QSymmS16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001280 DataType::Float16
1281 };
1282
1283 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1284 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001285}
1286
1287void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1288{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001289 const std::string descriptorName{"PermuteQueueDescriptor"};
1290
1291 ValidateNumInputs(workloadInfo, descriptorName, 1);
1292 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001293
1294 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1295
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001296 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1297 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001298
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001299 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1300 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001301
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001302 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001303 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001304 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001305 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001306 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1307 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1308 "must match dst dimension " + to_string(mapping[i]) +
1309 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001310 }
1311 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001312
1313 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001314}
1315
1316void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1317{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001318 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001319
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001320 ValidateNumInputs(workloadInfo, descriptorName, 1);
1321 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1322
1323 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1324 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1325
1326 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1327 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001328
1329 std::vector<DataType> supportedTypes =
1330 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001331 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001332 DataType::Float32,
1333 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001334 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001335 DataType::QAsymmU8,
1336 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001337 };
1338
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001339 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1340 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001341}
1342
1343void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1344{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001345 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001346
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001347 ValidateNumInputs(workloadInfo, descriptorName, 1);
1348 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1349
1350 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1351 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1352
1353 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1354 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001355
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001356 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001357 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001358 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001359 DataType::Float16,
1360 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001361 DataType::QAsymmU8,
1362 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001363 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001364
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001365 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1366 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001367
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001368 // ResizeBilinear only changes width and height: batch and channel count must match.
1369 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1370 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001371 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001372 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001373 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001374 boost::str(boost::format("%1%: Input batch size (%2%) "
1375 "does not match output batch size (%3%)") %
1376 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001377 }
1378
Teresa Charlin970f43b2019-07-01 13:51:07 +01001379 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001380 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1381 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001382 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001383 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001384 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001385 boost::str(boost::format("%1%: Input channel count (%2%) "
1386 "does not match output channel count (%3%)") %
1387 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001388 }
1389}
1390
1391void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1392{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001393 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001394
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001395 ValidateNumInputs(workloadInfo, descriptorName, 1);
1396 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1397
1398 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1399 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1400
1401 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1402 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001403
1404 std::vector<DataType> supportedTypes =
1405 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001406 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001407 DataType::Float16,
1408 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +00001409 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001410 DataType::QAsymmU8,
1411 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001412 };
1413
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001414 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1415 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001416
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001417 // Resize only changes width and height: batch and channel count must match.
1418 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1419 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001420 if (inputBatchSize != outputBatchSize)
1421 {
1422 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001423 boost::str(boost::format("%1%: Input batch size (%2%) "
1424 "does not match output batch size (%3%)") %
1425 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001426 }
1427
1428 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001429 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1430 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001431 if (inputChannelCount != outputChannelCount)
1432 {
1433 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001434 boost::str(boost::format("%1%: Input channel count (%2%) "
1435 "does not match output channel count (%3%)") %
1436 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001437 }
1438}
1439
1440void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1441{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001442 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001443
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001444 ValidateNumInputs(workloadInfo, descriptorName, 1);
1445 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1446
1447 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1448 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1449
1450 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1451 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1452
1453 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1454
telsoa014fcda012018-03-09 14:13:49 +00001455 if (m_Parameters.m_Min > m_Parameters.m_Max)
1456 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001457 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001458 }
telsoa014fcda012018-03-09 14:13:49 +00001459}
1460
Kevin Mayce5045a2019-10-02 14:07:47 +01001461void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1462{
1463 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1464
1465 ValidateNumInputs(workloadInfo, descriptorName, 1);
1466 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1467
1468 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1469 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1470
1471 if (inputTensorInfo.GetNumDimensions() > 4)
1472 {
1473 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1474 }
1475
1476 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1477
1478 // Check the supported data types
1479 std::vector<DataType> supportedTypes =
1480 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001481 DataType::BFloat16,
Kevin Mayce5045a2019-10-02 14:07:47 +01001482 DataType::Float32,
1483 DataType::Float16
1484 };
1485
1486 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001487 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001488}
1489
telsoa014fcda012018-03-09 14:13:49 +00001490void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1491{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001492 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001493
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001494 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001495 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1496
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001497 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1498 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1499
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001500 if (inputTensorInfo.GetNumDimensions() > 4)
1501 {
1502 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1503 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001504
1505 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001506
1507 // Check the supported data types
1508 std::vector<DataType> supportedTypes =
1509 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001510 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001511 DataType::Float32,
1512 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001513 DataType::QAsymmU8,
1514 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001515 };
1516
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001517 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001518 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1519}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001520
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001521void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1522{
1523 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1524
1525 ValidateNumInputs(workloadInfo, descriptorName, 1);
1526 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1527
1528 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1529 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1530
1531 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1532
1533 std::vector<DataType> supportedTypes =
1534 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001535 DataType::BFloat16,
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001536 DataType::Float32,
1537 DataType::Float16,
1538 };
1539
1540 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001541 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001542}
1543
1544void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1545{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001546 const std::string descriptorName{"ConstantQueueDescriptor"};
1547
1548 ValidateNumInputs(workloadInfo, descriptorName, 0);
1549 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001550
1551 if (!m_LayerOutput)
1552 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001553 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001554 }
1555
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001556 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1557 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001558
1559 // Check the supported data types
1560 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001561 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001562 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001563 DataType::Float32,
1564 DataType::Float16,
1565 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001566 DataType::QAsymmU8,
Keith Davis67e6c542020-02-19 10:08:33 +00001567 DataType::QAsymmS8,
Keith Davis5204aa82020-01-27 15:24:59 +00001568 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001569 DataType::QSymmS16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001570 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001571
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001572 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001573}
1574
1575void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1576{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001577 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001578
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001579 ValidateNumInputs(workloadInfo, descriptorName, 1);
1580 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1581
1582 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1583 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1584
1585 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001586
1587 // Check the supported data types
1588 std::vector<DataType> supportedTypes =
1589 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001590 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001591 DataType::Float32,
1592 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001593 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001594 DataType::QSymmS16,
1595 DataType::QAsymmS8,
Keith Davis67e6c542020-02-19 10:08:33 +00001596 DataType::QAsymmU8
Nina Drozd2f2778f2019-05-27 10:37:05 +01001597 };
1598
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001599 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1600 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001601}
1602
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001603void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1604{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001605 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001606
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001607 ValidateNumInputs(workloadInfo, descriptorName, 1);
1608 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1609
1610 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1611 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1612
1613 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1614 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001615
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001616 if (m_Parameters.m_BlockShape.size() != 2)
1617 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001618 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001619 }
1620
1621 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1622 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001623 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1624 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001625 }
1626
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001627 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001628
1629 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001630 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001631
Matthew Bentham8800c002018-11-19 13:19:28 +00001632 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001633
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001634 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1635 widthPad.first + widthPad.second;
1636 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1637 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001638
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001639 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1640 inputShape[dimensionIndices.GetChannelsIndex()];
1641 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001642
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001643 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001644 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001645 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001646 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001647 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001648 }
1649
1650 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001651 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001652 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1653 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001654 }
nikraj01120522a2019-05-31 11:33:07 +01001655
1656 std::vector<DataType> supportedTypes =
1657 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001658 DataType::BFloat16,
1659 DataType::Float16,
1660 DataType::Float32,
1661 DataType::QAsymmU8,
1662 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001663 };
1664
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001665 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1666 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001667}
1668
Keith Davisa57eccb2019-06-14 17:33:22 +01001669void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1670{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001671 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001672
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001673 ValidateNumInputs(workloadInfo, descriptorName, 1);
1674 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001675
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001676 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1677 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1678
1679 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1680 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001681
1682 std::vector<DataType> supportedTypes =
1683 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001684 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001685 DataType::Float32,
1686 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001687 DataType::QAsymmU8,
1688 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001689 };
1690
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001691 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1692 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001693
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001694 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1695
1696 if (m_Parameters.m_BlockSize == 0)
1697 {
1698 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1699 }
1700
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001701 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1702 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1703 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1704 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001705
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001706 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001707 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001708 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001709 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1710 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001711 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001712
1713 const TensorShape& outputShape = outputTensorInfo.GetShape();
1714 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1715 {
1716 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1717 "must be divisible by the square of block size." );
1718 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001719}
1720
telsoa014fcda012018-03-09 14:13:49 +00001721void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1722{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001723 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001724
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001725 ValidateNumInputs(workloadInfo, descriptorName, 1);
1726 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1727
1728 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1729 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001730
1731 std::vector<DataType> supportedTypes =
1732 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001733 DataType::BFloat16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001734 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001735 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001736 DataType::QSymmS16
James Conroy83735b12019-05-30 16:36:59 +01001737 };
1738
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001739 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001740
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001741 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001742 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001743 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001744 }
1745}
1746
telsoa01c577f2c2018-08-31 09:22:23 +01001747void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1748{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001749 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1750
1751 const std::string descriptorName{"LstmQueueDescriptor"};
1752
1753 // check dimensions of all inputs and outputs
1754 if (workloadInfo.m_InputTensorInfos.size() != 3)
1755 {
1756 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1757 }
1758 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1759 {
1760 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1761 }
1762
1763 std::vector<DataType> supportedTypes =
1764 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001765 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001766 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001767 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001768 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001769 };
1770
Jan Eilers38e05bd2019-06-26 13:10:09 +01001771 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001772 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1773
Jan Eilers38e05bd2019-06-26 13:10:09 +01001774 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001775 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001776 {
1777 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1778 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001779 descriptorName,
1780 "input_0",
1781 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001782 }
1783 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001784 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001785 {
1786 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1787 workloadInfo.m_OutputTensorInfos[i],
1788 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001789 "input_0",
1790 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001791 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001792
janeil0117d8d852019-11-15 15:00:16 +00001793 // Making sure clipping parameters have valid values.
1794 // == 0 means no clipping
1795 // > 0 means clipping
1796 if (m_Parameters.m_ClippingThresCell < 0.0f)
1797 {
1798 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
1799 }
1800 if (m_Parameters.m_ClippingThresProj < 0.0f)
1801 {
1802 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
1803 }
1804
Jan Eilers38e05bd2019-06-26 13:10:09 +01001805
1806 // Inferring batch size, number of outputs and number of cells from the inputs.
Jan Eilers38e05bd2019-06-26 13:10:09 +01001807 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1808 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1809 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1810 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1811 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1812 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1813
Jan Eilers38e05bd2019-06-26 13:10:09 +01001814 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001815 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1816 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001817 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001818 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1819 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001820 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001821 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1822 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001823 // scratchBufferTensor
1824 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001825 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1826 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001827 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001828 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1829 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001830 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001831 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1832 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001833 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001834 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1835 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001836
1837
1838 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1839 if ( m_InputToInputWeights )
1840 {
1841 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1842 (n_cell * n_input), "InputLayerNormWeights");
1843 }
1844
1845 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1846 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1847 (n_cell * n_input), "InputToForgetWeights");
1848
1849 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1850 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1851 (n_cell * n_input), "InputToCellWeights");
1852
1853 if ( m_RecurrentToInputWeights )
1854 {
1855 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1856 (n_cell * n_output), "RecurrentToInputWeights");
1857 }
1858
1859 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1860 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1861 (n_cell * n_output), "RecurrentToForgetWeights");
1862
1863 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1864 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1865 (n_cell * n_output), "RecurrentToCellWeights");
1866
1867 // Make sure the input-gate's parameters are either both present (regular
1868 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1869 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1870 !m_Parameters.m_CifgEnabled) ||
1871 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1872 m_Parameters.m_CifgEnabled));
1873 if (!cifg_weights_all_or_none)
1874 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001875 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1876 "RecurrentToInputWeights must either both be present (regular LSTM) "
1877 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1878 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001879 }
1880
1881 if ( m_CellToInputWeights )
1882 {
1883 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1884 n_cell, "CellToInputWeights");
1885 }
1886 if ( m_CellToForgetWeights )
1887 {
1888 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1889 n_cell, "CellToForgetWeights");
1890 }
1891 if ( m_CellToOutputWeights )
1892 {
1893 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1894 n_cell, "CellToOutputWeights");
1895 }
1896
1897 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1898 bool peephole_weights_all_or_none =
1899 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1900 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1901 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1902 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1903 if (!peephole_weights_all_or_none)
1904 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001905 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001906 }
1907
1908 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1909 if (m_Parameters.m_CifgEnabled)
1910 {
1911 if (m_InputGateBias)
1912 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001913 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001914 }
1915 }
1916 else
1917 {
1918 if (!m_InputGateBias)
1919 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001920 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1921 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001922 }
1923 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1924 n_cell, "InputGateBias");
1925 }
1926
1927 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1928 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1929
1930 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1931 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1932
1933 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1934 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1935
1936 if (m_ProjectionWeights)
1937 {
1938 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
1939 (n_cell * n_output), "ProjectionWeights");
1940 }
1941 if (m_ProjectionBias)
1942 {
1943 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
1944 }
1945
1946 // Making sure the projection tensors are consistent:
1947 // 1) If projection weight is not present, then projection bias should not be
1948 // present.
1949 // 2) If projection weight is present, then projection bias is optional.
1950 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
1951 !m_Parameters.m_ProjectionEnabled)
1952 || (m_ProjectionWeights && !m_ProjectionBias &&
1953 m_Parameters.m_ProjectionEnabled)
1954 || (m_ProjectionWeights && m_ProjectionBias &&
1955 m_Parameters.m_ProjectionEnabled));
1956 if (!projecton_tensors_consistent)
1957 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001958 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001959 }
1960
1961 // The four layer normalization weights either all have values or none of them have values. Additionally, if
1962 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
1963 // either all have values or none of them have values. Layer normalization is used when the values of all the
1964 // layer normalization weights are present
1965 if (m_InputLayerNormWeights)
1966 {
1967 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
1968 }
1969 if (m_ForgetLayerNormWeights)
1970 {
1971 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1972 }
1973 if (m_CellLayerNormWeights)
1974 {
1975 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1976 }
1977 if (m_OutputLayerNormWeights)
1978 {
1979 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1980 }
1981
Jan Eilers38e05bd2019-06-26 13:10:09 +01001982 if (m_Parameters.m_LayerNormEnabled)
1983 {
1984 if (!m_Parameters.m_CifgEnabled)
1985 {
1986 if (!m_InputLayerNormWeights)
1987 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001988 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
1989 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001990 }
1991 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
1992 1, n_cell, "InputLayerNormWeights");
1993 }
1994 else if (m_InputLayerNormWeights)
1995 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001996 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
1997 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001998 }
1999
2000 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2001 "ForgetLayerNormWeights");
2002 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2003
2004 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2005 "OutputLayerNormWeights");
2006 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2007
2008 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2009 "CellLayerNormWeights");
2010 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2011 }
2012 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2013 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002014 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2015 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01002016 }
telsoa01c577f2c2018-08-31 09:22:23 +01002017}
2018
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +00002019void ConvertBf16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2020{
2021 const std::string descriptorName{"ConvertBf16ToFp32QueueDescriptor"};
2022
2023 ValidateNumInputs(workloadInfo, descriptorName, 1);
2024 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2025
2026 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2027 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2028
2029 if (inputTensorInfo.GetDataType() != DataType::BFloat16)
2030 {
2031 throw InvalidArgumentException(descriptorName + ": Input tensor type must be BFloat16.");
2032 }
2033
2034 if (outputTensorInfo.GetDataType() != DataType::Float32)
2035 {
2036 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2037 }
2038
2039 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2040}
2041
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00002042void ConvertFp32ToBf16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2043{
2044 const std::string descriptorName{"ConvertFp32ToBf16QueueDescriptor"};
2045
2046 ValidateNumInputs(workloadInfo, descriptorName, 1);
2047 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2048
2049 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2050 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2051
2052 if (inputTensorInfo.GetDataType() != DataType::Float32)
2053 {
2054 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2055 }
2056
2057 if (outputTensorInfo.GetDataType() != DataType::BFloat16)
2058 {
2059 throw InvalidArgumentException(descriptorName + ": Output tensor type must be BFloat16.");
2060 }
2061
2062 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2063}
2064
telsoa01c577f2c2018-08-31 09:22:23 +01002065void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2066{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002067 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002068
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002069 ValidateNumInputs(workloadInfo, descriptorName, 1);
2070 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2071
2072 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2073 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2074
2075 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01002076 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002077 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01002078 }
2079
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002080 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002081 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002082 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002083 }
2084
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002085 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002086}
2087
2088void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2089{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002090 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01002091
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002092 ValidateNumInputs(workloadInfo, descriptorName, 1);
2093 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2094
2095 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2096 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2097
2098 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01002099 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002100 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01002101 }
2102
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002103 if (outputTensorInfo.GetDataType() != DataType::Float32)
2104 {
2105 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2106 }
2107
2108 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01002109}
2110
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002111void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2112{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002113 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002114
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002115 ValidateNumInputs(workloadInfo, descriptorName, 2);
2116 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2117
2118 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2119 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2120 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2121
2122 std::vector<DataType> supportedTypes =
2123 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002124 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002125 DataType::QAsymmU8,
2126 DataType::QSymmS16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002127 DataType::Float16,
2128 DataType::BFloat16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002129 };
2130
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002131 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2132 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2133 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002134
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002135 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2136 inputTensorInfo1,
2137 outputTensorInfo,
2138 descriptorName,
2139 "input_0",
2140 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002141}
2142
David Beckc2044fe2018-09-05 15:00:38 +01002143void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2144{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002145 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002146
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002147 ValidateNumInputs(workloadInfo, descriptorName, 2);
2148 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2149
2150 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2151 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2152 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2153
2154 std::vector<DataType> supportedTypes =
2155 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002156 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002157 DataType::QAsymmU8,
2158 DataType::QSymmS16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002159 DataType::Float16,
2160 DataType::BFloat16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002161 };
2162
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002163 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2164 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2165 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002166
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002167 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2168 inputTensorInfo1,
2169 outputTensorInfo,
2170 descriptorName,
2171 "input_0",
2172 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002173}
2174
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002175void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2176{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002177 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002178
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002179 ValidateNumInputs(workloadInfo, descriptorName, 2);
2180 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2181
2182 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2183 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2184 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2185
2186 std::vector<DataType> supportedTypes =
2187 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002188 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002189 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002190 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002191 DataType::Signed32,
Keith Davis67e6c542020-02-19 10:08:33 +00002192 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002193 DataType::QAsymmU8,
2194 DataType::QSymmS16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002195 };
2196
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002197 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2198 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2199 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002200
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002201 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2202 inputTensorInfo1,
2203 outputTensorInfo,
2204 descriptorName,
2205 "input_0",
2206 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002207}
2208
narpra01a6bf9122018-09-10 09:50:09 +01002209void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2210{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002211 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002212
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002213 ValidateNumInputs(workloadInfo, descriptorName, 1);
2214 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2215
2216 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2217 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002218
2219 std::vector<DataType> supportedTypes =
2220 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002221 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01002222 DataType::Float32,
2223 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002224 DataType::QAsymmU8,
2225 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01002226 };
narpra01eb061912018-09-10 17:35:27 +01002227
James Conroy4d1ff582019-06-10 17:06:39 +01002228 // First check if input tensor data type is supported, then
2229 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002230 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2231 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002232
narpra0132b90462018-09-13 11:07:48 +01002233 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002234 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002235 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002236 }
narpra0132b90462018-09-13 11:07:48 +01002237 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002238 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002239 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002240 }
2241 else
2242 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002243 unsigned int outputDim =
2244 inputTensorInfo.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
2245 ValidateTensorNumDimensions(outputTensorInfo,
2246 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002247 outputDim > 0 ? outputDim : 1,
2248 "output");
2249 }
narpra01a6bf9122018-09-10 09:50:09 +01002250}
2251
jimfly012c9322a2018-09-19 10:59:49 +01002252void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2253{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002254 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002255
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002256 ValidateNumInputs(workloadInfo, descriptorName, 1);
2257 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2258
2259 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2260 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002261
jimfly012c9322a2018-09-19 10:59:49 +01002262 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002263 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2264
jimfly012c9322a2018-09-19 10:59:49 +01002265 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002266 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2267 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2268 "as there are dimensions in the input tensor that is " +
2269 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2270 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002271 }
2272}
2273
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002274void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2275{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002276 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002277
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002278 ValidateNumInputs(workloadInfo, descriptorName, 1);
2279 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002280
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002281 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2282 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2283
Sadik Armagan2208b602019-07-31 16:36:27 +01002284 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002285 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002286 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002287 DataType::Float32,
Keith Davis5e51cd82020-01-29 16:52:59 +00002288 DataType::Float16,
2289 DataType::QSymmS8,
Ryan OShea9add1202020-02-07 10:06:33 +00002290 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002291 DataType::QAsymmU8,
2292 DataType::QSymmS16
Sadik Armagan2208b602019-07-31 16:36:27 +01002293 };
2294
2295 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002296
Keith Davis0c2eeac2020-02-11 16:51:50 +00002297 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002298 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002299 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002300 }
2301}
2302
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002303void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2304{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002305 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002306
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002307 ValidateNumInputs(workloadInfo, descriptorName, 1);
2308 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002309
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002310 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2311 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002312
2313 std::vector<DataType> supportedTypes =
2314 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002315 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002316 DataType::Float32,
2317 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002318 DataType::QAsymmU8,
2319 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002320 };
2321
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002322 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2323 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002324}
2325
Conor Kennedy430b5d82018-11-14 15:28:28 +00002326void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2327{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002328 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002329
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002330 ValidateNumInputs(workloadInfo, descriptorName, 1);
2331 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2332
2333 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2334 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002335
2336 std::vector<DataType> supportedTypes =
2337 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002338 DataType::BFloat16,
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002339 DataType::Float16,
2340 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002341 DataType::QAsymmU8,
2342 DataType::QSymmS16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002343 };
2344
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002345 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2346 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002347
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002348 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002349
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002350 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002351 if (rank > 4)
2352 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002353 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002354 }
2355
Conor Kennedy430b5d82018-11-14 15:28:28 +00002356 // Begin, End & Stride length must be of rank(input0)
2357 if (m_Parameters.m_Begin.size() != rank)
2358 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002359 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002360 }
2361
2362 if (m_Parameters.m_End.size() != rank)
2363 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002364 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002365 }
2366
2367 if (m_Parameters.m_Stride.size() != rank)
2368 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002369 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002370 }
2371
2372 // Stride entries must be non-zero
2373 for (auto& stride : m_Parameters.m_Stride)
2374 {
2375 if (stride == 0)
2376 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002377 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002378 }
2379 }
2380}
2381
kevmay0190539692018-11-29 08:40:19 +00002382void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2383{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002384 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002385
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002386 ValidateNumInputs(workloadInfo, descriptorName, 2);
2387 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2388
2389 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2390 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2391 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2392
2393 std::vector<DataType> supportedTypes =
2394 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002395 DataType::BFloat16,
Mike Kelly1da02362019-08-01 08:43:57 +01002396 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002397 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002398 DataType::Signed32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002399 DataType::QAsymmU8,
2400 DataType::QSymmS16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002401 };
2402
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002403 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2404 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2405 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002406
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002407 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2408 inputTensorInfo1,
2409 outputTensorInfo,
2410 descriptorName,
2411 "input_0",
2412 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002413}
2414
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002415void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2416{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002417 const std::string descriptorName{"DebugQueueDescriptor"};
2418
2419 ValidateNumInputs(workloadInfo, descriptorName, 1);
2420 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002421}
2422
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002423void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2424{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002425 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002426
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002427 ValidateNumInputs(workloadInfo, descriptorName, 2);
2428 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002429
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002430 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2431 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2432 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2433
2434 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2435 inputTensorInfo1,
2436 outputTensorInfo,
2437 descriptorName,
2438 "input_0",
2439 "input_1");
2440
2441 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002442 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002443 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002444 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002445}
2446
FrancisMurtagh878f0232018-12-19 10:56:15 +00002447void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2448{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002449 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002450
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002451 ValidateNumInputs(workloadInfo, descriptorName, 2);
2452 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002453
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002454 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2455 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2456 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2457
2458 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2459 inputTensorInfo1,
2460 outputTensorInfo,
2461 descriptorName,
2462 "input_0",
2463 "input_1");
2464
2465 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002466 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002467 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002468 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002469}
2470
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002471void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2472{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002473 const std::string descriptorName{"RsqrtQueueDescriptor"};
2474
2475 ValidateNumInputs(workloadInfo, descriptorName, 1);
2476 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2477
2478 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2479 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2480
2481 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002482
2483 std::vector<DataType> supportedTypes =
2484 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002485 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002486 DataType::Float16,
2487 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002488 DataType::QAsymmU8,
2489 DataType::QSymmS16
nikraj010421e7f2019-06-14 09:40:34 +01002490 };
2491
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002492 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2493 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002494}
2495
narpra01b89b05f2019-01-16 09:53:09 +00002496void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2497{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002498 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002499
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002500 ValidateNumInputs(workloadInfo, descriptorName, 2);
2501 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002502
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002503 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2504 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002505 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002506 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002507 }
2508
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002509 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2510 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2511
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002512 std::vector<DataType> supportedTypes =
2513 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002514 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002515 DataType::Float16,
2516 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002517 DataType::QAsymmU8,
2518 DataType::QSymmS16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002519 };
2520
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002521 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002522
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002523 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002524
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002525 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2526 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002527}
2528
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002529void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2530{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002531 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2532
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002533 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002534
2535 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2536 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002537 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002538 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2539 }
2540
2541 if (m_Anchors == nullptr)
2542 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002543 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002544 }
2545
2546 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002547 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2548 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2549
2550 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002551 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002552 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2553 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002554
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002555 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2556 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2557 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002558
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002559 const std::vector<DataType> supportedInputTypes =
2560 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002561 DataType::BFloat16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002562 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002563 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002564 DataType::QAsymmU8,
2565 DataType::QSymmS16
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002566 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002567
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002568 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2569 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2570 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2571
2572 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2573 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2574 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2575 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2576
2577 // NOTE: Output is always Float32 regardless of input type
2578 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2579 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2580 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2581 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002582
2583 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2584 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002585 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002586 "must be positive and less than or equal to 1.");
2587 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002588
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002589 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2590 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002591 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002592 "should be equal to number of classes + 1.");
2593 }
2594}
2595
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002596void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2597{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002598 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002599
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002600 ValidateNumInputs(workloadInfo, descriptorName, 1);
2601 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2602
2603 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2604 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2605
Aron Virginas-Tare9323ec2019-11-26 12:50:34 +00002606 if (!IsQuantizedType(inputTensorInfo.GetDataType()))
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002607 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002608 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002609 }
2610
Sadik Armagan2208b602019-07-31 16:36:27 +01002611 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002612 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002613 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002614 DataType::Float32,
2615 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002616 };
2617
2618 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002619}
2620
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002621void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2622{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002623 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002624
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002625 ValidateNumInputs(workloadInfo, descriptorName, 2);
2626 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002627
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002628 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2629 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2630 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002631
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002632 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2633 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2634
2635 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2636 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002637}
2638
Sadik Armaganeff363d2019-04-05 15:25:46 +01002639void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2640{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002641 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002642
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002643 ValidateNumInputs(workloadInfo, descriptorName, 2);
2644 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2645
2646 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2647 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2648
2649 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2650 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2651
2652 std::vector<DataType> supportedTypes =
2653 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002654 DataType::BFloat16,
Sadik Armaganeff363d2019-04-05 15:25:46 +01002655 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002656 DataType::QAsymmU8,
2657 DataType::QSymmS16
Sadik Armaganeff363d2019-04-05 15:25:46 +01002658 };
2659
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002660 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2661 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002662
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002663 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2664 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002665
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002666 ValidateTensorShapesMatch(inputTensorInfo0,
2667 outputTensorInfo0,
2668 descriptorName,
2669 "input_0",
2670 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002671
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002672 ValidateTensorShapesMatch(inputTensorInfo0,
2673 outputTensorInfo1,
2674 descriptorName,
2675 "input_0",
2676 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002677}
2678
Derek Lamberti901ea112019-12-10 22:07:09 +00002679void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
Matteo Martincigh49124022019-01-11 13:25:59 +00002680{
2681 // This is internally generated so it should not need validation.
2682}
2683
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002684void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2685{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002686 const std::string& descriptorName{"PreluQueueDescriptor"};
2687
2688 ValidateNumInputs(workloadInfo, descriptorName, 2);
2689 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2690
2691 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2692 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2693 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002694
2695 std::vector<DataType> supportedTypes
2696 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002697 DataType::BFloat16,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002698 DataType::Float16,
2699 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002700 DataType::QAsymmU8,
2701 DataType::QSymmS16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002702 };
2703
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002704 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2705 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002706
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002707 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002708
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002709 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2710 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002711
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002712 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2713 alphaTensorInfo,
2714 outputTensorInfo,
2715 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002716 "input",
2717 "alpha");
2718}
2719
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002720void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2721{
2722 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2723
2724 ValidateNumInputs(workloadInfo, descriptorName, 1);
2725 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2726
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002727 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2728 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2729
2730 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2731 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002732
2733 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002734
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002735 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2736 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002737
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002738 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2739
2740 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002741 if (m_Parameters.m_BiasEnabled)
2742 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002743 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002744
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002745 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2746 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002747
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002748 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002749 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002750 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002751
2752 ValidatePerAxisQuantization(inputTensorInfo,
2753 outputTensorInfo,
2754 weightTensorInfo,
2755 optionalBiasTensorInfo,
2756 descriptorName);
2757
2758 std::vector<DataType> supportedTypes =
2759 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002760 DataType::BFloat16,
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002761 DataType::Float32,
2762 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002763 DataType::QAsymmU8,
2764 DataType::QSymmS16
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002765 };
2766
2767 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2768 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002769}
2770
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002771void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2772{
2773 const std::string descriptorName{"TransposeQueueDescriptor"};
2774
2775 ValidateNumInputs(workloadInfo, descriptorName, 1);
2776 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2777
2778 const PermutationVector& mapping = m_Parameters.m_DimMappings;
2779
2780 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2781 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2782
2783 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
2784 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
2785
2786 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
2787 {
2788 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
2789 {
2790 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
2791 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
2792 "must match dst dimension " + to_string(i) +
2793 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
2794 }
2795 }
2796
2797 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2798}
2799
James Conroy9c3cae82019-08-01 16:01:48 +01002800void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2801{
2802 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
2803
2804 // Validate number of inputs/outputs
2805 ValidateNumInputs(workloadInfo, descriptorName, 3);
2806 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2807
2808 // Input/output tensor infos
2809 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2810 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
2811 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
2812
2813 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2814 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2815
2816 std::vector<DataType> inputOutputSupportedTypes =
2817 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002818 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01002819 };
2820
2821 std::vector<DataType> cellStateSupportedTypes =
2822 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002823 DataType::QSymmS16
James Conroy9c3cae82019-08-01 16:01:48 +01002824 };
2825
2826 std::vector<DataType> weightsSupportedTypes =
2827 {
Derek Lambertif90c56d2020-01-10 17:14:08 +00002828 DataType::QAsymmU8
James Conroy9c3cae82019-08-01 16:01:48 +01002829 };
2830
2831 std::vector<DataType> biasSupportedTypes =
2832 {
2833 DataType::Signed32
2834 };
2835
2836 // Validate types of input/output tensors
2837 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2838 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2839 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2840
2841 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2842 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2843
2844 // Validate matching types of input/output tensors
2845 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2846 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2847 "outputStateIn", "outputStateOut");
2848 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2849
2850 // Validate matching quantization info for input/output tensors
2851 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2852 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
2853 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002854
James Conroy9c3cae82019-08-01 16:01:48 +01002855 // Infer number of batches, input size and output size from tensor dimensions
2856 const uint32_t numBatches = inputInfo.GetShape()[0];
2857 const uint32_t inputSize = inputInfo.GetShape()[1];
2858 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
2859
2860 // Validate number of dimensions and number of elements for input/output tensors
2861 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2862 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
2863 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2864 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
2865 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2866
2867 // Validate number of dimensions and number of elements for weights tensors
2868 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
2869 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
2870 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
2871
2872 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2873 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2874 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
2875
2876 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2877 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2878 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
2879
2880 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2881 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2882 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
2883
2884 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
2885 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
2886 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
2887
2888 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2889 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2890 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
2891 " RecurrentToForgetWeights");
2892
2893 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2894 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2895 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2896
2897 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2898 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2899 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2900
2901 // Validate data types for weights tensors (all should match each other)
2902 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
2903
2904 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
2905 "inputToInputWeights", "inputToForgetWeights");
2906 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
2907 "inputToInputWeights", "inputToCellWeights");
2908 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2909 "inputToInputWeights", "inputToOutputWeights");
2910
2911 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
2912 "inputToInputWeights", "recurrentToInputWeights");
2913 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2914 "inputToInputWeights", "recurrentToForgeteights");
2915 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2916 "inputToInputWeights", "recurrentToCellWeights");
2917 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2918 "inputToInputWeights", "recurrentToOutputWeights");
2919
2920 // Validate matching quantization info for weight tensors (all should match each other)
2921 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
2922 descriptorName, "inputToInputWeights", "inputToForgetWeights");
2923 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
2924 descriptorName, "inputToInputWeights", "inputToCellWeights");
2925 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
2926 descriptorName, "inputToInputWeights", "inputToOutputWeights");
2927
2928 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
2929 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
2930 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
2931 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
2932 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
2933 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
2934 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
2935 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
2936
2937 // Validate number of dimensions and number of elements in bias tensors
2938 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
2939 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
2940 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
2941
2942 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2943 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2944 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
2945
2946 ValidatePointer(m_CellBias, descriptorName, "CellBias");
2947 auto cellBiasInfo = m_CellBias->GetTensorInfo();
2948 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
2949
2950 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
2951 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
2952 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
2953
2954 // Validate data types for bias tensors (all should match each other)
2955 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
2956
2957 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
2958 "inputGateBias", "forgetGateBias");
2959 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
2960 "inputGateBias", "cellBias");
2961 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
2962 "inputGateBias", "outputGateBias");
2963
2964 // Validate bias tensor quantization info
2965 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2966 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2967 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2968 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2969}
2970
Kevin May868eb142019-09-04 17:29:31 +01002971void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2972{
2973 const std::string descriptorName{"AbsQueueDescriptor"};
2974
2975 ValidateNumInputs(workloadInfo, descriptorName, 1);
2976 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2977
2978 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2979 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2980
2981 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2982
2983 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01002984 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002985 DataType::BFloat16,
James Conroyd47a0642019-09-17 14:22:06 +01002986 DataType::Float16,
2987 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002988 DataType::QAsymmU8,
2989 DataType::QSymmS16
James Conroyd47a0642019-09-17 14:22:06 +01002990 };
Kevin May868eb142019-09-04 17:29:31 +01002991
2992 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2993 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2994}
2995
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002996void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2997{
2998 const std::string descriptorName{"SliceQueueDescriptor"};
2999
3000 ValidateNumInputs(workloadInfo, descriptorName, 1);
3001 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3002
3003 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3004 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3005
3006 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3007
3008 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3009 if (rank > 4)
3010 {
3011 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3012 }
3013
3014 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3015
3016 // Check if m_Begin and m_Size have the expected length
3017 if (m_Parameters.m_Begin.size() != rank)
3018 {
3019 throw InvalidArgumentException(descriptorName +
3020 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3021 }
3022 if (m_Parameters.m_Size.size() != rank)
3023 {
3024 throw InvalidArgumentException(descriptorName +
3025 ": Length of size descriptor must equal rank " + std::to_string(rank));
3026 }
3027
3028 // Check if the shape of the output tensor matches m_Size
3029 const TensorShape& outputShape = outputTensorInfo.GetShape();
3030 for (unsigned int i = 0u; i < rank; ++i)
3031 {
3032 if (m_Parameters.m_Size[i] != outputShape[i])
3033 {
3034 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3035 }
3036 }
3037
3038 // Check if the sum of begin offset and size in a given dimension
3039 // does not exceed the size of corresponding input
3040 const TensorShape& inputShape = inputTensorInfo.GetShape();
3041 for(unsigned int i = 0u; i < rank; ++i)
3042 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01003043 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01003044 {
3045 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3046 std::to_string(i) + " exceeds input size.");
3047 }
3048 }
3049}
3050
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003051void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3052{
3053 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3054
3055 ValidateNumInputs(workloadInfo, descriptorName, 1);
3056 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3057
3058 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3059 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3060
3061 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3062 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3063
3064 std::vector<DataType> supportedTypes =
3065 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003066 DataType::BFloat16,
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003067 DataType::Float32,
3068 DataType::Float16,
Derek Lambertif90c56d2020-01-10 17:14:08 +00003069 DataType::QAsymmU8,
3070 DataType::QSymmS16
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01003071 };
3072
3073 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3074 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3075
3076 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3077
3078 if (m_Parameters.m_BlockSize == 0)
3079 {
3080 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3081 }
3082
3083 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3084 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3085 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3086 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3087
3088 const TensorShape& outputShape = outputInfo.GetShape();
3089 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3090 {
3091 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3092 "must be divisible by block size.");
3093 }
3094
3095 const TensorShape& inputShape = inputInfo.GetShape();
3096 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3097 {
3098 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3099 "must be divisible by the square of block size." );
3100 }
3101}
3102
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01003103void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3104{
3105 const std::string descriptorName{"ComparisonQueueDescriptor"};
3106
3107 ValidateNumInputs(workloadInfo, descriptorName, 2);
3108 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3109
3110 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3111 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3112 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3113
3114 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3115 inputTensorInfo1,
3116 outputTensorInfo,
3117 descriptorName,
3118 "input_0",
3119 "input_1");
3120
3121 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3122 {
3123 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3124 }
3125}
3126
josh minor4a3c6102020-01-06 16:40:46 -06003127void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3128{
3129 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3130
3131 ValidateNumInputs(workloadInfo, descriptorName, 1);
3132 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3133
3134 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3135 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3136
3137 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3138
3139 std::vector<DataType> supportedTypes =
3140 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00003141 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06003142 DataType::Float16,
3143 DataType::Float32,
3144 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00003145 DataType::QSymmS16,
3146 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06003147 };
3148
3149 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3150 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3151}
3152
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01003153} // namespace armnn