blob: 6222ba4800b95d0f719f25e9ad5c4b3972c47b74 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "WorkloadData.hpp"
6
7#include "CpuTensorHandle.hpp"
telsoa014fcda012018-03-09 14:13:49 +00008
Matteo Martincigh21350152018-11-28 16:22:22 +00009#include <DataLayoutIndexed.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000013#include <string>
14#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/format.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010017#include <boost/numeric/conversion/cast.hpp>
James Conroyc8724c72019-10-08 15:41:34 +010018#include <TensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000019
Matteo Martincigh21350152018-11-28 16:22:22 +000020using namespace armnnUtils;
21
telsoa014fcda012018-03-09 14:13:49 +000022namespace armnn
23{
24
25//---------------------------------------------------------------
26DataType GetBiasDataType(DataType inputDataType)
27{
28 switch (inputDataType)
29 {
telsoa01c577f2c2018-08-31 09:22:23 +010030 case DataType::Float16:
31 return DataType::Float16;
telsoa014fcda012018-03-09 14:13:49 +000032 case DataType::Float32:
33 return DataType::Float32;
34 case DataType::QuantisedAsymm8:
35 return DataType::Signed32;
Ruomei Yan88d44b82019-05-23 14:29:06 +010036 case DataType::QuantisedSymm16:
37 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000038 default:
39 BOOST_ASSERT_MSG(false, "Invalid input data type");
40 return DataType::Float32;
41 }
42}
43
44namespace
45{
46
47//---------------------------------------------------------------
48//android ndk does not support std::to_string function.
49template <typename T>
50std::string to_string(T value)
51{
52 std::ostringstream os;
53 os << value;
54 return os.str();
55}
56
57//---------------------------------------------------------------
58void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
59{
60 if (!ptr)
61 {
62 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
63 paramName + " parameter must be set.");
64 }
65}
66
67//---------------------------------------------------------------
68void ValidateTensorShapesMatch(const TensorInfo& first,
69 const TensorInfo& second,
70 std::string const& descName,
71 std::string const& firstName,
72 std::string const& secondName)
73{
74 if (first.GetShape() != second.GetShape())
75 {
76 throw InvalidArgumentException(descName + ": "
77 + firstName + " & " + secondName + " must have identical shapes");
78 }
79}
80
81//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010082void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000083{
Sadik Armaganeff363d2019-04-05 15:25:46 +010084 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000085 {
86 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010087 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000088 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
89 }
90}
91
92//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010093void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000094{
Sadik Armaganeff363d2019-04-05 15:25:46 +010095 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000096 {
97 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010098 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +000099 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
100 }
101}
102
103//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100104void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000105 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100106 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000107 std::string const& tensorName)
108{
109 if (tensor.GetNumDimensions() != numDimensions)
110 {
111 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
112 to_string(tensor.GetNumDimensions()) + " dimensions for " +
113 tensorName + " tensor.");
114 }
115}
116
117//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100118void ValidateTensorNumElements(const TensorInfo& tensor,
119 std::string const& descName,
120 unsigned int numElements,
121 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100122{
123 if (tensor.GetNumElements() != numElements)
124 {
125 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100126 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100127 tensorName + " tensor.");
128 }
129}
130
131//---------------------------------------------------------------
132void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100133 unsigned int numDimension,
134 unsigned int numElements,
135 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100136{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100137 const std::string functionName{"ValidateTensorNumDimNumElem"};
138 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
139 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100140}
141
142//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000143void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
144 const std::string& descName, std::string const& tensorName)
145{
146 if (tensor.GetDataType() != dataType)
147 {
148 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
149 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
150 }
151}
152
153//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100154void ValidateTensorQuantizationSpace(const TensorInfo& first,
155 const TensorInfo& second,
156 const std::string& descName,
157 std::string const& firstName,
158 std::string const& secondName)
159{
160 if (!first.IsQuantized() ||
161 !second.IsQuantized())
162 {
163 // Not a quantized type, ignore the validation
164 return;
165 }
166
167 DataType firstDataType = first.GetDataType();
168 DataType secondDataType = second.GetDataType();
169
170 if (firstDataType != secondDataType)
171 {
172 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
173 " must be of the same quantized type, " +
174 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
175 secondName + " is " + GetDataTypeName(secondDataType));
176 }
177
178 if (!first.IsTypeSpaceMatch(second))
179 {
180 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
181 " must have the same quantization space, " +
182 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
183 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
184 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
185 " and scale " + to_string(second.GetQuantizationScale()));
186 }
187}
188
189//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100190void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
191 const TensorInfo& inputTensorInfo,
192 const TensorInfo& weightsTensorInfo,
193 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000194{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000195 // Helper lambda function to validate a single bias quantization scale value
196 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
197 {
ricbur013f4d7102019-10-31 16:22:18 +0000198 constexpr float tolerance = 0.000001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000199 if (std::abs(biasScale - expectedScale) > tolerance)
200 {
201 // Print the float values with extra precision to see very small differences
202 std::stringstream msg;
203 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
204 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
205 biasScale;
206 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
207 }
208 };
209
telsoa014fcda012018-03-09 14:13:49 +0000210 if (biasTensor.GetQuantizationOffset() != 0)
211 {
212 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
213 to_string(biasTensor.GetQuantizationOffset()));
214 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000215
216 if (biasTensor.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000217 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000218 // Validate per-axis quantization scales
219 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
220 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
221
222 if (weightScales.size() != biasScales.size())
223 {
224 std::stringstream msg;
225 msg << descName << ": Expected matchhing number of per-axis quantization scales, but got different "
226 << "values: weights=" << weightScales.size() << ", biases=" << biasScales.size();
227 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
228 }
229
230 for (size_t i = 0ul; i < biasScales.size(); ++i)
231 {
232 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
233 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
234 }
235 }
236 else
237 {
238 // Validate per-tensor quantization scale
239 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
240 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000241 }
242}
243
244//---------------------------------------------------------------
245void ValidateTensors(const std::vector<ITensorHandle*>& vec,
246 unsigned int numExpected,
247 const std::string& descName,
248 const std::string& varName)
249{
250 if (vec.empty() && numExpected > 0)
251 {
252 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
253 }
254
255 for (unsigned int i = 0; i < numExpected; ++i)
256 {
257 if (!vec[i])
258 {
259 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
260 }
261 }
262}
263
264//---------------------------------------------------------------
265void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
266 const TensorInfo& second,
267 const TensorInfo& output,
268 std::string const& descName,
269 std::string const& firstName,
270 std::string const& secondName)
271{
272 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
273 // broadcasted.
274 if (first.GetNumDimensions() != second.GetNumDimensions())
275 {
276 throw InvalidArgumentException(descName + ": Tensors "
277 + firstName + " & " + secondName
278 + " must have the same number of dimensions in order to be broadcasted");
279 }
280 uint32_t numDims = first.GetNumDimensions();
281 std::vector<uint32_t> outputDims(numDims, 0u);
282 for (uint32_t i = 0; i < numDims; i++)
283 {
284 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
285 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
286 if (dimsNotEqual && dimsNotOne)
287 {
288 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
289 }
290 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
291 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100292 TensorShape broadcastShape = TensorShape(boost::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000293 if (broadcastShape != output.GetShape())
294 {
295 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
296 + firstName + " & " + secondName
297 + " does not match the output shape");
298 }
299}
300
301//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100302void ValidateDataTypes(const TensorInfo& info,
303 const std::vector<armnn::DataType>& supportedTypes,
304 std::string const& descName)
305{
306 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
307 if (iterator == supportedTypes.end())
308 {
309 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
310 }
311}
312
James Conroy4d1ff582019-06-10 17:06:39 +0100313//---------------------------------------------------------------
314void ValidateTensorDataTypesMatch(const TensorInfo& first,
315 const TensorInfo& second,
316 std::string const& descName,
317 std::string const& firstName,
318 std::string const& secondName)
319{
320 if (first.GetDataType() != second.GetDataType())
321 {
322 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
323 " must have identical data types.");
324 }
325}
326
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100327//---------------------------------------------------------------
328void ValidateTensorNumElementsMatch(const TensorInfo& first,
329 const TensorInfo& second,
330 std::string const& descName,
331 std::string const& firstName,
332 std::string const& secondName)
333{
334 if (first.GetNumElements() != second.GetNumElements())
335 {
336 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
337 " must have the same number of elements.");
338 }
339}
340
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000341void ValidateWeightDataType(const TensorInfo& inputInfo,
342 const TensorInfo& weightInfo,
343 const std::string& descName)
344{
345 const DataType inputType = inputInfo.GetDataType();
346 if (inputType == DataType::QuantisedAsymm8)
347 {
348 const std::vector<DataType> validTypes =
349 {
350 DataType::QuantisedAsymm8,
351 DataType::QuantizedSymm8PerAxis
352 };
353
354 ValidateDataTypes(weightInfo, validTypes, descName);
355 }
356 else
357 {
358 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
359 }
360}
361
362void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
363 const std::string& descName,
364 const std::string& tensorName)
365{
366 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
367 if (!quantizationDim.has_value())
368 {
369 throw InvalidArgumentException(boost::str(
370 boost::format("%1%: Quantization dimension for per-axis quantization not set on tensor %2%.")
371 % descName % tensorName));
372 }
373
374 if (quantizationDim.value() != 0)
375 {
376 throw InvalidArgumentException(boost::str(
377 boost::format("%1%: Quantization dimension for per-axis quantization expected to be 0 on tensor %2%, "
378 "but got: %3%") % descName % tensorName % quantizationDim.value()));
379 }
380}
381
382void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
383 const std::string& descName,
384 const std::string& tensorName)
385{
386 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
387 if (quantizationOffset != 0)
388 {
389 throw InvalidArgumentException(boost::str(
390 boost::format("%1%: Quantization offset for per-axis quantization expected to be 0 on tensor %2%, "
391 "but got: %3%") % descName % tensorName % quantizationOffset));
392 }
393}
394
395void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
396 const TensorInfo& outputInfo,
397 const TensorInfo& weightInfo,
398 const Optional<TensorInfo>& optionalBiasInfo,
399 const std::string& descName)
400{
401 if (weightInfo.HasPerAxisQuantization())
402 {
403 const DataType inputDataType = inputInfo.GetDataType();
404 const DataType outputDataType = outputInfo.GetDataType();
405
406 const bool canHavePerAxisQuantization =
407 inputDataType == DataType::QuantisedAsymm8 && inputDataType == outputDataType;
408
409 if (!canHavePerAxisQuantization)
410 {
411 throw InvalidArgumentException(boost::str(
412 boost::format("%1%: Per-axis quantization parameters set on tensor %2%, "
413 "but data type does not support per-axis quantization.") % descName % "weight"));
414 }
415
416 ValidateTensorDataType(weightInfo, DataType::QuantizedSymm8PerAxis, descName, "weight");
417 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
418 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
419
420 if (optionalBiasInfo.has_value())
421 {
422 const TensorInfo& biasInfo = optionalBiasInfo.value();
423 if (!biasInfo.HasPerAxisQuantization())
424 {
425 throw InvalidArgumentException(boost::str(
426 boost::format("%1%: Per-axis quantization parameters not set on bias tensor, despite being set on "
427 "weight tensor.") % descName));
428 }
429
430 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
431 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
432 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
433 }
434 }
435}
436
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100437} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000438
439void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
440 unsigned int numExpectedIn, unsigned int numExpectedOut) const
441{
442 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
443 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
444}
445
446//---------------------------------------------------------------
447void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
448{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100449 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000450
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100451 ValidateNumInputs(workloadInfo, descriptorName, 1);
452 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000453
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100454 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
455 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
456
457 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
458 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000459
460 if (m_Inputs.size() != m_Outputs.size())
461 {
462 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100463 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
464 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000465 }
466
467 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
468 {
469 if (!m_Inputs[i])
470 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100471 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
472 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000473 }
474
475 if (!m_Outputs[i])
476 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100477 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
478 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000479 }
480 }
481}
482
Derek Lambertif674aa02019-08-01 15:56:25 +0100483//---------------------------------------------------------------
484void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
485{
486 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
487 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
488
489 if (workloadInfo.m_InputTensorInfos.size() != 1)
490 {
491 throw InvalidArgumentException(boost::str(
492 boost::format("Number of input infos (%1%) is not 1.")
493 % workloadInfo.m_InputTensorInfos.size()));
494
495 }
496
497 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
498 {
499 throw InvalidArgumentException(boost::str(
500 boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
501 % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
502 }
503
504 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
505 {
506 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
507 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
508 {
509 throw InvalidArgumentException(boost::str(
510 boost::format("Number of elements for tensor input and output %1% does not match")
511 % i ));
512 }
513 }
514
515 if (m_Inputs.size() != 1)
516 {
517 throw InvalidArgumentException(boost::str(
518 boost::format("Number of inputs (%1%) is not 1.")
519 % m_Inputs.size()));
520 }
521
522 if (m_Inputs.size() != m_Outputs.size())
523 {
524 throw InvalidArgumentException(boost::str(
525 boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
526 % m_Inputs.size() % m_Outputs.size()));
527 }
528
529 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
530 {
531 if (!m_Inputs[i])
532 {
533 throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
534 }
535
536 if (!m_Outputs[i])
537 {
538 throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
539 }
540 }
541}
542
543//---------------------------------------------------------------
544void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
545{
546 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
547 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
548
Derek Lambertif674aa02019-08-01 15:56:25 +0100549 if (m_Inputs.size() != 1)
550 {
551 throw InvalidArgumentException(boost::str(
552 boost::format("Number of inputs (%1%) is not 1.")
553 % m_Inputs.size()));
554 }
555
556 if (m_Outputs.size() != 0)
557 {
558 throw InvalidArgumentException(boost::str(
559 boost::format("Number of outputs (%1%) is not 0.")
560 % m_Inputs.size() % m_Outputs.size()));
561 }
562
563 if (!m_Inputs[0])
564 {
565 throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
566 }
567}
568
569//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000570void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
571{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100572 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100573
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100574 ValidateNumInputs(workloadInfo, descriptorName, 1);
575 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100576
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100577 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
578 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100579
580 std::vector<DataType> supportedTypes =
581 {
James Conroyd47a0642019-09-17 14:22:06 +0100582 DataType::Float16,
583 DataType::Float32,
584 DataType::QuantisedAsymm8,
585 DataType::QuantisedSymm16
nikraj01248683f2019-05-29 16:46:50 +0100586 };
587
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100588 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
589 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
590 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000591}
592
Nikhil Rajee391d52019-09-05 17:50:44 +0100593void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
594{
595 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
596
597 ValidateNumInputs(workloadInfo, descriptorName, 1);
598 ValidateNumOutputs(workloadInfo, descriptorName, 1);
599
600 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
601 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
602
Nikhil Raj68c2c902019-09-19 11:21:11 +0100603 if (outputTensorInfo.GetDataType() != DataType::Signed32)
604 {
605 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32.");
606 }
607
James Conroyd47a0642019-09-17 14:22:06 +0100608 std::vector<DataType> supportedInputTypes =
609 {
610 DataType::Float16,
611 DataType::Float32,
612 DataType::QuantisedAsymm8,
613 DataType::QuantisedSymm16
614 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100615
James Conroyd47a0642019-09-17 14:22:06 +0100616 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100617
618 auto inputShape = inputTensorInfo.GetShape();
619 auto outputShape = outputTensorInfo.GetShape();
620
621 auto inputNumDimensions = inputShape.GetNumDimensions();
622 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
623
624 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
625
626 // 1D input shape results in scalar output shape
627 if (inputShape.GetNumDimensions() == 1)
628 {
629 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
630 {
631 throw InvalidArgumentException(descriptorName + outputShapeError);
632 }
633 }
634 else
635 {
636 for (unsigned int i = 0; i < unsignedAxis; ++i)
637 {
638 if (outputShape[i] != inputShape[i])
639 {
640 throw InvalidArgumentException(descriptorName + outputShapeError);
641 }
642 }
643
644 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
645 {
646 if (outputShape[i - 1] != inputShape[i])
647 {
648 throw InvalidArgumentException(descriptorName + outputShapeError);
649 }
650 }
651 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100652}
653
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100654void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
655{
656 const std::string descriptorName{"SoftmaxQueueDescriptor"};
657
658 ValidateNumInputs(workloadInfo, descriptorName, 1);
659 ValidateNumOutputs(workloadInfo, descriptorName, 1);
660
661 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
662 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
663
664 std::vector<DataType> supportedTypes =
665 {
James Conroyd47a0642019-09-17 14:22:06 +0100666 DataType::Float16,
667 DataType::Float32,
668 DataType::QuantisedAsymm8,
669 DataType::QuantisedSymm16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100670 };
671
672 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
673 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
674 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
675}
676
telsoa014fcda012018-03-09 14:13:49 +0000677void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
678{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100679 const std::string descriptorName{"SplitterQueueDescriptor"};
680
681 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000682
Ruomei Yan25339c32019-05-28 16:48:20 +0100683 // Check the supported data types
684 std::vector<DataType> supportedTypes =
685 {
James Conroyd47a0642019-09-17 14:22:06 +0100686 DataType::Float32,
687 DataType::Float16,
688 DataType::Boolean,
689 DataType::Signed32,
690 DataType::QuantisedAsymm8,
691 DataType::QuantisedSymm16
Ruomei Yan25339c32019-05-28 16:48:20 +0100692 };
693
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100694 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
695 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100696 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100697 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
698 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
699
700 const std::string outputName = "output_" + std::to_string(i);
701 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100702 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100703
telsoa014fcda012018-03-09 14:13:49 +0000704 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
705 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100706 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000707 }
708
709 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
710 {
711 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100712 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000713 "has to match number of workloadInfo.m_OutputTensorInfos. "
714 "Number of windows: " +
715 to_string(m_ViewOrigins.size()) +
716 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
717 }
718
telsoa01c577f2c2018-08-31 09:22:23 +0100719 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000720 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
721 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
722 {
telsoa01c577f2c2018-08-31 09:22:23 +0100723 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000724 ViewOrigin const& e = m_ViewOrigins[w];
725 if (e.m_Origin.size() != inputDims)
726 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100727 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000728 "have the same dimensionality as the input tensor. "
729 "Window origin (index: " +
730 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
731 " dimensions, the input "
732 "tensor has " +
733 to_string(inputDims) + " dimensions.");
734 }
735 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
736 {
737 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
738 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
739 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100740 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000741 "be smaller or equal than the size of the input in that coord.");
742 }
743 }
744 }
745}
746
Jim Flynne242f2d2019-05-22 14:24:13 +0100747void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000748{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100749 const std::string descriptorName{"ConcatQueueDescriptor"};
750
751 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000752
753 if (m_Inputs.size() <= 0)
754 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100755 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000756 }
757 if (m_Outputs.size() <= 0)
758 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100759 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000760 }
761
762 if (workloadInfo.m_InputTensorInfos.size() <= 0)
763 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100764 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000765 }
766 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
767 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100768 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000769 }
770
Nikhil Raj8599a412018-11-19 14:51:07 +0000771 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
772 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100773 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000774 }
775
776 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
777 {
778 return;
779 }
780
telsoa014fcda012018-03-09 14:13:49 +0000781 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
782 {
783 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100784 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000785 "has to match number of workloadInfo.m_InputTensorInfos. "
786 "Number of windows: " +
787 to_string(m_ViewOrigins.size()) +
788 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
789 }
790
telsoa01c577f2c2018-08-31 09:22:23 +0100791 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000792 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
793 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
794 {
telsoa01c577f2c2018-08-31 09:22:23 +0100795 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000796 ViewOrigin const& e = m_ViewOrigins[w];
797 if (e.m_Origin.size() != outputDims)
798 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100799 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000800 "have the same dimensionality as the output tensor. "
801 "Window origin (index: " +
802 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
803 " dimensions, the output "
804 "tensor has " +
805 to_string(outputDims) + " dimensions.");
806 }
telsoa01c577f2c2018-08-31 09:22:23 +0100807 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000808 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
809 {
810 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
811 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
812 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100813 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000814 "be smaller or equal than the size of the output in that coord.");
815 }
816 }
817 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100818
819 // Check the supported data types
820 std::vector<DataType> supportedTypes =
821 {
James Conroyd47a0642019-09-17 14:22:06 +0100822 DataType::Float32,
823 DataType::Float16,
824 DataType::Boolean,
825 DataType::Signed32,
826 DataType::QuantisedAsymm8,
827 DataType::QuantisedSymm16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100828 };
829
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100830 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
831 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100832 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100833 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
834 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
835
836 const std::string inputName = "input_" + std::to_string(i);
837 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100838 }
telsoa014fcda012018-03-09 14:13:49 +0000839}
840
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100841void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
842{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100843 const std::string descriptorName{"StackQueueDescriptor"};
844
845 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100846
847 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
848 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100849 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100850 }
851
852 // All inputs must have the same shape, which is defined in parameters
853 const TensorShape& inputShape = m_Parameters.m_InputShape;
854 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
855 {
856 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
857 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100858 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100859 }
860 }
861
Matthew Jacksondba634f2019-08-15 15:14:18 +0100862 if (inputShape.GetNumDimensions() > 4)
863 {
864 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
865 }
866
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100867 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
868 // since the output tensor has an additional dimension.
869 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
870 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100871 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100872 "than the number of input dimensions.");
873 }
874
875 // Output shape must be as inferred from the input shape
876 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
877 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
878 {
879 if (outputShape[i] != inputShape[i])
880 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100881 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100882 "match shape inferred from input tensor.");
883 }
884 }
885
886 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
887 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100888 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100889 "match shape inferred from input tensor.");
890 }
891
892 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
893 {
894 if (outputShape[i] != inputShape[i-1])
895 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100896 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100897 "match shape inferred from input tensor.");
898 }
899 }
900
Matthew Jacksondba634f2019-08-15 15:14:18 +0100901 if (outputShape.GetNumDimensions() > 5)
902 {
903 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
904 }
905
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100906 // Check the supported data types
907 std::vector<DataType> supportedTypes =
908 {
James Conroyd47a0642019-09-17 14:22:06 +0100909 DataType::Float32,
910 DataType::Float16,
911 DataType::Boolean,
912 DataType::Signed32,
913 DataType::QuantisedAsymm8,
914 DataType::QuantisedSymm16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100915 };
916
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100917 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100918
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100919 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100920 {
921 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
922 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100923 descriptorName,
924 "input_0",
925 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100926 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100927
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100928 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
929 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100930 descriptorName,
931 "input_0",
932 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100933}
934
telsoa014fcda012018-03-09 14:13:49 +0000935void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
936{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100937 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000938
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100939 ValidateNumInputs(workloadInfo, descriptorName, 1);
940 ValidateNumOutputs(workloadInfo, descriptorName, 1);
941
942 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
943 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
944
945 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
946
947 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +0000948 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100949 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +0000950 }
951
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100952 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000953
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100954 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
955 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000956
957 if (m_Parameters.m_BiasEnabled)
958 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100959 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000960
telsoa01c577f2c2018-08-31 09:22:23 +0100961 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100962 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
963 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000964
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100965 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
966 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000967 }
968
Francis Murtagh46c09d02019-05-28 08:15:28 +0100969 // Check the supported data types
970 std::vector<DataType> supportedTypes =
971 {
James Conroyd47a0642019-09-17 14:22:06 +0100972 DataType::Float32,
973 DataType::Float16,
974 DataType::QuantisedAsymm8,
975 DataType::QuantisedSymm16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100976 };
977
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100978 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
979 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000980}
981
telsoa014fcda012018-03-09 14:13:49 +0000982void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
983{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100984 const std::string descriptorName{"NormalizationQueueDescriptor"};
985
986 ValidateNumInputs(workloadInfo, descriptorName, 1);
987 ValidateNumOutputs(workloadInfo, descriptorName, 1);
988
989 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
990 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100991
992 // Check the supported data types
993 std::vector<DataType> supportedTypes =
994 {
995 DataType::Float16,
996 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100997 DataType::QuantisedAsymm8,
998 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100999 };
1000
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001001 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001002
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001003 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001004
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001005 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001006}
1007
1008void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1009{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001010 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001011
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001012 ValidateNumInputs(workloadInfo, descriptorName, 2);
1013 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1014
1015 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1016 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1017 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1018
1019 std::vector<DataType> supportedTypes =
1020 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001021 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001022 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001023 DataType::QuantisedSymm16,
1024 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001025 };
1026
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001027 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1028 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1029 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001030
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001031 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1032 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001033
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001034 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1035 inputTensorInfo1,
1036 outputTensorInfo,
1037 descriptorName,
1038 "input_0",
1039 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001040}
1041
telsoa014fcda012018-03-09 14:13:49 +00001042void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1043{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001044 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001045
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001046 ValidateNumInputs(workloadInfo, descriptorName, 2);
1047 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1048
1049 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1050 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1051 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1052
1053 std::vector<DataType> supportedTypes =
1054 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001055 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001056 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001057 DataType::QuantisedSymm16,
1058 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001059 };
1060
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001061 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1062 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1063 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001064
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001065 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1066 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001067
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001068 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1069 inputTensorInfo1,
1070 outputTensorInfo,
1071 descriptorName,
1072 "input_0",
1073 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001074}
1075
1076void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1077{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001078 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001079
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001080 ValidateNumInputs(workloadInfo, descriptorName, 1);
1081 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1082
1083 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1084 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001085
1086 std::vector<DataType> supportedTypes =
1087 {
1088 DataType::Float16,
1089 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +01001090 DataType::QuantisedAsymm8,
1091 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001092 };
1093
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001094 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1095 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001096
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001097 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1098 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1099 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001100
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001101 ValidatePointer(m_Mean, descriptorName, "mean");
1102 ValidatePointer(m_Variance, descriptorName, "variance");
1103 ValidatePointer(m_Beta, descriptorName, "beta");
1104 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001105
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001106 const TensorInfo& mean = m_Mean->GetTensorInfo();
1107 const TensorInfo& variance = m_Variance->GetTensorInfo();
1108 const TensorInfo& beta = m_Beta->GetTensorInfo();
1109 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001110
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001111 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1112 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1113 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1114 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001115
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001116 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1117 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1118 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001119}
1120
1121void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1122{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001123 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001124
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001125 ValidateNumInputs(workloadInfo, descriptorName, 1);
1126 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001127
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001128 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1129 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001130
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001131 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1132 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001133
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001134 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001135
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001136 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1137 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001138
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001139 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001140
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001141 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001142 if (m_Parameters.m_BiasEnabled)
1143 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001144 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001145
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001146 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1147 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001148
1149 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1150 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001151 }
1152
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001153 ValidatePerAxisQuantization(inputTensorInfo,
1154 outputTensorInfo,
1155 weightTensorInfo,
1156 optionalBiasTensorInfo,
1157 descriptorName);
1158
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001159 std::vector<DataType> supportedTypes =
1160 {
Ruomei Yan88d44b82019-05-23 14:29:06 +01001161 DataType::Float32,
1162 DataType::QuantisedAsymm8,
1163 DataType::QuantisedSymm16,
1164 DataType::Float16
1165 };
1166
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001167 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1168 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1169}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001170
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001171void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1172{
1173 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1174
1175 ValidateNumInputs(workloadInfo, descriptorName, 1);
1176 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1177
1178 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1179 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1180
1181 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1182 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1183
1184 ValidatePointer(m_Weight, descriptorName, "weight");
1185
1186 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1187 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1188
1189 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1190 {
1191 throw InvalidArgumentException(
1192 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
1193 "cannot be smaller than 1.") % descriptorName %
1194 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
1195 }
1196
1197 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1198
1199 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1200 // inputChannels * channelMultiplier should be equal to outputChannels.
1201 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1202 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1203 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1204 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1205 {
1206 throw InvalidArgumentException(
1207 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
1208 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
1209 "(provided %4%).") % descriptorName % numWeightOutputChannels %
1210 numWeightInputChannels % numWeightChannelMultiplier));
1211 }
1212
1213 ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight");
1214
1215 if (m_Parameters.m_BiasEnabled)
1216 {
1217 ValidatePointer(m_Bias, descriptorName, "bias");
1218
1219 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
1220 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
1221
1222 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1223 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1224 }
1225
1226 std::vector<DataType> supportedTypes =
1227 {
1228 DataType::Float32,
1229 DataType::QuantisedAsymm8,
1230 DataType::QuantisedSymm16,
1231 DataType::Float16
1232 };
1233
1234 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1235 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001236}
1237
1238void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1239{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001240 const std::string descriptorName{"PermuteQueueDescriptor"};
1241
1242 ValidateNumInputs(workloadInfo, descriptorName, 1);
1243 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001244
1245 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1246
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001247 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1248 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001249
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001250 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1251 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001252
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001253 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001254 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001255 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001256 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001257 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1258 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1259 "must match dst dimension " + to_string(mapping[i]) +
1260 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001261 }
1262 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001263
1264 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001265}
1266
1267void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1268{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001269 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001270
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001271 ValidateNumInputs(workloadInfo, descriptorName, 1);
1272 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1273
1274 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1275 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1276
1277 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1278 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001279
1280 std::vector<DataType> supportedTypes =
1281 {
1282 DataType::Float32,
1283 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001284 DataType::QuantisedAsymm8,
1285 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001286 };
1287
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001288 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1289 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001290}
1291
1292void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1293{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001294 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001295
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001296 ValidateNumInputs(workloadInfo, descriptorName, 1);
1297 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1298
1299 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1300 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1301
1302 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1303 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001304
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001305 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001306 {
1307 DataType::Float16,
1308 DataType::Float32,
1309 DataType::QuantisedAsymm8,
1310 DataType::QuantisedSymm16
1311 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001312
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001313 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1314 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001315
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001316 // ResizeBilinear only changes width and height: batch and channel count must match.
1317 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1318 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001319 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001320 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001321 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001322 boost::str(boost::format("%1%: Input batch size (%2%) "
1323 "does not match output batch size (%3%)") %
1324 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001325 }
1326
Teresa Charlin970f43b2019-07-01 13:51:07 +01001327 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001328 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1329 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001330 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001331 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001332 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001333 boost::str(boost::format("%1%: Input channel count (%2%) "
1334 "does not match output channel count (%3%)") %
1335 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001336 }
1337}
1338
1339void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1340{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001341 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001342
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001343 ValidateNumInputs(workloadInfo, descriptorName, 1);
1344 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1345
1346 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1347 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1348
1349 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1350 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001351
1352 std::vector<DataType> supportedTypes =
1353 {
1354 DataType::Float16,
1355 DataType::Float32,
1356 DataType::QuantisedAsymm8,
1357 DataType::QuantisedSymm16
1358 };
1359
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001360 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1361 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001362
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001363 // Resize only changes width and height: batch and channel count must match.
1364 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1365 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001366 if (inputBatchSize != outputBatchSize)
1367 {
1368 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001369 boost::str(boost::format("%1%: Input batch size (%2%) "
1370 "does not match output batch size (%3%)") %
1371 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001372 }
1373
1374 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001375 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1376 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001377 if (inputChannelCount != outputChannelCount)
1378 {
1379 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001380 boost::str(boost::format("%1%: Input channel count (%2%) "
1381 "does not match output channel count (%3%)") %
1382 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001383 }
1384}
1385
1386void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1387{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001388 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001389
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001390 ValidateNumInputs(workloadInfo, descriptorName, 1);
1391 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1392
1393 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1394 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1395
1396 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1397 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1398
1399 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1400
telsoa014fcda012018-03-09 14:13:49 +00001401 if (m_Parameters.m_Min > m_Parameters.m_Max)
1402 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001403 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001404 }
telsoa014fcda012018-03-09 14:13:49 +00001405}
1406
Kevin Mayce5045a2019-10-02 14:07:47 +01001407void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1408{
1409 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1410
1411 ValidateNumInputs(workloadInfo, descriptorName, 1);
1412 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1413
1414 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1415 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1416
1417 if (inputTensorInfo.GetNumDimensions() > 4)
1418 {
1419 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1420 }
1421
1422 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1423
1424 // Check the supported data types
1425 std::vector<DataType> supportedTypes =
1426 {
1427 DataType::Float32,
1428 DataType::Float16
1429 };
1430
1431 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001432 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001433}
1434
telsoa014fcda012018-03-09 14:13:49 +00001435void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1436{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001437 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001438
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001439 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001440 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1441
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001442 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1443 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1444
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001445 if (inputTensorInfo.GetNumDimensions() > 4)
1446 {
1447 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1448 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001449
1450 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001451
1452 // Check the supported data types
1453 std::vector<DataType> supportedTypes =
1454 {
1455 DataType::Float32,
1456 DataType::Float16,
1457 DataType::QuantisedAsymm8,
1458 DataType::QuantisedSymm16
1459 };
1460
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001461 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001462 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1463}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001464
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001465void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1466{
1467 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1468
1469 ValidateNumInputs(workloadInfo, descriptorName, 1);
1470 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1471
1472 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1473 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1474
1475 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1476
1477 std::vector<DataType> supportedTypes =
1478 {
1479 DataType::Float32,
1480 DataType::Float16,
1481 };
1482
1483 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001484 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001485}
1486
1487void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1488{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001489 const std::string descriptorName{"ConstantQueueDescriptor"};
1490
1491 ValidateNumInputs(workloadInfo, descriptorName, 0);
1492 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001493
1494 if (!m_LayerOutput)
1495 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001496 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001497 }
1498
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001499 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1500 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001501
1502 // Check the supported data types
1503 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001504 {
1505 DataType::Float32,
1506 DataType::Float16,
1507 DataType::Signed32,
1508 DataType::QuantisedAsymm8,
1509 DataType::QuantisedSymm16
1510 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001511
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001512 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001513}
1514
1515void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1516{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001517 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001518
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001519 ValidateNumInputs(workloadInfo, descriptorName, 1);
1520 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1521
1522 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1523 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1524
1525 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001526
1527 // Check the supported data types
1528 std::vector<DataType> supportedTypes =
1529 {
1530 DataType::Float32,
1531 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001532 DataType::Signed32,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001533 DataType::QuantisedAsymm8,
1534 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001535 };
1536
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001537 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1538 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001539}
1540
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001541void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1542{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001543 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001544
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001545 ValidateNumInputs(workloadInfo, descriptorName, 1);
1546 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1547
1548 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1549 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1550
1551 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1552 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001553
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001554 if (m_Parameters.m_BlockShape.size() != 2)
1555 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001556 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001557 }
1558
1559 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1560 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001561 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1562 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001563 }
1564
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001565 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001566
1567 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001568 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001569
Matthew Bentham8800c002018-11-19 13:19:28 +00001570 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001571
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001572 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1573 widthPad.first + widthPad.second;
1574 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1575 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001576
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001577 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1578 inputShape[dimensionIndices.GetChannelsIndex()];
1579 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001580
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001581 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001582 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001583 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001584 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001585 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001586 }
1587
1588 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001589 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001590 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1591 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001592 }
nikraj01120522a2019-05-31 11:33:07 +01001593
1594 std::vector<DataType> supportedTypes =
1595 {
1596 DataType::Float16,
1597 DataType::Float32,
1598 DataType::QuantisedAsymm8,
1599 DataType::QuantisedSymm16
1600 };
1601
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001602 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1603 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001604}
1605
Keith Davisa57eccb2019-06-14 17:33:22 +01001606void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1607{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001608 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001609
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001610 ValidateNumInputs(workloadInfo, descriptorName, 1);
1611 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001612
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001613 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1614 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1615
1616 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1617 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001618
1619 std::vector<DataType> supportedTypes =
1620 {
1621 DataType::Float32,
1622 DataType::Float16,
James Conroyd2aa85e2019-07-01 17:12:40 +01001623 DataType::QuantisedAsymm8,
1624 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001625 };
1626
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001627 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1628 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001629
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001630 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1631
1632 if (m_Parameters.m_BlockSize == 0)
1633 {
1634 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1635 }
1636
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001637 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1638 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1639 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1640 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001641
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001642 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001643 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001644 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001645 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1646 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001647 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001648
1649 const TensorShape& outputShape = outputTensorInfo.GetShape();
1650 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1651 {
1652 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1653 "must be divisible by the square of block size." );
1654 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001655}
1656
telsoa014fcda012018-03-09 14:13:49 +00001657void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1658{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001659 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001660
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001661 ValidateNumInputs(workloadInfo, descriptorName, 1);
1662 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1663
1664 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1665 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001666
1667 std::vector<DataType> supportedTypes =
1668 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001669 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001670 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001671 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +01001672 };
1673
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001674 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001675
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001676 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001677 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001678 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001679 }
1680}
1681
telsoa01c577f2c2018-08-31 09:22:23 +01001682void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1683{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001684 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1685
1686 const std::string descriptorName{"LstmQueueDescriptor"};
1687
1688 // check dimensions of all inputs and outputs
1689 if (workloadInfo.m_InputTensorInfos.size() != 3)
1690 {
1691 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1692 }
1693 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1694 {
1695 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1696 }
1697
1698 std::vector<DataType> supportedTypes =
1699 {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001700 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001701 DataType::Float32,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001702 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001703 };
1704
Jan Eilers38e05bd2019-06-26 13:10:09 +01001705 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001706 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1707
Jan Eilers38e05bd2019-06-26 13:10:09 +01001708 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001709 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001710 {
1711 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1712 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001713 descriptorName,
1714 "input_0",
1715 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001716 }
1717 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001718 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001719 {
1720 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1721 workloadInfo.m_OutputTensorInfos[i],
1722 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001723 "input_0",
1724 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001725 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001726
Jan Eilers38e05bd2019-06-26 13:10:09 +01001727 // TODO: check clipping parameter is valid
1728
1729 // Inferring batch size, number of outputs and number of cells from the inputs.
1730 // TODO: figure out if there is a way to make sure the specific inputs are at that index of workloadInfo
1731 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1732 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1733 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1734 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1735 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1736 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1737
Jan Eilers38e05bd2019-06-26 13:10:09 +01001738 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001739 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1740 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001741 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001742 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1743 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001744 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001745 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1746 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001747 // scratchBufferTensor
1748 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001749 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1750 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001751 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001752 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1753 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001754 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001755 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1756 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001757 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001758 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1759 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001760
1761
1762 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1763 if ( m_InputToInputWeights )
1764 {
1765 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1766 (n_cell * n_input), "InputLayerNormWeights");
1767 }
1768
1769 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1770 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1771 (n_cell * n_input), "InputToForgetWeights");
1772
1773 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1774 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1775 (n_cell * n_input), "InputToCellWeights");
1776
1777 if ( m_RecurrentToInputWeights )
1778 {
1779 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1780 (n_cell * n_output), "RecurrentToInputWeights");
1781 }
1782
1783 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1784 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1785 (n_cell * n_output), "RecurrentToForgetWeights");
1786
1787 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1788 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1789 (n_cell * n_output), "RecurrentToCellWeights");
1790
1791 // Make sure the input-gate's parameters are either both present (regular
1792 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1793 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1794 !m_Parameters.m_CifgEnabled) ||
1795 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1796 m_Parameters.m_CifgEnabled));
1797 if (!cifg_weights_all_or_none)
1798 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001799 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1800 "RecurrentToInputWeights must either both be present (regular LSTM) "
1801 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1802 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001803 }
1804
1805 if ( m_CellToInputWeights )
1806 {
1807 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1808 n_cell, "CellToInputWeights");
1809 }
1810 if ( m_CellToForgetWeights )
1811 {
1812 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1813 n_cell, "CellToForgetWeights");
1814 }
1815 if ( m_CellToOutputWeights )
1816 {
1817 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1818 n_cell, "CellToOutputWeights");
1819 }
1820
1821 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1822 bool peephole_weights_all_or_none =
1823 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1824 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1825 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1826 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1827 if (!peephole_weights_all_or_none)
1828 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001829 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001830 }
1831
1832 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1833 if (m_Parameters.m_CifgEnabled)
1834 {
1835 if (m_InputGateBias)
1836 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001837 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001838 }
1839 }
1840 else
1841 {
1842 if (!m_InputGateBias)
1843 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001844 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1845 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001846 }
1847 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1848 n_cell, "InputGateBias");
1849 }
1850
1851 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1852 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1853
1854 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1855 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1856
1857 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1858 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1859
1860 if (m_ProjectionWeights)
1861 {
1862 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
1863 (n_cell * n_output), "ProjectionWeights");
1864 }
1865 if (m_ProjectionBias)
1866 {
1867 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
1868 }
1869
1870 // Making sure the projection tensors are consistent:
1871 // 1) If projection weight is not present, then projection bias should not be
1872 // present.
1873 // 2) If projection weight is present, then projection bias is optional.
1874 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
1875 !m_Parameters.m_ProjectionEnabled)
1876 || (m_ProjectionWeights && !m_ProjectionBias &&
1877 m_Parameters.m_ProjectionEnabled)
1878 || (m_ProjectionWeights && m_ProjectionBias &&
1879 m_Parameters.m_ProjectionEnabled));
1880 if (!projecton_tensors_consistent)
1881 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001882 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001883 }
1884
1885 // The four layer normalization weights either all have values or none of them have values. Additionally, if
1886 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
1887 // either all have values or none of them have values. Layer normalization is used when the values of all the
1888 // layer normalization weights are present
1889 if (m_InputLayerNormWeights)
1890 {
1891 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
1892 }
1893 if (m_ForgetLayerNormWeights)
1894 {
1895 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1896 }
1897 if (m_CellLayerNormWeights)
1898 {
1899 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1900 }
1901 if (m_OutputLayerNormWeights)
1902 {
1903 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1904 }
1905
Jan Eilers38e05bd2019-06-26 13:10:09 +01001906 if (m_Parameters.m_LayerNormEnabled)
1907 {
1908 if (!m_Parameters.m_CifgEnabled)
1909 {
1910 if (!m_InputLayerNormWeights)
1911 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001912 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
1913 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001914 }
1915 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
1916 1, n_cell, "InputLayerNormWeights");
1917 }
1918 else if (m_InputLayerNormWeights)
1919 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001920 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
1921 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001922 }
1923
1924 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
1925 "ForgetLayerNormWeights");
1926 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1927
1928 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
1929 "OutputLayerNormWeights");
1930 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1931
1932 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
1933 "CellLayerNormWeights");
1934 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1935 }
1936 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
1937 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001938 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
1939 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001940 }
telsoa01c577f2c2018-08-31 09:22:23 +01001941}
1942
1943void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1944{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001945 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001946
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001947 ValidateNumInputs(workloadInfo, descriptorName, 1);
1948 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1949
1950 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1951 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1952
1953 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01001954 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001955 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01001956 }
1957
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001958 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001959 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001960 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001961 }
1962
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001963 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001964}
1965
1966void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1967{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001968 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001969
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001970 ValidateNumInputs(workloadInfo, descriptorName, 1);
1971 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1972
1973 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1974 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1975
1976 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001977 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001978 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001979 }
1980
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001981 if (outputTensorInfo.GetDataType() != DataType::Float32)
1982 {
1983 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
1984 }
1985
1986 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001987}
1988
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001989void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1990{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001991 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001992
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001993 ValidateNumInputs(workloadInfo, descriptorName, 2);
1994 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1995
1996 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1997 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1998 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1999
2000 std::vector<DataType> supportedTypes =
2001 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002002 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01002003 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01002004 DataType::QuantisedSymm16,
2005 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002006 };
2007
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002008 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2009 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2010 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002011
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002012 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2013 inputTensorInfo1,
2014 outputTensorInfo,
2015 descriptorName,
2016 "input_0",
2017 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002018}
2019
David Beckc2044fe2018-09-05 15:00:38 +01002020void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2021{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002022 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002023
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002024 ValidateNumInputs(workloadInfo, descriptorName, 2);
2025 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2026
2027 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2028 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2029 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2030
2031 std::vector<DataType> supportedTypes =
2032 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002033 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01002034 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01002035 DataType::QuantisedSymm16,
2036 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002037 };
2038
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002039 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2040 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2041 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002042
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002043 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2044 inputTensorInfo1,
2045 outputTensorInfo,
2046 descriptorName,
2047 "input_0",
2048 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002049}
2050
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002051void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2052{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002053 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002054
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002055 ValidateNumInputs(workloadInfo, descriptorName, 2);
2056 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2057
2058 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2059 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2060 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2061
2062 std::vector<DataType> supportedTypes =
2063 {
Mike Kelly1da02362019-08-01 08:43:57 +01002064 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002065 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002066 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01002067 DataType::QuantisedAsymm8,
2068 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002069 };
2070
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002071 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2072 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2073 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002074
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002075 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2076 inputTensorInfo1,
2077 outputTensorInfo,
2078 descriptorName,
2079 "input_0",
2080 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002081}
2082
narpra01a6bf9122018-09-10 09:50:09 +01002083void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2084{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002085 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002086
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002087 ValidateNumInputs(workloadInfo, descriptorName, 1);
2088 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2089
2090 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2091 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002092
2093 std::vector<DataType> supportedTypes =
2094 {
2095 DataType::Float32,
2096 DataType::Float16,
2097 DataType::QuantisedAsymm8,
2098 DataType::QuantisedSymm16
2099 };
narpra01eb061912018-09-10 17:35:27 +01002100
James Conroy4d1ff582019-06-10 17:06:39 +01002101 // First check if input tensor data type is supported, then
2102 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002103 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2104 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002105
narpra0132b90462018-09-13 11:07:48 +01002106 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002107 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002108 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002109 }
narpra0132b90462018-09-13 11:07:48 +01002110 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002111 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002112 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002113 }
2114 else
2115 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002116 unsigned int outputDim =
2117 inputTensorInfo.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
2118 ValidateTensorNumDimensions(outputTensorInfo,
2119 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002120 outputDim > 0 ? outputDim : 1,
2121 "output");
2122 }
narpra01a6bf9122018-09-10 09:50:09 +01002123}
2124
jimfly012c9322a2018-09-19 10:59:49 +01002125void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2126{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002127 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002128
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002129 ValidateNumInputs(workloadInfo, descriptorName, 1);
2130 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2131
2132 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2133 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002134
jimfly012c9322a2018-09-19 10:59:49 +01002135 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002136 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2137
jimfly012c9322a2018-09-19 10:59:49 +01002138 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002139 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2140 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2141 "as there are dimensions in the input tensor that is " +
2142 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2143 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002144 }
2145}
2146
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002147void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2148{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002149 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002150
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002151 ValidateNumInputs(workloadInfo, descriptorName, 1);
2152 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002153
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002154 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2155 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2156
Sadik Armagan2208b602019-07-31 16:36:27 +01002157 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002158 {
James Conroyd47a0642019-09-17 14:22:06 +01002159 DataType::Float32,
2160 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002161 };
2162
2163 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002164
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002165 if (outputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
2166 outputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002167 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002168 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002169 }
2170}
2171
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002172void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2173{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002174 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002175
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002176 ValidateNumInputs(workloadInfo, descriptorName, 1);
2177 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002178
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002179 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2180 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002181
2182 std::vector<DataType> supportedTypes =
2183 {
James Conroyd47a0642019-09-17 14:22:06 +01002184 DataType::Float32,
2185 DataType::Float16,
2186 DataType::QuantisedAsymm8,
2187 DataType::QuantisedSymm16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002188 };
2189
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002190 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2191 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002192}
2193
Conor Kennedy430b5d82018-11-14 15:28:28 +00002194void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2195{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002196 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002197
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002198 ValidateNumInputs(workloadInfo, descriptorName, 1);
2199 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2200
2201 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2202 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002203
2204 std::vector<DataType> supportedTypes =
2205 {
2206 DataType::Float16,
2207 DataType::Float32,
Matteo Martincigh42666a12019-05-29 08:53:41 +01002208 DataType::QuantisedAsymm8,
2209 DataType::QuantisedSymm16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002210 };
2211
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002212 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2213 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002214
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002215 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002216
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002217 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002218 if (rank > 4)
2219 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002220 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002221 }
2222
Conor Kennedy430b5d82018-11-14 15:28:28 +00002223 // Begin, End & Stride length must be of rank(input0)
2224 if (m_Parameters.m_Begin.size() != rank)
2225 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002226 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002227 }
2228
2229 if (m_Parameters.m_End.size() != rank)
2230 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002231 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002232 }
2233
2234 if (m_Parameters.m_Stride.size() != rank)
2235 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002236 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002237 }
2238
2239 // Stride entries must be non-zero
2240 for (auto& stride : m_Parameters.m_Stride)
2241 {
2242 if (stride == 0)
2243 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002244 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002245 }
2246 }
2247}
2248
kevmay0190539692018-11-29 08:40:19 +00002249void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2250{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002251 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002252
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002253 ValidateNumInputs(workloadInfo, descriptorName, 2);
2254 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2255
2256 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2257 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2258 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2259
2260 std::vector<DataType> supportedTypes =
2261 {
Mike Kelly1da02362019-08-01 08:43:57 +01002262 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002263 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002264 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01002265 DataType::QuantisedAsymm8,
2266 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002267 };
2268
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002269 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2270 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2271 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002272
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002273 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2274 inputTensorInfo1,
2275 outputTensorInfo,
2276 descriptorName,
2277 "input_0",
2278 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002279}
2280
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002281void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2282{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002283 const std::string descriptorName{"DebugQueueDescriptor"};
2284
2285 ValidateNumInputs(workloadInfo, descriptorName, 1);
2286 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002287}
2288
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002289void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2290{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002291 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002292
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002293 ValidateNumInputs(workloadInfo, descriptorName, 2);
2294 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002295
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002296 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2297 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2298 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2299
2300 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2301 inputTensorInfo1,
2302 outputTensorInfo,
2303 descriptorName,
2304 "input_0",
2305 "input_1");
2306
2307 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002308 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002309 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002310 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002311}
2312
FrancisMurtagh878f0232018-12-19 10:56:15 +00002313void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2314{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002315 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002316
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002317 ValidateNumInputs(workloadInfo, descriptorName, 2);
2318 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002319
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002320 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2321 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2322 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2323
2324 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2325 inputTensorInfo1,
2326 outputTensorInfo,
2327 descriptorName,
2328 "input_0",
2329 "input_1");
2330
2331 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002332 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002333 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002334 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002335}
2336
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002337void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2338{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002339 const std::string descriptorName{"RsqrtQueueDescriptor"};
2340
2341 ValidateNumInputs(workloadInfo, descriptorName, 1);
2342 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2343
2344 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2345 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2346
2347 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002348
2349 std::vector<DataType> supportedTypes =
2350 {
James Conroyd47a0642019-09-17 14:22:06 +01002351 DataType::Float16,
2352 DataType::Float32,
2353 DataType::QuantisedAsymm8,
2354 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01002355 };
2356
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002357 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2358 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002359}
2360
narpra01b89b05f2019-01-16 09:53:09 +00002361void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2362{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002363 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002364
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002365 ValidateNumInputs(workloadInfo, descriptorName, 2);
2366 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002367
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002368 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2369 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002370 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002371 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002372 }
2373
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002374 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2375 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2376
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002377 std::vector<DataType> supportedTypes =
2378 {
James Conroyd47a0642019-09-17 14:22:06 +01002379 DataType::Float16,
2380 DataType::Float32,
2381 DataType::QuantisedAsymm8,
2382 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002383 };
2384
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002385 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002386
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002387 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002388
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002389 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2390 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002391}
2392
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002393void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2394{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002395 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2396
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002397 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002398
2399 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2400 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002401 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002402 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2403 }
2404
2405 if (m_Anchors == nullptr)
2406 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002407 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002408 }
2409
2410 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002411 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2412 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2413
2414 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002415 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002416 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2417 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002418
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002419 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2420 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2421 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002422
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002423 const std::vector<DataType> supportedInputTypes =
2424 {
2425 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002426 DataType::Float16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002427 DataType::QuantisedAsymm8,
2428 DataType::QuantisedSymm16
2429 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002430
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002431 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2432 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2433 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2434
2435 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2436 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2437 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2438 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2439
2440 // NOTE: Output is always Float32 regardless of input type
2441 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2442 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2443 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2444 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002445
2446 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2447 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002448 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002449 "must be positive and less than or equal to 1.");
2450 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002451
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002452 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2453 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002454 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002455 "should be equal to number of classes + 1.");
2456 }
2457}
2458
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002459void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2460{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002461 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002462
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002463 ValidateNumInputs(workloadInfo, descriptorName, 1);
2464 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2465
2466 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2467 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2468
2469 if (inputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
2470 inputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002471 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002472 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002473 }
2474
Sadik Armagan2208b602019-07-31 16:36:27 +01002475 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002476 {
James Conroyd47a0642019-09-17 14:22:06 +01002477 DataType::Float32,
2478 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002479 };
2480
2481 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002482}
2483
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002484void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2485{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002486 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002487
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002488 ValidateNumInputs(workloadInfo, descriptorName, 2);
2489 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002490
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002491 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2492 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2493 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002494
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002495 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2496 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2497
2498 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2499 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002500}
2501
Sadik Armaganeff363d2019-04-05 15:25:46 +01002502void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2503{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002504 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002505
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002506 ValidateNumInputs(workloadInfo, descriptorName, 2);
2507 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2508
2509 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2510 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2511
2512 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2513 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2514
2515 std::vector<DataType> supportedTypes =
2516 {
Sadik Armaganeff363d2019-04-05 15:25:46 +01002517 DataType::Float32,
2518 DataType::QuantisedAsymm8,
2519 DataType::QuantisedSymm16
2520 };
2521
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002522 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2523 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002524
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002525 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2526 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002527
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002528 ValidateTensorShapesMatch(inputTensorInfo0,
2529 outputTensorInfo0,
2530 descriptorName,
2531 "input_0",
2532 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002533
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002534 ValidateTensorShapesMatch(inputTensorInfo0,
2535 outputTensorInfo1,
2536 descriptorName,
2537 "input_0",
2538 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002539}
2540
Matteo Martincigh49124022019-01-11 13:25:59 +00002541void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2542{
2543 // This is internally generated so it should not need validation.
2544}
2545
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002546void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2547{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002548 const std::string& descriptorName{"PreluQueueDescriptor"};
2549
2550 ValidateNumInputs(workloadInfo, descriptorName, 2);
2551 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2552
2553 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2554 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2555 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002556
2557 std::vector<DataType> supportedTypes
2558 {
2559 DataType::Float16,
2560 DataType::Float32,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002561 DataType::QuantisedAsymm8,
2562 DataType::QuantisedSymm16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002563 };
2564
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002565 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2566 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002567
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002568 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002569
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002570 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2571 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002572
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002573 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2574 alphaTensorInfo,
2575 outputTensorInfo,
2576 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002577 "input",
2578 "alpha");
2579}
2580
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002581void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2582{
2583 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2584
2585 ValidateNumInputs(workloadInfo, descriptorName, 1);
2586 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2587
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002588 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2589 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2590
2591 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2592 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002593
2594 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002595
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002596 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2597 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002598
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002599 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2600
2601 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002602 if (m_Parameters.m_BiasEnabled)
2603 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002604 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002605
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002606 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2607 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002608
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002609 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002610 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002611 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002612
2613 ValidatePerAxisQuantization(inputTensorInfo,
2614 outputTensorInfo,
2615 weightTensorInfo,
2616 optionalBiasTensorInfo,
2617 descriptorName);
2618
2619 std::vector<DataType> supportedTypes =
2620 {
2621 DataType::Float32,
2622 DataType::Float16,
2623 DataType::QuantisedAsymm8,
2624 DataType::QuantisedSymm16
2625 };
2626
2627 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2628 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002629}
2630
James Conroy9c3cae82019-08-01 16:01:48 +01002631void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2632{
2633 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
2634
2635 // Validate number of inputs/outputs
2636 ValidateNumInputs(workloadInfo, descriptorName, 3);
2637 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2638
2639 // Input/output tensor infos
2640 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2641 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
2642 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
2643
2644 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2645 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2646
2647 std::vector<DataType> inputOutputSupportedTypes =
2648 {
2649 DataType::QuantisedAsymm8
2650 };
2651
2652 std::vector<DataType> cellStateSupportedTypes =
2653 {
2654 DataType::QuantisedSymm16
2655 };
2656
2657 std::vector<DataType> weightsSupportedTypes =
2658 {
2659 DataType::QuantisedAsymm8
2660 };
2661
2662 std::vector<DataType> biasSupportedTypes =
2663 {
2664 DataType::Signed32
2665 };
2666
2667 // Validate types of input/output tensors
2668 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2669 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2670 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2671
2672 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2673 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2674
2675 // Validate matching types of input/output tensors
2676 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2677 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2678 "outputStateIn", "outputStateOut");
2679 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2680
2681 // Validate matching quantization info for input/output tensors
2682 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2683 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
2684 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002685
James Conroy9c3cae82019-08-01 16:01:48 +01002686 // Infer number of batches, input size and output size from tensor dimensions
2687 const uint32_t numBatches = inputInfo.GetShape()[0];
2688 const uint32_t inputSize = inputInfo.GetShape()[1];
2689 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
2690
2691 // Validate number of dimensions and number of elements for input/output tensors
2692 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2693 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
2694 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2695 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
2696 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2697
2698 // Validate number of dimensions and number of elements for weights tensors
2699 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
2700 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
2701 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
2702
2703 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2704 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2705 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
2706
2707 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2708 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2709 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
2710
2711 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2712 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2713 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
2714
2715 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
2716 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
2717 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
2718
2719 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2720 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2721 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
2722 " RecurrentToForgetWeights");
2723
2724 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2725 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2726 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2727
2728 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2729 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2730 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2731
2732 // Validate data types for weights tensors (all should match each other)
2733 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
2734
2735 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
2736 "inputToInputWeights", "inputToForgetWeights");
2737 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
2738 "inputToInputWeights", "inputToCellWeights");
2739 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2740 "inputToInputWeights", "inputToOutputWeights");
2741
2742 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
2743 "inputToInputWeights", "recurrentToInputWeights");
2744 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2745 "inputToInputWeights", "recurrentToForgeteights");
2746 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2747 "inputToInputWeights", "recurrentToCellWeights");
2748 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2749 "inputToInputWeights", "recurrentToOutputWeights");
2750
2751 // Validate matching quantization info for weight tensors (all should match each other)
2752 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
2753 descriptorName, "inputToInputWeights", "inputToForgetWeights");
2754 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
2755 descriptorName, "inputToInputWeights", "inputToCellWeights");
2756 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
2757 descriptorName, "inputToInputWeights", "inputToOutputWeights");
2758
2759 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
2760 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
2761 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
2762 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
2763 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
2764 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
2765 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
2766 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
2767
2768 // Validate number of dimensions and number of elements in bias tensors
2769 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
2770 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
2771 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
2772
2773 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2774 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2775 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
2776
2777 ValidatePointer(m_CellBias, descriptorName, "CellBias");
2778 auto cellBiasInfo = m_CellBias->GetTensorInfo();
2779 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
2780
2781 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
2782 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
2783 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
2784
2785 // Validate data types for bias tensors (all should match each other)
2786 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
2787
2788 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
2789 "inputGateBias", "forgetGateBias");
2790 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
2791 "inputGateBias", "cellBias");
2792 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
2793 "inputGateBias", "outputGateBias");
2794
2795 // Validate bias tensor quantization info
2796 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2797 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2798 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2799 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2800}
2801
Kevin May868eb142019-09-04 17:29:31 +01002802void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2803{
2804 const std::string descriptorName{"AbsQueueDescriptor"};
2805
2806 ValidateNumInputs(workloadInfo, descriptorName, 1);
2807 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2808
2809 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2810 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2811
2812 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2813
2814 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01002815 {
2816 DataType::Float16,
2817 DataType::Float32,
2818 DataType::QuantisedAsymm8,
2819 DataType::QuantisedSymm16
2820 };
Kevin May868eb142019-09-04 17:29:31 +01002821
2822 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2823 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2824}
2825
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002826void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2827{
2828 const std::string descriptorName{"SliceQueueDescriptor"};
2829
2830 ValidateNumInputs(workloadInfo, descriptorName, 1);
2831 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2832
2833 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2834 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2835
2836 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2837
2838 const unsigned int rank = inputTensorInfo.GetNumDimensions();
2839 if (rank > 4)
2840 {
2841 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
2842 }
2843
2844 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
2845
2846 // Check if m_Begin and m_Size have the expected length
2847 if (m_Parameters.m_Begin.size() != rank)
2848 {
2849 throw InvalidArgumentException(descriptorName +
2850 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
2851 }
2852 if (m_Parameters.m_Size.size() != rank)
2853 {
2854 throw InvalidArgumentException(descriptorName +
2855 ": Length of size descriptor must equal rank " + std::to_string(rank));
2856 }
2857
2858 // Check if the shape of the output tensor matches m_Size
2859 const TensorShape& outputShape = outputTensorInfo.GetShape();
2860 for (unsigned int i = 0u; i < rank; ++i)
2861 {
2862 if (m_Parameters.m_Size[i] != outputShape[i])
2863 {
2864 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
2865 }
2866 }
2867
2868 // Check if the sum of begin offset and size in a given dimension
2869 // does not exceed the size of corresponding input
2870 const TensorShape& inputShape = inputTensorInfo.GetShape();
2871 for(unsigned int i = 0u; i < rank; ++i)
2872 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002873 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002874 {
2875 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
2876 std::to_string(i) + " exceeds input size.");
2877 }
2878 }
2879}
2880
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01002881void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2882{
2883 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
2884
2885 ValidateNumInputs(workloadInfo, descriptorName, 1);
2886 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2887
2888 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
2889 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
2890
2891 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
2892 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
2893
2894 std::vector<DataType> supportedTypes =
2895 {
2896 DataType::Float32,
2897 DataType::Float16,
2898 DataType::QuantisedAsymm8,
2899 DataType::QuantisedSymm16
2900 };
2901
2902 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
2903 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
2904
2905 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
2906
2907 if (m_Parameters.m_BlockSize == 0)
2908 {
2909 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
2910 }
2911
2912 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
2913 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
2914 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
2915 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
2916
2917 const TensorShape& outputShape = outputInfo.GetShape();
2918 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
2919 {
2920 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
2921 "must be divisible by block size.");
2922 }
2923
2924 const TensorShape& inputShape = inputInfo.GetShape();
2925 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
2926 {
2927 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
2928 "must be divisible by the square of block size." );
2929 }
2930}
2931
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01002932void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2933{
2934 const std::string descriptorName{"ComparisonQueueDescriptor"};
2935
2936 ValidateNumInputs(workloadInfo, descriptorName, 2);
2937 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2938
2939 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2940 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2941 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2942
2943 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2944 inputTensorInfo1,
2945 outputTensorInfo,
2946 descriptorName,
2947 "input_0",
2948 "input_1");
2949
2950 if (outputTensorInfo.GetDataType() != DataType::Boolean)
2951 {
2952 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
2953 }
2954}
2955
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002956} // namespace armnn