blob: 443dc8eae31dcc539081899e7f6a8ee9012e18fc [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "WorkloadData.hpp"
6
7#include "CpuTensorHandle.hpp"
telsoa014fcda012018-03-09 14:13:49 +00008
Matteo Martincigh21350152018-11-28 16:22:22 +00009#include <DataLayoutIndexed.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000013#include <string>
14#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/format.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010017#include <boost/numeric/conversion/cast.hpp>
James Conroyc8724c72019-10-08 15:41:34 +010018#include <TensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000019
Matteo Martincigh21350152018-11-28 16:22:22 +000020using namespace armnnUtils;
21
telsoa014fcda012018-03-09 14:13:49 +000022namespace armnn
23{
24
25//---------------------------------------------------------------
26DataType GetBiasDataType(DataType inputDataType)
27{
28 switch (inputDataType)
29 {
telsoa01c577f2c2018-08-31 09:22:23 +010030 case DataType::Float16:
31 return DataType::Float16;
telsoa014fcda012018-03-09 14:13:49 +000032 case DataType::Float32:
33 return DataType::Float32;
34 case DataType::QuantisedAsymm8:
35 return DataType::Signed32;
Ruomei Yan88d44b82019-05-23 14:29:06 +010036 case DataType::QuantisedSymm16:
37 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000038 default:
39 BOOST_ASSERT_MSG(false, "Invalid input data type");
40 return DataType::Float32;
41 }
42}
43
44namespace
45{
46
47//---------------------------------------------------------------
48//android ndk does not support std::to_string function.
49template <typename T>
50std::string to_string(T value)
51{
52 std::ostringstream os;
53 os << value;
54 return os.str();
55}
56
57//---------------------------------------------------------------
58void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
59{
60 if (!ptr)
61 {
62 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
63 paramName + " parameter must be set.");
64 }
65}
66
67//---------------------------------------------------------------
68void ValidateTensorShapesMatch(const TensorInfo& first,
69 const TensorInfo& second,
70 std::string const& descName,
71 std::string const& firstName,
72 std::string const& secondName)
73{
74 if (first.GetShape() != second.GetShape())
75 {
76 throw InvalidArgumentException(descName + ": "
77 + firstName + " & " + secondName + " must have identical shapes");
78 }
79}
80
81//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010082void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000083{
Sadik Armaganeff363d2019-04-05 15:25:46 +010084 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000085 {
86 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010087 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000088 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
89 }
90}
91
92//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010093void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000094{
Sadik Armaganeff363d2019-04-05 15:25:46 +010095 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000096 {
97 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010098 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +000099 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
100 }
101}
102
103//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100104void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000105 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100106 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000107 std::string const& tensorName)
108{
109 if (tensor.GetNumDimensions() != numDimensions)
110 {
111 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
112 to_string(tensor.GetNumDimensions()) + " dimensions for " +
113 tensorName + " tensor.");
114 }
115}
116
117//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100118void ValidateTensorNumElements(const TensorInfo& tensor,
119 std::string const& descName,
120 unsigned int numElements,
121 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100122{
123 if (tensor.GetNumElements() != numElements)
124 {
125 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100126 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100127 tensorName + " tensor.");
128 }
129}
130
131//---------------------------------------------------------------
132void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100133 unsigned int numDimension,
134 unsigned int numElements,
135 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100136{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100137 const std::string functionName{"ValidateTensorNumDimNumElem"};
138 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
139 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100140}
141
142//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000143void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
144 const std::string& descName, std::string const& tensorName)
145{
146 if (tensor.GetDataType() != dataType)
147 {
148 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
149 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
150 }
151}
152
153//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100154void ValidateTensorQuantizationSpace(const TensorInfo& first,
155 const TensorInfo& second,
156 const std::string& descName,
157 std::string const& firstName,
158 std::string const& secondName)
159{
160 if (!first.IsQuantized() ||
161 !second.IsQuantized())
162 {
163 // Not a quantized type, ignore the validation
164 return;
165 }
166
167 DataType firstDataType = first.GetDataType();
168 DataType secondDataType = second.GetDataType();
169
170 if (firstDataType != secondDataType)
171 {
172 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
173 " must be of the same quantized type, " +
174 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
175 secondName + " is " + GetDataTypeName(secondDataType));
176 }
177
178 if (!first.IsTypeSpaceMatch(second))
179 {
180 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
181 " must have the same quantization space, " +
182 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
183 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
184 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
185 " and scale " + to_string(second.GetQuantizationScale()));
186 }
187}
188
189//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100190void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
191 const TensorInfo& inputTensorInfo,
192 const TensorInfo& weightsTensorInfo,
193 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000194{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000195 // Helper lambda function to validate a single bias quantization scale value
196 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
197 {
ricbur013f4d7102019-10-31 16:22:18 +0000198 constexpr float tolerance = 0.000001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000199 if (std::abs(biasScale - expectedScale) > tolerance)
200 {
201 // Print the float values with extra precision to see very small differences
202 std::stringstream msg;
203 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
204 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
205 biasScale;
206 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
207 }
208 };
209
telsoa014fcda012018-03-09 14:13:49 +0000210 if (biasTensor.GetQuantizationOffset() != 0)
211 {
212 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
213 to_string(biasTensor.GetQuantizationOffset()));
214 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000215
216 if (biasTensor.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000217 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000218 // Validate per-axis quantization scales
219 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
220 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
221
222 if (weightScales.size() != biasScales.size())
223 {
224 std::stringstream msg;
225 msg << descName << ": Expected matchhing number of per-axis quantization scales, but got different "
226 << "values: weights=" << weightScales.size() << ", biases=" << biasScales.size();
227 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
228 }
229
230 for (size_t i = 0ul; i < biasScales.size(); ++i)
231 {
232 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
233 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
234 }
235 }
236 else
237 {
238 // Validate per-tensor quantization scale
239 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
240 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000241 }
242}
243
244//---------------------------------------------------------------
245void ValidateTensors(const std::vector<ITensorHandle*>& vec,
246 unsigned int numExpected,
247 const std::string& descName,
248 const std::string& varName)
249{
250 if (vec.empty() && numExpected > 0)
251 {
252 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
253 }
254
255 for (unsigned int i = 0; i < numExpected; ++i)
256 {
257 if (!vec[i])
258 {
259 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
260 }
261 }
262}
263
264//---------------------------------------------------------------
265void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
266 const TensorInfo& second,
267 const TensorInfo& output,
268 std::string const& descName,
269 std::string const& firstName,
270 std::string const& secondName)
271{
272 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
273 // broadcasted.
274 if (first.GetNumDimensions() != second.GetNumDimensions())
275 {
276 throw InvalidArgumentException(descName + ": Tensors "
277 + firstName + " & " + secondName
278 + " must have the same number of dimensions in order to be broadcasted");
279 }
280 uint32_t numDims = first.GetNumDimensions();
281 std::vector<uint32_t> outputDims(numDims, 0u);
282 for (uint32_t i = 0; i < numDims; i++)
283 {
284 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
285 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
286 if (dimsNotEqual && dimsNotOne)
287 {
288 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
289 }
290 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
291 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100292 TensorShape broadcastShape = TensorShape(boost::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000293 if (broadcastShape != output.GetShape())
294 {
295 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
296 + firstName + " & " + secondName
297 + " does not match the output shape");
298 }
299}
300
301//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100302void ValidateDataTypes(const TensorInfo& info,
303 const std::vector<armnn::DataType>& supportedTypes,
304 std::string const& descName)
305{
306 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
307 if (iterator == supportedTypes.end())
308 {
309 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
310 }
311}
312
James Conroy4d1ff582019-06-10 17:06:39 +0100313//---------------------------------------------------------------
314void ValidateTensorDataTypesMatch(const TensorInfo& first,
315 const TensorInfo& second,
316 std::string const& descName,
317 std::string const& firstName,
318 std::string const& secondName)
319{
320 if (first.GetDataType() != second.GetDataType())
321 {
322 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
323 " must have identical data types.");
324 }
325}
326
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100327//---------------------------------------------------------------
328void ValidateTensorNumElementsMatch(const TensorInfo& first,
329 const TensorInfo& second,
330 std::string const& descName,
331 std::string const& firstName,
332 std::string const& secondName)
333{
334 if (first.GetNumElements() != second.GetNumElements())
335 {
336 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
337 " must have the same number of elements.");
338 }
339}
340
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000341void ValidateWeightDataType(const TensorInfo& inputInfo,
342 const TensorInfo& weightInfo,
343 const std::string& descName)
344{
345 const DataType inputType = inputInfo.GetDataType();
346 if (inputType == DataType::QuantisedAsymm8)
347 {
348 const std::vector<DataType> validTypes =
349 {
350 DataType::QuantisedAsymm8,
351 DataType::QuantizedSymm8PerAxis
352 };
353
354 ValidateDataTypes(weightInfo, validTypes, descName);
355 }
356 else
357 {
358 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
359 }
360}
361
362void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
363 const std::string& descName,
364 const std::string& tensorName)
365{
366 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
367 if (!quantizationDim.has_value())
368 {
369 throw InvalidArgumentException(boost::str(
370 boost::format("%1%: Quantization dimension for per-axis quantization not set on tensor %2%.")
371 % descName % tensorName));
372 }
373
374 if (quantizationDim.value() != 0)
375 {
376 throw InvalidArgumentException(boost::str(
377 boost::format("%1%: Quantization dimension for per-axis quantization expected to be 0 on tensor %2%, "
378 "but got: %3%") % descName % tensorName % quantizationDim.value()));
379 }
380}
381
382void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
383 const std::string& descName,
384 const std::string& tensorName)
385{
386 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
387 if (quantizationOffset != 0)
388 {
389 throw InvalidArgumentException(boost::str(
390 boost::format("%1%: Quantization offset for per-axis quantization expected to be 0 on tensor %2%, "
391 "but got: %3%") % descName % tensorName % quantizationOffset));
392 }
393}
394
395void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
396 const TensorInfo& outputInfo,
397 const TensorInfo& weightInfo,
398 const Optional<TensorInfo>& optionalBiasInfo,
399 const std::string& descName)
400{
401 if (weightInfo.HasPerAxisQuantization())
402 {
403 const DataType inputDataType = inputInfo.GetDataType();
404 const DataType outputDataType = outputInfo.GetDataType();
405
406 const bool canHavePerAxisQuantization =
407 inputDataType == DataType::QuantisedAsymm8 && inputDataType == outputDataType;
408
409 if (!canHavePerAxisQuantization)
410 {
411 throw InvalidArgumentException(boost::str(
412 boost::format("%1%: Per-axis quantization parameters set on tensor %2%, "
413 "but data type does not support per-axis quantization.") % descName % "weight"));
414 }
415
416 ValidateTensorDataType(weightInfo, DataType::QuantizedSymm8PerAxis, descName, "weight");
417 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
418 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
419
420 if (optionalBiasInfo.has_value())
421 {
422 const TensorInfo& biasInfo = optionalBiasInfo.value();
423 if (!biasInfo.HasPerAxisQuantization())
424 {
425 throw InvalidArgumentException(boost::str(
426 boost::format("%1%: Per-axis quantization parameters not set on bias tensor, despite being set on "
427 "weight tensor.") % descName));
428 }
429
430 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
431 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
432 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
433 }
434 }
435}
436
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100437} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000438
439void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
440 unsigned int numExpectedIn, unsigned int numExpectedOut) const
441{
442 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
443 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
444}
445
446//---------------------------------------------------------------
447void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
448{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100449 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000450
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100451 ValidateNumInputs(workloadInfo, descriptorName, 1);
452 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000453
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100454 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
455 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
456
457 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
458 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000459
460 if (m_Inputs.size() != m_Outputs.size())
461 {
462 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100463 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
464 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000465 }
466
467 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
468 {
469 if (!m_Inputs[i])
470 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100471 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
472 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000473 }
474
475 if (!m_Outputs[i])
476 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100477 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
478 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000479 }
480 }
481}
482
Derek Lambertif674aa02019-08-01 15:56:25 +0100483//---------------------------------------------------------------
484void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
485{
486 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
487 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
488
489 if (workloadInfo.m_InputTensorInfos.size() != 1)
490 {
491 throw InvalidArgumentException(boost::str(
492 boost::format("Number of input infos (%1%) is not 1.")
493 % workloadInfo.m_InputTensorInfos.size()));
494
495 }
496
497 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
498 {
499 throw InvalidArgumentException(boost::str(
500 boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
501 % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
502 }
503
504 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
505 {
506 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
507 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
508 {
509 throw InvalidArgumentException(boost::str(
510 boost::format("Number of elements for tensor input and output %1% does not match")
511 % i ));
512 }
513 }
514
515 if (m_Inputs.size() != 1)
516 {
517 throw InvalidArgumentException(boost::str(
518 boost::format("Number of inputs (%1%) is not 1.")
519 % m_Inputs.size()));
520 }
521
522 if (m_Inputs.size() != m_Outputs.size())
523 {
524 throw InvalidArgumentException(boost::str(
525 boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
526 % m_Inputs.size() % m_Outputs.size()));
527 }
528
529 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
530 {
531 if (!m_Inputs[i])
532 {
533 throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
534 }
535
536 if (!m_Outputs[i])
537 {
538 throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
539 }
540 }
541}
542
543//---------------------------------------------------------------
544void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
545{
546 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
547 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
548
Derek Lambertif674aa02019-08-01 15:56:25 +0100549 if (m_Inputs.size() != 1)
550 {
551 throw InvalidArgumentException(boost::str(
552 boost::format("Number of inputs (%1%) is not 1.")
553 % m_Inputs.size()));
554 }
555
556 if (m_Outputs.size() != 0)
557 {
558 throw InvalidArgumentException(boost::str(
559 boost::format("Number of outputs (%1%) is not 0.")
560 % m_Inputs.size() % m_Outputs.size()));
561 }
562
563 if (!m_Inputs[0])
564 {
565 throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
566 }
567}
568
569//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000570void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
571{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100572 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100573
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100574 ValidateNumInputs(workloadInfo, descriptorName, 1);
575 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100576
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100577 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
578 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100579
580 std::vector<DataType> supportedTypes =
581 {
James Conroyd47a0642019-09-17 14:22:06 +0100582 DataType::Float16,
583 DataType::Float32,
584 DataType::QuantisedAsymm8,
585 DataType::QuantisedSymm16
nikraj01248683f2019-05-29 16:46:50 +0100586 };
587
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100588 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
589 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
590 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000591}
592
Nikhil Rajee391d52019-09-05 17:50:44 +0100593void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
594{
595 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
596
597 ValidateNumInputs(workloadInfo, descriptorName, 1);
598 ValidateNumOutputs(workloadInfo, descriptorName, 1);
599
600 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
601 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
602
Nikhil Raj68c2c902019-09-19 11:21:11 +0100603 if (outputTensorInfo.GetDataType() != DataType::Signed32)
604 {
605 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32.");
606 }
607
James Conroyd47a0642019-09-17 14:22:06 +0100608 std::vector<DataType> supportedInputTypes =
609 {
610 DataType::Float16,
611 DataType::Float32,
612 DataType::QuantisedAsymm8,
Francis Murtagh1939df52019-11-13 15:21:09 +0000613 DataType::QuantisedSymm16,
614 DataType::Signed32
James Conroyd47a0642019-09-17 14:22:06 +0100615 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100616
James Conroyd47a0642019-09-17 14:22:06 +0100617 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100618
619 auto inputShape = inputTensorInfo.GetShape();
620 auto outputShape = outputTensorInfo.GetShape();
621
622 auto inputNumDimensions = inputShape.GetNumDimensions();
623 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
624
625 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
626
627 // 1D input shape results in scalar output shape
628 if (inputShape.GetNumDimensions() == 1)
629 {
630 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
631 {
632 throw InvalidArgumentException(descriptorName + outputShapeError);
633 }
634 }
635 else
636 {
637 for (unsigned int i = 0; i < unsignedAxis; ++i)
638 {
639 if (outputShape[i] != inputShape[i])
640 {
641 throw InvalidArgumentException(descriptorName + outputShapeError);
642 }
643 }
644
645 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
646 {
647 if (outputShape[i - 1] != inputShape[i])
648 {
649 throw InvalidArgumentException(descriptorName + outputShapeError);
650 }
651 }
652 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100653}
654
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100655void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
656{
657 const std::string descriptorName{"SoftmaxQueueDescriptor"};
658
659 ValidateNumInputs(workloadInfo, descriptorName, 1);
660 ValidateNumOutputs(workloadInfo, descriptorName, 1);
661
662 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
663 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
664
665 std::vector<DataType> supportedTypes =
666 {
James Conroyd47a0642019-09-17 14:22:06 +0100667 DataType::Float16,
668 DataType::Float32,
669 DataType::QuantisedAsymm8,
670 DataType::QuantisedSymm16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100671 };
672
673 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
674 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
675 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
676}
677
telsoa014fcda012018-03-09 14:13:49 +0000678void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
679{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100680 const std::string descriptorName{"SplitterQueueDescriptor"};
681
682 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000683
Ruomei Yan25339c32019-05-28 16:48:20 +0100684 // Check the supported data types
685 std::vector<DataType> supportedTypes =
686 {
James Conroyd47a0642019-09-17 14:22:06 +0100687 DataType::Float32,
688 DataType::Float16,
689 DataType::Boolean,
690 DataType::Signed32,
691 DataType::QuantisedAsymm8,
692 DataType::QuantisedSymm16
Ruomei Yan25339c32019-05-28 16:48:20 +0100693 };
694
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100695 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
696 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100697 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100698 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
699 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
700
701 const std::string outputName = "output_" + std::to_string(i);
702 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100703 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100704
telsoa014fcda012018-03-09 14:13:49 +0000705 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
706 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100707 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000708 }
709
710 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
711 {
712 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100713 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000714 "has to match number of workloadInfo.m_OutputTensorInfos. "
715 "Number of windows: " +
716 to_string(m_ViewOrigins.size()) +
717 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
718 }
719
telsoa01c577f2c2018-08-31 09:22:23 +0100720 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000721 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
722 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
723 {
telsoa01c577f2c2018-08-31 09:22:23 +0100724 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000725 ViewOrigin const& e = m_ViewOrigins[w];
726 if (e.m_Origin.size() != inputDims)
727 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100728 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000729 "have the same dimensionality as the input tensor. "
730 "Window origin (index: " +
731 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
732 " dimensions, the input "
733 "tensor has " +
734 to_string(inputDims) + " dimensions.");
735 }
736 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
737 {
738 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
739 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
740 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100741 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000742 "be smaller or equal than the size of the input in that coord.");
743 }
744 }
745 }
746}
747
Jim Flynne242f2d2019-05-22 14:24:13 +0100748void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000749{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100750 const std::string descriptorName{"ConcatQueueDescriptor"};
751
752 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000753
754 if (m_Inputs.size() <= 0)
755 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100756 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000757 }
758 if (m_Outputs.size() <= 0)
759 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100760 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000761 }
762
763 if (workloadInfo.m_InputTensorInfos.size() <= 0)
764 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100765 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000766 }
767 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
768 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100769 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000770 }
771
Nikhil Raj8599a412018-11-19 14:51:07 +0000772 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
773 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100774 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000775 }
776
777 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
778 {
779 return;
780 }
781
telsoa014fcda012018-03-09 14:13:49 +0000782 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
783 {
784 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100785 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000786 "has to match number of workloadInfo.m_InputTensorInfos. "
787 "Number of windows: " +
788 to_string(m_ViewOrigins.size()) +
789 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
790 }
791
telsoa01c577f2c2018-08-31 09:22:23 +0100792 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000793 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
794 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
795 {
telsoa01c577f2c2018-08-31 09:22:23 +0100796 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000797 ViewOrigin const& e = m_ViewOrigins[w];
798 if (e.m_Origin.size() != outputDims)
799 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100800 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000801 "have the same dimensionality as the output tensor. "
802 "Window origin (index: " +
803 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
804 " dimensions, the output "
805 "tensor has " +
806 to_string(outputDims) + " dimensions.");
807 }
telsoa01c577f2c2018-08-31 09:22:23 +0100808 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000809 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
810 {
811 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
812 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
813 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100814 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000815 "be smaller or equal than the size of the output in that coord.");
816 }
817 }
818 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100819
820 // Check the supported data types
821 std::vector<DataType> supportedTypes =
822 {
James Conroyd47a0642019-09-17 14:22:06 +0100823 DataType::Float32,
824 DataType::Float16,
825 DataType::Boolean,
826 DataType::Signed32,
827 DataType::QuantisedAsymm8,
828 DataType::QuantisedSymm16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100829 };
830
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100831 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
832 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100833 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100834 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
835 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
836
837 const std::string inputName = "input_" + std::to_string(i);
838 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100839 }
telsoa014fcda012018-03-09 14:13:49 +0000840}
841
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100842void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
843{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100844 const std::string descriptorName{"StackQueueDescriptor"};
845
846 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100847
848 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
849 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100850 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100851 }
852
853 // All inputs must have the same shape, which is defined in parameters
854 const TensorShape& inputShape = m_Parameters.m_InputShape;
855 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
856 {
857 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
858 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100859 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100860 }
861 }
862
Matthew Jacksondba634f2019-08-15 15:14:18 +0100863 if (inputShape.GetNumDimensions() > 4)
864 {
865 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
866 }
867
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100868 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
869 // since the output tensor has an additional dimension.
870 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
871 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100872 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100873 "than the number of input dimensions.");
874 }
875
876 // Output shape must be as inferred from the input shape
877 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
878 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
879 {
880 if (outputShape[i] != inputShape[i])
881 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100882 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100883 "match shape inferred from input tensor.");
884 }
885 }
886
887 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
888 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100889 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100890 "match shape inferred from input tensor.");
891 }
892
893 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
894 {
895 if (outputShape[i] != inputShape[i-1])
896 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100897 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100898 "match shape inferred from input tensor.");
899 }
900 }
901
Matthew Jacksondba634f2019-08-15 15:14:18 +0100902 if (outputShape.GetNumDimensions() > 5)
903 {
904 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
905 }
906
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100907 // Check the supported data types
908 std::vector<DataType> supportedTypes =
909 {
James Conroyd47a0642019-09-17 14:22:06 +0100910 DataType::Float32,
911 DataType::Float16,
912 DataType::Boolean,
913 DataType::Signed32,
914 DataType::QuantisedAsymm8,
915 DataType::QuantisedSymm16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100916 };
917
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100918 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100919
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100920 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100921 {
922 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
923 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100924 descriptorName,
925 "input_0",
926 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100927 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100928
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100929 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
930 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100931 descriptorName,
932 "input_0",
933 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100934}
935
telsoa014fcda012018-03-09 14:13:49 +0000936void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
937{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100938 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000939
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100940 ValidateNumInputs(workloadInfo, descriptorName, 1);
941 ValidateNumOutputs(workloadInfo, descriptorName, 1);
942
943 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
944 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
945
946 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
947
948 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +0000949 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100950 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +0000951 }
952
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100953 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000954
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100955 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
956 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000957
958 if (m_Parameters.m_BiasEnabled)
959 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100960 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000961
telsoa01c577f2c2018-08-31 09:22:23 +0100962 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100963 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
964 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000965
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100966 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
967 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000968 }
969
Francis Murtagh46c09d02019-05-28 08:15:28 +0100970 // Check the supported data types
971 std::vector<DataType> supportedTypes =
972 {
James Conroyd47a0642019-09-17 14:22:06 +0100973 DataType::Float32,
974 DataType::Float16,
975 DataType::QuantisedAsymm8,
976 DataType::QuantisedSymm16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100977 };
978
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100979 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
980 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000981}
982
telsoa014fcda012018-03-09 14:13:49 +0000983void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
984{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100985 const std::string descriptorName{"NormalizationQueueDescriptor"};
986
987 ValidateNumInputs(workloadInfo, descriptorName, 1);
988 ValidateNumOutputs(workloadInfo, descriptorName, 1);
989
990 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
991 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100992
993 // Check the supported data types
994 std::vector<DataType> supportedTypes =
995 {
996 DataType::Float16,
997 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100998 DataType::QuantisedAsymm8,
999 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001000 };
1001
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001002 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001003
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001004 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001005
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001006 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001007}
1008
1009void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1010{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001011 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001012
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001013 ValidateNumInputs(workloadInfo, descriptorName, 2);
1014 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1015
1016 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1017 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1018 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1019
1020 std::vector<DataType> supportedTypes =
1021 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001022 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001023 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001024 DataType::QuantisedSymm16,
1025 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001026 };
1027
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001028 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1029 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1030 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001031
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001032 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1033 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001034
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001035 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1036 inputTensorInfo1,
1037 outputTensorInfo,
1038 descriptorName,
1039 "input_0",
1040 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001041}
1042
telsoa014fcda012018-03-09 14:13:49 +00001043void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1044{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001045 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +01001046
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001047 ValidateNumInputs(workloadInfo, descriptorName, 2);
1048 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1049
1050 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1051 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1052 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1053
1054 std::vector<DataType> supportedTypes =
1055 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001056 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001057 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001058 DataType::QuantisedSymm16,
1059 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001060 };
1061
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001062 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1063 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1064 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001065
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001066 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1067 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001068
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001069 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1070 inputTensorInfo1,
1071 outputTensorInfo,
1072 descriptorName,
1073 "input_0",
1074 "input_1");
telsoa014fcda012018-03-09 14:13:49 +00001075}
1076
1077void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1078{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001079 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001080
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001081 ValidateNumInputs(workloadInfo, descriptorName, 1);
1082 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1083
1084 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1085 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001086
1087 std::vector<DataType> supportedTypes =
1088 {
1089 DataType::Float16,
1090 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +01001091 DataType::QuantisedAsymm8,
1092 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001093 };
1094
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001095 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1096 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001097
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001098 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1099 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1100 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001101
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001102 ValidatePointer(m_Mean, descriptorName, "mean");
1103 ValidatePointer(m_Variance, descriptorName, "variance");
1104 ValidatePointer(m_Beta, descriptorName, "beta");
1105 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001106
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001107 const TensorInfo& mean = m_Mean->GetTensorInfo();
1108 const TensorInfo& variance = m_Variance->GetTensorInfo();
1109 const TensorInfo& beta = m_Beta->GetTensorInfo();
1110 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001111
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001112 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1113 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1114 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1115 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001116
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001117 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1118 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1119 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001120}
1121
1122void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1123{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001124 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001125
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001126 ValidateNumInputs(workloadInfo, descriptorName, 1);
1127 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001128
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001129 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1130 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001131
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001132 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1133 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001134
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001135 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001136
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001137 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1138 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001139
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001140 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001141
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001142 Optional<TensorInfo> optionalBiasTensorInfo;
telsoa014fcda012018-03-09 14:13:49 +00001143 if (m_Parameters.m_BiasEnabled)
1144 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001145 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001146
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001147 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1148 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001149
1150 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1151 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001152 }
1153
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001154 ValidatePerAxisQuantization(inputTensorInfo,
1155 outputTensorInfo,
1156 weightTensorInfo,
1157 optionalBiasTensorInfo,
1158 descriptorName);
1159
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001160 std::vector<DataType> supportedTypes =
1161 {
Ruomei Yan88d44b82019-05-23 14:29:06 +01001162 DataType::Float32,
1163 DataType::QuantisedAsymm8,
1164 DataType::QuantisedSymm16,
1165 DataType::Float16
1166 };
1167
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001168 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1169 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1170}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001171
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001172void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1173{
1174 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1175
1176 ValidateNumInputs(workloadInfo, descriptorName, 1);
1177 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1178
1179 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1180 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1181
1182 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1183 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1184
1185 ValidatePointer(m_Weight, descriptorName, "weight");
1186
1187 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1188 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1189
1190 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1191 {
1192 throw InvalidArgumentException(
1193 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
1194 "cannot be smaller than 1.") % descriptorName %
1195 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
1196 }
1197
1198 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1199
1200 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1201 // inputChannels * channelMultiplier should be equal to outputChannels.
1202 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1203 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1204 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1205 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1206 {
1207 throw InvalidArgumentException(
1208 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
1209 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
1210 "(provided %4%).") % descriptorName % numWeightOutputChannels %
1211 numWeightInputChannels % numWeightChannelMultiplier));
1212 }
1213
Teresa Charlind8df0262019-11-11 12:28:15 +00001214 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001215
Teresa Charlind8df0262019-11-11 12:28:15 +00001216 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001217 if (m_Parameters.m_BiasEnabled)
1218 {
1219 ValidatePointer(m_Bias, descriptorName, "bias");
1220
Teresa Charlind8df0262019-11-11 12:28:15 +00001221 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
1222 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001223
1224 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1225 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1226 }
Teresa Charlind8df0262019-11-11 12:28:15 +00001227 ValidatePerAxisQuantization(inputTensorInfo,
1228 outputTensorInfo,
1229 weightTensorInfo,
1230 optionalBiasTensorInfo,
1231 descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001232
1233 std::vector<DataType> supportedTypes =
1234 {
1235 DataType::Float32,
1236 DataType::QuantisedAsymm8,
1237 DataType::QuantisedSymm16,
1238 DataType::Float16
1239 };
1240
1241 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1242 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001243}
1244
1245void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1246{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001247 const std::string descriptorName{"PermuteQueueDescriptor"};
1248
1249 ValidateNumInputs(workloadInfo, descriptorName, 1);
1250 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001251
1252 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1253
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001254 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1255 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001256
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001257 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1258 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001259
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001260 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001261 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001262 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001263 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001264 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1265 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1266 "must match dst dimension " + to_string(mapping[i]) +
1267 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001268 }
1269 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001270
1271 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001272}
1273
1274void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1275{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001276 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001277
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001278 ValidateNumInputs(workloadInfo, descriptorName, 1);
1279 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1280
1281 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1282 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1283
1284 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1285 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001286
1287 std::vector<DataType> supportedTypes =
1288 {
1289 DataType::Float32,
1290 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001291 DataType::QuantisedAsymm8,
1292 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001293 };
1294
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001295 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1296 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001297}
1298
1299void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1300{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001301 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001302
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001303 ValidateNumInputs(workloadInfo, descriptorName, 1);
1304 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1305
1306 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1307 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1308
1309 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1310 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001311
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001312 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001313 {
1314 DataType::Float16,
1315 DataType::Float32,
1316 DataType::QuantisedAsymm8,
1317 DataType::QuantisedSymm16
1318 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001319
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001320 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1321 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001322
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001323 // ResizeBilinear only changes width and height: batch and channel count must match.
1324 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1325 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001326 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001327 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001328 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001329 boost::str(boost::format("%1%: Input batch size (%2%) "
1330 "does not match output batch size (%3%)") %
1331 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001332 }
1333
Teresa Charlin970f43b2019-07-01 13:51:07 +01001334 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001335 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1336 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001337 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001338 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001339 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001340 boost::str(boost::format("%1%: Input channel count (%2%) "
1341 "does not match output channel count (%3%)") %
1342 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001343 }
1344}
1345
1346void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1347{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001348 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001349
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001350 ValidateNumInputs(workloadInfo, descriptorName, 1);
1351 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1352
1353 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1354 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1355
1356 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1357 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001358
1359 std::vector<DataType> supportedTypes =
1360 {
1361 DataType::Float16,
1362 DataType::Float32,
1363 DataType::QuantisedAsymm8,
1364 DataType::QuantisedSymm16
1365 };
1366
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001367 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1368 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001369
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001370 // Resize only changes width and height: batch and channel count must match.
1371 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1372 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001373 if (inputBatchSize != outputBatchSize)
1374 {
1375 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001376 boost::str(boost::format("%1%: Input batch size (%2%) "
1377 "does not match output batch size (%3%)") %
1378 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001379 }
1380
1381 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001382 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1383 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001384 if (inputChannelCount != outputChannelCount)
1385 {
1386 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001387 boost::str(boost::format("%1%: Input channel count (%2%) "
1388 "does not match output channel count (%3%)") %
1389 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001390 }
1391}
1392
1393void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1394{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001395 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001396
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001397 ValidateNumInputs(workloadInfo, descriptorName, 1);
1398 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1399
1400 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1401 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1402
1403 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1404 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1405
1406 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1407
telsoa014fcda012018-03-09 14:13:49 +00001408 if (m_Parameters.m_Min > m_Parameters.m_Max)
1409 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001410 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001411 }
telsoa014fcda012018-03-09 14:13:49 +00001412}
1413
Kevin Mayce5045a2019-10-02 14:07:47 +01001414void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1415{
1416 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1417
1418 ValidateNumInputs(workloadInfo, descriptorName, 1);
1419 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1420
1421 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1422 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1423
1424 if (inputTensorInfo.GetNumDimensions() > 4)
1425 {
1426 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1427 }
1428
1429 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1430
1431 // Check the supported data types
1432 std::vector<DataType> supportedTypes =
1433 {
1434 DataType::Float32,
1435 DataType::Float16
1436 };
1437
1438 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001439 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001440}
1441
telsoa014fcda012018-03-09 14:13:49 +00001442void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1443{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001444 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001445
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001446 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001447 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1448
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001449 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1450 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1451
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001452 if (inputTensorInfo.GetNumDimensions() > 4)
1453 {
1454 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1455 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001456
1457 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001458
1459 // Check the supported data types
1460 std::vector<DataType> supportedTypes =
1461 {
1462 DataType::Float32,
1463 DataType::Float16,
1464 DataType::QuantisedAsymm8,
1465 DataType::QuantisedSymm16
1466 };
1467
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001468 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001469 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1470}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001471
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001472void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1473{
1474 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1475
1476 ValidateNumInputs(workloadInfo, descriptorName, 1);
1477 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1478
1479 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1480 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1481
1482 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1483
1484 std::vector<DataType> supportedTypes =
1485 {
1486 DataType::Float32,
1487 DataType::Float16,
1488 };
1489
1490 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001491 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001492}
1493
1494void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1495{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001496 const std::string descriptorName{"ConstantQueueDescriptor"};
1497
1498 ValidateNumInputs(workloadInfo, descriptorName, 0);
1499 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001500
1501 if (!m_LayerOutput)
1502 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001503 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001504 }
1505
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001506 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1507 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001508
1509 // Check the supported data types
1510 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001511 {
1512 DataType::Float32,
1513 DataType::Float16,
1514 DataType::Signed32,
1515 DataType::QuantisedAsymm8,
1516 DataType::QuantisedSymm16
1517 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001518
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001519 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001520}
1521
1522void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1523{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001524 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001525
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001526 ValidateNumInputs(workloadInfo, descriptorName, 1);
1527 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1528
1529 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1530 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1531
1532 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001533
1534 // Check the supported data types
1535 std::vector<DataType> supportedTypes =
1536 {
1537 DataType::Float32,
1538 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001539 DataType::Signed32,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001540 DataType::QuantisedAsymm8,
1541 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001542 };
1543
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001544 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1545 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001546}
1547
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001548void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1549{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001550 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001551
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001552 ValidateNumInputs(workloadInfo, descriptorName, 1);
1553 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1554
1555 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1556 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1557
1558 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1559 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001560
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001561 if (m_Parameters.m_BlockShape.size() != 2)
1562 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001563 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001564 }
1565
1566 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1567 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001568 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1569 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001570 }
1571
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001572 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001573
1574 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001575 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001576
Matthew Bentham8800c002018-11-19 13:19:28 +00001577 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001578
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001579 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1580 widthPad.first + widthPad.second;
1581 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1582 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001583
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001584 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1585 inputShape[dimensionIndices.GetChannelsIndex()];
1586 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001587
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001588 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001589 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001590 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001591 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001592 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001593 }
1594
1595 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001596 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001597 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1598 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001599 }
nikraj01120522a2019-05-31 11:33:07 +01001600
1601 std::vector<DataType> supportedTypes =
1602 {
1603 DataType::Float16,
1604 DataType::Float32,
1605 DataType::QuantisedAsymm8,
1606 DataType::QuantisedSymm16
1607 };
1608
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001609 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1610 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001611}
1612
Keith Davisa57eccb2019-06-14 17:33:22 +01001613void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1614{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001615 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001616
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001617 ValidateNumInputs(workloadInfo, descriptorName, 1);
1618 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001619
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001620 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1621 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1622
1623 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1624 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001625
1626 std::vector<DataType> supportedTypes =
1627 {
1628 DataType::Float32,
1629 DataType::Float16,
James Conroyd2aa85e2019-07-01 17:12:40 +01001630 DataType::QuantisedAsymm8,
1631 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001632 };
1633
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001634 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1635 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001636
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001637 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1638
1639 if (m_Parameters.m_BlockSize == 0)
1640 {
1641 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1642 }
1643
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001644 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1645 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1646 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1647 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001648
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001649 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001650 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001651 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001652 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1653 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001654 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001655
1656 const TensorShape& outputShape = outputTensorInfo.GetShape();
1657 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1658 {
1659 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1660 "must be divisible by the square of block size." );
1661 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001662}
1663
telsoa014fcda012018-03-09 14:13:49 +00001664void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1665{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001666 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001667
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001668 ValidateNumInputs(workloadInfo, descriptorName, 1);
1669 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1670
1671 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1672 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001673
1674 std::vector<DataType> supportedTypes =
1675 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001676 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001677 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001678 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +01001679 };
1680
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001681 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001682
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001683 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001684 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001685 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001686 }
1687}
1688
telsoa01c577f2c2018-08-31 09:22:23 +01001689void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1690{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001691 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1692
1693 const std::string descriptorName{"LstmQueueDescriptor"};
1694
1695 // check dimensions of all inputs and outputs
1696 if (workloadInfo.m_InputTensorInfos.size() != 3)
1697 {
1698 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1699 }
1700 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1701 {
1702 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1703 }
1704
1705 std::vector<DataType> supportedTypes =
1706 {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001707 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001708 DataType::Float32,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001709 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001710 };
1711
Jan Eilers38e05bd2019-06-26 13:10:09 +01001712 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001713 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1714
Jan Eilers38e05bd2019-06-26 13:10:09 +01001715 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001716 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001717 {
1718 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1719 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001720 descriptorName,
1721 "input_0",
1722 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001723 }
1724 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001725 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001726 {
1727 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1728 workloadInfo.m_OutputTensorInfos[i],
1729 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001730 "input_0",
1731 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001732 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001733
Jan Eilers38e05bd2019-06-26 13:10:09 +01001734 // TODO: check clipping parameter is valid
1735
1736 // Inferring batch size, number of outputs and number of cells from the inputs.
1737 // TODO: figure out if there is a way to make sure the specific inputs are at that index of workloadInfo
1738 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1739 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1740 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1741 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1742 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1743 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1744
Jan Eilers38e05bd2019-06-26 13:10:09 +01001745 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001746 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1747 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001748 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001749 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1750 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001751 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001752 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1753 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001754 // scratchBufferTensor
1755 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001756 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1757 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001758 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001759 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1760 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001761 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001762 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1763 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001764 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001765 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1766 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001767
1768
1769 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1770 if ( m_InputToInputWeights )
1771 {
1772 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1773 (n_cell * n_input), "InputLayerNormWeights");
1774 }
1775
1776 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1777 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1778 (n_cell * n_input), "InputToForgetWeights");
1779
1780 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1781 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1782 (n_cell * n_input), "InputToCellWeights");
1783
1784 if ( m_RecurrentToInputWeights )
1785 {
1786 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1787 (n_cell * n_output), "RecurrentToInputWeights");
1788 }
1789
1790 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1791 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1792 (n_cell * n_output), "RecurrentToForgetWeights");
1793
1794 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1795 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1796 (n_cell * n_output), "RecurrentToCellWeights");
1797
1798 // Make sure the input-gate's parameters are either both present (regular
1799 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1800 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1801 !m_Parameters.m_CifgEnabled) ||
1802 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1803 m_Parameters.m_CifgEnabled));
1804 if (!cifg_weights_all_or_none)
1805 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001806 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1807 "RecurrentToInputWeights must either both be present (regular LSTM) "
1808 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1809 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001810 }
1811
1812 if ( m_CellToInputWeights )
1813 {
1814 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1815 n_cell, "CellToInputWeights");
1816 }
1817 if ( m_CellToForgetWeights )
1818 {
1819 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1820 n_cell, "CellToForgetWeights");
1821 }
1822 if ( m_CellToOutputWeights )
1823 {
1824 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1825 n_cell, "CellToOutputWeights");
1826 }
1827
1828 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1829 bool peephole_weights_all_or_none =
1830 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1831 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1832 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1833 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1834 if (!peephole_weights_all_or_none)
1835 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001836 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001837 }
1838
1839 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1840 if (m_Parameters.m_CifgEnabled)
1841 {
1842 if (m_InputGateBias)
1843 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001844 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001845 }
1846 }
1847 else
1848 {
1849 if (!m_InputGateBias)
1850 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001851 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1852 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001853 }
1854 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1855 n_cell, "InputGateBias");
1856 }
1857
1858 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1859 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1860
1861 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1862 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1863
1864 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1865 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1866
1867 if (m_ProjectionWeights)
1868 {
1869 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
1870 (n_cell * n_output), "ProjectionWeights");
1871 }
1872 if (m_ProjectionBias)
1873 {
1874 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
1875 }
1876
1877 // Making sure the projection tensors are consistent:
1878 // 1) If projection weight is not present, then projection bias should not be
1879 // present.
1880 // 2) If projection weight is present, then projection bias is optional.
1881 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
1882 !m_Parameters.m_ProjectionEnabled)
1883 || (m_ProjectionWeights && !m_ProjectionBias &&
1884 m_Parameters.m_ProjectionEnabled)
1885 || (m_ProjectionWeights && m_ProjectionBias &&
1886 m_Parameters.m_ProjectionEnabled));
1887 if (!projecton_tensors_consistent)
1888 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001889 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001890 }
1891
1892 // The four layer normalization weights either all have values or none of them have values. Additionally, if
1893 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
1894 // either all have values or none of them have values. Layer normalization is used when the values of all the
1895 // layer normalization weights are present
1896 if (m_InputLayerNormWeights)
1897 {
1898 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
1899 }
1900 if (m_ForgetLayerNormWeights)
1901 {
1902 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1903 }
1904 if (m_CellLayerNormWeights)
1905 {
1906 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1907 }
1908 if (m_OutputLayerNormWeights)
1909 {
1910 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1911 }
1912
Jan Eilers38e05bd2019-06-26 13:10:09 +01001913 if (m_Parameters.m_LayerNormEnabled)
1914 {
1915 if (!m_Parameters.m_CifgEnabled)
1916 {
1917 if (!m_InputLayerNormWeights)
1918 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001919 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
1920 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001921 }
1922 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
1923 1, n_cell, "InputLayerNormWeights");
1924 }
1925 else if (m_InputLayerNormWeights)
1926 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001927 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
1928 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001929 }
1930
1931 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
1932 "ForgetLayerNormWeights");
1933 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1934
1935 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
1936 "OutputLayerNormWeights");
1937 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1938
1939 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
1940 "CellLayerNormWeights");
1941 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1942 }
1943 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
1944 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001945 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
1946 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001947 }
telsoa01c577f2c2018-08-31 09:22:23 +01001948}
1949
1950void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1951{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001952 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001953
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001954 ValidateNumInputs(workloadInfo, descriptorName, 1);
1955 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1956
1957 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1958 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1959
1960 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01001961 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001962 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01001963 }
1964
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001965 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001966 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001967 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001968 }
1969
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001970 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001971}
1972
1973void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1974{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001975 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001976
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001977 ValidateNumInputs(workloadInfo, descriptorName, 1);
1978 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1979
1980 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1981 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1982
1983 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001984 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001985 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001986 }
1987
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001988 if (outputTensorInfo.GetDataType() != DataType::Float32)
1989 {
1990 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
1991 }
1992
1993 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001994}
1995
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001996void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1997{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001998 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001999
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002000 ValidateNumInputs(workloadInfo, descriptorName, 2);
2001 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2002
2003 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2004 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2005 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2006
2007 std::vector<DataType> supportedTypes =
2008 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002009 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01002010 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01002011 DataType::QuantisedSymm16,
2012 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002013 };
2014
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002015 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2016 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2017 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002018
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002019 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2020 inputTensorInfo1,
2021 outputTensorInfo,
2022 descriptorName,
2023 "input_0",
2024 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01002025}
2026
David Beckc2044fe2018-09-05 15:00:38 +01002027void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2028{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002029 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01002030
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002031 ValidateNumInputs(workloadInfo, descriptorName, 2);
2032 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2033
2034 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2035 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2036 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2037
2038 std::vector<DataType> supportedTypes =
2039 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002040 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01002041 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01002042 DataType::QuantisedSymm16,
2043 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002044 };
2045
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002046 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2047 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2048 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002049
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002050 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2051 inputTensorInfo1,
2052 outputTensorInfo,
2053 descriptorName,
2054 "input_0",
2055 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01002056}
2057
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002058void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2059{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002060 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002061
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002062 ValidateNumInputs(workloadInfo, descriptorName, 2);
2063 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2064
2065 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2066 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2067 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2068
2069 std::vector<DataType> supportedTypes =
2070 {
Mike Kelly1da02362019-08-01 08:43:57 +01002071 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002072 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002073 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01002074 DataType::QuantisedAsymm8,
2075 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002076 };
2077
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002078 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2079 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2080 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002081
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002082 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2083 inputTensorInfo1,
2084 outputTensorInfo,
2085 descriptorName,
2086 "input_0",
2087 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00002088}
2089
narpra01a6bf9122018-09-10 09:50:09 +01002090void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2091{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002092 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01002093
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002094 ValidateNumInputs(workloadInfo, descriptorName, 1);
2095 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2096
2097 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2098 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01002099
2100 std::vector<DataType> supportedTypes =
2101 {
2102 DataType::Float32,
2103 DataType::Float16,
2104 DataType::QuantisedAsymm8,
2105 DataType::QuantisedSymm16
2106 };
narpra01eb061912018-09-10 17:35:27 +01002107
James Conroy4d1ff582019-06-10 17:06:39 +01002108 // First check if input tensor data type is supported, then
2109 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002110 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2111 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002112
narpra0132b90462018-09-13 11:07:48 +01002113 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002114 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002115 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002116 }
narpra0132b90462018-09-13 11:07:48 +01002117 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002118 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002119 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002120 }
2121 else
2122 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002123 unsigned int outputDim =
2124 inputTensorInfo.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
2125 ValidateTensorNumDimensions(outputTensorInfo,
2126 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002127 outputDim > 0 ? outputDim : 1,
2128 "output");
2129 }
narpra01a6bf9122018-09-10 09:50:09 +01002130}
2131
jimfly012c9322a2018-09-19 10:59:49 +01002132void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2133{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002134 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002135
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002136 ValidateNumInputs(workloadInfo, descriptorName, 1);
2137 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2138
2139 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2140 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002141
jimfly012c9322a2018-09-19 10:59:49 +01002142 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002143 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2144
jimfly012c9322a2018-09-19 10:59:49 +01002145 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002146 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2147 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2148 "as there are dimensions in the input tensor that is " +
2149 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2150 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002151 }
2152}
2153
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002154void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2155{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002156 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002157
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002158 ValidateNumInputs(workloadInfo, descriptorName, 1);
2159 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002160
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002161 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2162 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2163
Sadik Armagan2208b602019-07-31 16:36:27 +01002164 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002165 {
James Conroyd47a0642019-09-17 14:22:06 +01002166 DataType::Float32,
2167 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002168 };
2169
2170 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002171
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002172 if (outputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
2173 outputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002174 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002175 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002176 }
2177}
2178
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002179void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2180{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002181 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002182
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002183 ValidateNumInputs(workloadInfo, descriptorName, 1);
2184 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002185
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002186 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2187 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002188
2189 std::vector<DataType> supportedTypes =
2190 {
James Conroyd47a0642019-09-17 14:22:06 +01002191 DataType::Float32,
2192 DataType::Float16,
2193 DataType::QuantisedAsymm8,
2194 DataType::QuantisedSymm16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002195 };
2196
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002197 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2198 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002199}
2200
Conor Kennedy430b5d82018-11-14 15:28:28 +00002201void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2202{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002203 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002204
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002205 ValidateNumInputs(workloadInfo, descriptorName, 1);
2206 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2207
2208 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2209 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002210
2211 std::vector<DataType> supportedTypes =
2212 {
2213 DataType::Float16,
2214 DataType::Float32,
Matteo Martincigh42666a12019-05-29 08:53:41 +01002215 DataType::QuantisedAsymm8,
2216 DataType::QuantisedSymm16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002217 };
2218
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002219 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2220 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002221
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002222 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002223
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002224 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002225 if (rank > 4)
2226 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002227 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002228 }
2229
Conor Kennedy430b5d82018-11-14 15:28:28 +00002230 // Begin, End & Stride length must be of rank(input0)
2231 if (m_Parameters.m_Begin.size() != rank)
2232 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002233 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002234 }
2235
2236 if (m_Parameters.m_End.size() != rank)
2237 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002238 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002239 }
2240
2241 if (m_Parameters.m_Stride.size() != rank)
2242 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002243 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002244 }
2245
2246 // Stride entries must be non-zero
2247 for (auto& stride : m_Parameters.m_Stride)
2248 {
2249 if (stride == 0)
2250 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002251 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002252 }
2253 }
2254}
2255
kevmay0190539692018-11-29 08:40:19 +00002256void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2257{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002258 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002259
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002260 ValidateNumInputs(workloadInfo, descriptorName, 2);
2261 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2262
2263 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2264 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2265 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2266
2267 std::vector<DataType> supportedTypes =
2268 {
Mike Kelly1da02362019-08-01 08:43:57 +01002269 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002270 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002271 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01002272 DataType::QuantisedAsymm8,
2273 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002274 };
2275
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002276 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2277 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2278 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002279
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002280 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2281 inputTensorInfo1,
2282 outputTensorInfo,
2283 descriptorName,
2284 "input_0",
2285 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002286}
2287
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002288void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2289{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002290 const std::string descriptorName{"DebugQueueDescriptor"};
2291
2292 ValidateNumInputs(workloadInfo, descriptorName, 1);
2293 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002294}
2295
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002296void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2297{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002298 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002299
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002300 ValidateNumInputs(workloadInfo, descriptorName, 2);
2301 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002302
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002303 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2304 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2305 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2306
2307 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2308 inputTensorInfo1,
2309 outputTensorInfo,
2310 descriptorName,
2311 "input_0",
2312 "input_1");
2313
2314 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002315 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002316 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002317 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002318}
2319
FrancisMurtagh878f0232018-12-19 10:56:15 +00002320void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2321{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002322 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002323
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002324 ValidateNumInputs(workloadInfo, descriptorName, 2);
2325 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002326
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002327 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2328 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2329 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2330
2331 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2332 inputTensorInfo1,
2333 outputTensorInfo,
2334 descriptorName,
2335 "input_0",
2336 "input_1");
2337
2338 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002339 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002340 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002341 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002342}
2343
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002344void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2345{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002346 const std::string descriptorName{"RsqrtQueueDescriptor"};
2347
2348 ValidateNumInputs(workloadInfo, descriptorName, 1);
2349 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2350
2351 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2352 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2353
2354 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002355
2356 std::vector<DataType> supportedTypes =
2357 {
James Conroyd47a0642019-09-17 14:22:06 +01002358 DataType::Float16,
2359 DataType::Float32,
2360 DataType::QuantisedAsymm8,
2361 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01002362 };
2363
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002364 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2365 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002366}
2367
narpra01b89b05f2019-01-16 09:53:09 +00002368void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2369{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002370 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002371
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002372 ValidateNumInputs(workloadInfo, descriptorName, 2);
2373 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002374
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002375 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2376 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002377 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002378 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002379 }
2380
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002381 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2382 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2383
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002384 std::vector<DataType> supportedTypes =
2385 {
James Conroyd47a0642019-09-17 14:22:06 +01002386 DataType::Float16,
2387 DataType::Float32,
2388 DataType::QuantisedAsymm8,
2389 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002390 };
2391
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002392 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002393
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002394 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002395
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002396 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2397 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002398}
2399
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002400void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2401{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002402 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2403
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002404 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002405
2406 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2407 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002408 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002409 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2410 }
2411
2412 if (m_Anchors == nullptr)
2413 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002414 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002415 }
2416
2417 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002418 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2419 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2420
2421 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002422 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002423 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2424 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002425
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002426 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2427 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2428 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002429
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002430 const std::vector<DataType> supportedInputTypes =
2431 {
2432 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002433 DataType::Float16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002434 DataType::QuantisedAsymm8,
2435 DataType::QuantisedSymm16
2436 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002437
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002438 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2439 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2440 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2441
2442 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2443 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2444 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2445 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2446
2447 // NOTE: Output is always Float32 regardless of input type
2448 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2449 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2450 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2451 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002452
2453 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2454 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002455 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002456 "must be positive and less than or equal to 1.");
2457 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002458
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002459 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2460 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002461 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002462 "should be equal to number of classes + 1.");
2463 }
2464}
2465
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002466void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2467{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002468 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002469
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002470 ValidateNumInputs(workloadInfo, descriptorName, 1);
2471 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2472
2473 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2474 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2475
2476 if (inputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
2477 inputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002478 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002479 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002480 }
2481
Sadik Armagan2208b602019-07-31 16:36:27 +01002482 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002483 {
James Conroyd47a0642019-09-17 14:22:06 +01002484 DataType::Float32,
2485 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002486 };
2487
2488 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002489}
2490
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002491void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2492{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002493 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002494
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002495 ValidateNumInputs(workloadInfo, descriptorName, 2);
2496 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002497
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002498 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2499 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2500 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002501
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002502 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2503 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2504
2505 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2506 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002507}
2508
Sadik Armaganeff363d2019-04-05 15:25:46 +01002509void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2510{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002511 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002512
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002513 ValidateNumInputs(workloadInfo, descriptorName, 2);
2514 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2515
2516 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2517 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2518
2519 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2520 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2521
2522 std::vector<DataType> supportedTypes =
2523 {
Sadik Armaganeff363d2019-04-05 15:25:46 +01002524 DataType::Float32,
2525 DataType::QuantisedAsymm8,
2526 DataType::QuantisedSymm16
2527 };
2528
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002529 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2530 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002531
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002532 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2533 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002534
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002535 ValidateTensorShapesMatch(inputTensorInfo0,
2536 outputTensorInfo0,
2537 descriptorName,
2538 "input_0",
2539 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002540
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002541 ValidateTensorShapesMatch(inputTensorInfo0,
2542 outputTensorInfo1,
2543 descriptorName,
2544 "input_0",
2545 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002546}
2547
Matteo Martincigh49124022019-01-11 13:25:59 +00002548void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2549{
2550 // This is internally generated so it should not need validation.
2551}
2552
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002553void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2554{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002555 const std::string& descriptorName{"PreluQueueDescriptor"};
2556
2557 ValidateNumInputs(workloadInfo, descriptorName, 2);
2558 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2559
2560 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2561 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2562 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002563
2564 std::vector<DataType> supportedTypes
2565 {
2566 DataType::Float16,
2567 DataType::Float32,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002568 DataType::QuantisedAsymm8,
2569 DataType::QuantisedSymm16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002570 };
2571
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002572 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2573 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002574
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002575 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002576
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002577 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2578 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002579
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002580 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2581 alphaTensorInfo,
2582 outputTensorInfo,
2583 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002584 "input",
2585 "alpha");
2586}
2587
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002588void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2589{
2590 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2591
2592 ValidateNumInputs(workloadInfo, descriptorName, 1);
2593 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2594
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002595 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2596 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2597
2598 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2599 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002600
2601 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002602
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002603 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2604 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002605
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002606 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2607
2608 Optional<TensorInfo> optionalBiasTensorInfo;
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002609 if (m_Parameters.m_BiasEnabled)
2610 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002611 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002612
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002613 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2614 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002615
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002616 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002617 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002618 }
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002619
2620 ValidatePerAxisQuantization(inputTensorInfo,
2621 outputTensorInfo,
2622 weightTensorInfo,
2623 optionalBiasTensorInfo,
2624 descriptorName);
2625
2626 std::vector<DataType> supportedTypes =
2627 {
2628 DataType::Float32,
2629 DataType::Float16,
2630 DataType::QuantisedAsymm8,
2631 DataType::QuantisedSymm16
2632 };
2633
2634 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2635 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002636}
2637
James Conroy9c3cae82019-08-01 16:01:48 +01002638void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2639{
2640 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
2641
2642 // Validate number of inputs/outputs
2643 ValidateNumInputs(workloadInfo, descriptorName, 3);
2644 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2645
2646 // Input/output tensor infos
2647 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2648 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
2649 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
2650
2651 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2652 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2653
2654 std::vector<DataType> inputOutputSupportedTypes =
2655 {
2656 DataType::QuantisedAsymm8
2657 };
2658
2659 std::vector<DataType> cellStateSupportedTypes =
2660 {
2661 DataType::QuantisedSymm16
2662 };
2663
2664 std::vector<DataType> weightsSupportedTypes =
2665 {
2666 DataType::QuantisedAsymm8
2667 };
2668
2669 std::vector<DataType> biasSupportedTypes =
2670 {
2671 DataType::Signed32
2672 };
2673
2674 // Validate types of input/output tensors
2675 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2676 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2677 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2678
2679 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2680 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2681
2682 // Validate matching types of input/output tensors
2683 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2684 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2685 "outputStateIn", "outputStateOut");
2686 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2687
2688 // Validate matching quantization info for input/output tensors
2689 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2690 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
2691 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002692
James Conroy9c3cae82019-08-01 16:01:48 +01002693 // Infer number of batches, input size and output size from tensor dimensions
2694 const uint32_t numBatches = inputInfo.GetShape()[0];
2695 const uint32_t inputSize = inputInfo.GetShape()[1];
2696 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
2697
2698 // Validate number of dimensions and number of elements for input/output tensors
2699 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2700 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
2701 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2702 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
2703 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2704
2705 // Validate number of dimensions and number of elements for weights tensors
2706 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
2707 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
2708 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
2709
2710 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2711 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2712 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
2713
2714 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2715 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2716 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
2717
2718 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2719 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2720 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
2721
2722 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
2723 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
2724 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
2725
2726 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2727 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2728 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
2729 " RecurrentToForgetWeights");
2730
2731 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2732 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2733 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2734
2735 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2736 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2737 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2738
2739 // Validate data types for weights tensors (all should match each other)
2740 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
2741
2742 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
2743 "inputToInputWeights", "inputToForgetWeights");
2744 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
2745 "inputToInputWeights", "inputToCellWeights");
2746 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2747 "inputToInputWeights", "inputToOutputWeights");
2748
2749 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
2750 "inputToInputWeights", "recurrentToInputWeights");
2751 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2752 "inputToInputWeights", "recurrentToForgeteights");
2753 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2754 "inputToInputWeights", "recurrentToCellWeights");
2755 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2756 "inputToInputWeights", "recurrentToOutputWeights");
2757
2758 // Validate matching quantization info for weight tensors (all should match each other)
2759 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
2760 descriptorName, "inputToInputWeights", "inputToForgetWeights");
2761 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
2762 descriptorName, "inputToInputWeights", "inputToCellWeights");
2763 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
2764 descriptorName, "inputToInputWeights", "inputToOutputWeights");
2765
2766 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
2767 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
2768 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
2769 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
2770 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
2771 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
2772 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
2773 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
2774
2775 // Validate number of dimensions and number of elements in bias tensors
2776 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
2777 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
2778 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
2779
2780 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2781 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2782 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
2783
2784 ValidatePointer(m_CellBias, descriptorName, "CellBias");
2785 auto cellBiasInfo = m_CellBias->GetTensorInfo();
2786 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
2787
2788 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
2789 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
2790 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
2791
2792 // Validate data types for bias tensors (all should match each other)
2793 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
2794
2795 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
2796 "inputGateBias", "forgetGateBias");
2797 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
2798 "inputGateBias", "cellBias");
2799 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
2800 "inputGateBias", "outputGateBias");
2801
2802 // Validate bias tensor quantization info
2803 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2804 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2805 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2806 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2807}
2808
Kevin May868eb142019-09-04 17:29:31 +01002809void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2810{
2811 const std::string descriptorName{"AbsQueueDescriptor"};
2812
2813 ValidateNumInputs(workloadInfo, descriptorName, 1);
2814 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2815
2816 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2817 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2818
2819 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2820
2821 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01002822 {
2823 DataType::Float16,
2824 DataType::Float32,
2825 DataType::QuantisedAsymm8,
2826 DataType::QuantisedSymm16
2827 };
Kevin May868eb142019-09-04 17:29:31 +01002828
2829 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2830 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2831}
2832
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002833void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2834{
2835 const std::string descriptorName{"SliceQueueDescriptor"};
2836
2837 ValidateNumInputs(workloadInfo, descriptorName, 1);
2838 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2839
2840 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2841 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2842
2843 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2844
2845 const unsigned int rank = inputTensorInfo.GetNumDimensions();
2846 if (rank > 4)
2847 {
2848 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
2849 }
2850
2851 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
2852
2853 // Check if m_Begin and m_Size have the expected length
2854 if (m_Parameters.m_Begin.size() != rank)
2855 {
2856 throw InvalidArgumentException(descriptorName +
2857 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
2858 }
2859 if (m_Parameters.m_Size.size() != rank)
2860 {
2861 throw InvalidArgumentException(descriptorName +
2862 ": Length of size descriptor must equal rank " + std::to_string(rank));
2863 }
2864
2865 // Check if the shape of the output tensor matches m_Size
2866 const TensorShape& outputShape = outputTensorInfo.GetShape();
2867 for (unsigned int i = 0u; i < rank; ++i)
2868 {
2869 if (m_Parameters.m_Size[i] != outputShape[i])
2870 {
2871 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
2872 }
2873 }
2874
2875 // Check if the sum of begin offset and size in a given dimension
2876 // does not exceed the size of corresponding input
2877 const TensorShape& inputShape = inputTensorInfo.GetShape();
2878 for(unsigned int i = 0u; i < rank; ++i)
2879 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002880 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002881 {
2882 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
2883 std::to_string(i) + " exceeds input size.");
2884 }
2885 }
2886}
2887
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01002888void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2889{
2890 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
2891
2892 ValidateNumInputs(workloadInfo, descriptorName, 1);
2893 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2894
2895 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
2896 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
2897
2898 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
2899 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
2900
2901 std::vector<DataType> supportedTypes =
2902 {
2903 DataType::Float32,
2904 DataType::Float16,
2905 DataType::QuantisedAsymm8,
2906 DataType::QuantisedSymm16
2907 };
2908
2909 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
2910 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
2911
2912 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
2913
2914 if (m_Parameters.m_BlockSize == 0)
2915 {
2916 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
2917 }
2918
2919 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
2920 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
2921 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
2922 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
2923
2924 const TensorShape& outputShape = outputInfo.GetShape();
2925 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
2926 {
2927 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
2928 "must be divisible by block size.");
2929 }
2930
2931 const TensorShape& inputShape = inputInfo.GetShape();
2932 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
2933 {
2934 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
2935 "must be divisible by the square of block size." );
2936 }
2937}
2938
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01002939void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2940{
2941 const std::string descriptorName{"ComparisonQueueDescriptor"};
2942
2943 ValidateNumInputs(workloadInfo, descriptorName, 2);
2944 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2945
2946 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2947 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2948 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2949
2950 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2951 inputTensorInfo1,
2952 outputTensorInfo,
2953 descriptorName,
2954 "input_0",
2955 "input_1");
2956
2957 if (outputTensorInfo.GetDataType() != DataType::Boolean)
2958 {
2959 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
2960 }
2961}
2962
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002963} // namespace armnn