blob: 136c196e1bf371db72292719d5f80e2f45c9b0d6 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "WorkloadData.hpp"
6
7#include "CpuTensorHandle.hpp"
telsoa014fcda012018-03-09 14:13:49 +00008
Matteo Martincigh21350152018-11-28 16:22:22 +00009#include <DataLayoutIndexed.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000013#include <string>
14#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/format.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010017#include <boost/numeric/conversion/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
telsoa01c577f2c2018-08-31 09:22:23 +010029 case DataType::Float16:
30 return DataType::Float16;
telsoa014fcda012018-03-09 14:13:49 +000031 case DataType::Float32:
32 return DataType::Float32;
33 case DataType::QuantisedAsymm8:
34 return DataType::Signed32;
Ruomei Yan88d44b82019-05-23 14:29:06 +010035 case DataType::QuantisedSymm16:
36 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000037 default:
38 BOOST_ASSERT_MSG(false, "Invalid input data type");
39 return DataType::Float32;
40 }
41}
42
43namespace
44{
45
46//---------------------------------------------------------------
47//android ndk does not support std::to_string function.
48template <typename T>
49std::string to_string(T value)
50{
51 std::ostringstream os;
52 os << value;
53 return os.str();
54}
55
56//---------------------------------------------------------------
57void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
58{
59 if (!ptr)
60 {
61 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
62 paramName + " parameter must be set.");
63 }
64}
65
66//---------------------------------------------------------------
67void ValidateTensorShapesMatch(const TensorInfo& first,
68 const TensorInfo& second,
69 std::string const& descName,
70 std::string const& firstName,
71 std::string const& secondName)
72{
73 if (first.GetShape() != second.GetShape())
74 {
75 throw InvalidArgumentException(descName + ": "
76 + firstName + " & " + secondName + " must have identical shapes");
77 }
78}
79
80//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010081void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000082{
Sadik Armaganeff363d2019-04-05 15:25:46 +010083 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000084 {
85 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010086 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000087 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
88 }
89}
90
91//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010092void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000093{
Sadik Armaganeff363d2019-04-05 15:25:46 +010094 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000095 {
96 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010097 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +000098 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
99 }
100}
101
102//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100103void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000104 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100105 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000106 std::string const& tensorName)
107{
108 if (tensor.GetNumDimensions() != numDimensions)
109 {
110 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
111 to_string(tensor.GetNumDimensions()) + " dimensions for " +
112 tensorName + " tensor.");
113 }
114}
115
116//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100117void ValidateTensorNumElements(const TensorInfo& tensor,
118 std::string const& descName,
119 unsigned int numElements,
120 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100121{
122 if (tensor.GetNumElements() != numElements)
123 {
124 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100125 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100126 tensorName + " tensor.");
127 }
128}
129
130//---------------------------------------------------------------
131void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100132 unsigned int numDimension,
133 unsigned int numElements,
134 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100135{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100136 const std::string functionName{"ValidateTensorNumDimNumElem"};
137 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
138 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100139}
140
141//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000142void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
143 const std::string& descName, std::string const& tensorName)
144{
145 if (tensor.GetDataType() != dataType)
146 {
147 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
148 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
149 }
150}
151
152//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100153void ValidateTensorQuantizationSpace(const TensorInfo& first,
154 const TensorInfo& second,
155 const std::string& descName,
156 std::string const& firstName,
157 std::string const& secondName)
158{
159 if (!first.IsQuantized() ||
160 !second.IsQuantized())
161 {
162 // Not a quantized type, ignore the validation
163 return;
164 }
165
166 DataType firstDataType = first.GetDataType();
167 DataType secondDataType = second.GetDataType();
168
169 if (firstDataType != secondDataType)
170 {
171 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
172 " must be of the same quantized type, " +
173 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
174 secondName + " is " + GetDataTypeName(secondDataType));
175 }
176
177 if (!first.IsTypeSpaceMatch(second))
178 {
179 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
180 " must have the same quantization space, " +
181 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
182 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
183 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
184 " and scale " + to_string(second.GetQuantizationScale()));
185 }
186}
187
188//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100189void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
190 const TensorInfo& inputTensorInfo,
191 const TensorInfo& weightsTensorInfo,
192 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000193{
194 if (biasTensor.GetQuantizationOffset() != 0)
195 {
196 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
197 to_string(biasTensor.GetQuantizationOffset()));
198 }
199 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
kevmay016c46dd32018-12-17 15:32:45 +0000200 if (std::abs(biasTensor.GetQuantizationScale() - expectedScale) > 0.00000001f)
telsoa014fcda012018-03-09 14:13:49 +0000201 {
202 // Print the float values with extra precision to see very small differences
203 std::stringstream msg;
204 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
205 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
206 biasTensor.GetQuantizationScale();
207 throw InvalidArgumentException(msg.str());
208 }
209}
210
211//---------------------------------------------------------------
212void ValidateTensors(const std::vector<ITensorHandle*>& vec,
213 unsigned int numExpected,
214 const std::string& descName,
215 const std::string& varName)
216{
217 if (vec.empty() && numExpected > 0)
218 {
219 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
220 }
221
222 for (unsigned int i = 0; i < numExpected; ++i)
223 {
224 if (!vec[i])
225 {
226 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
227 }
228 }
229}
230
231//---------------------------------------------------------------
232void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
233 const TensorInfo& second,
234 const TensorInfo& output,
235 std::string const& descName,
236 std::string const& firstName,
237 std::string const& secondName)
238{
239 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
240 // broadcasted.
241 if (first.GetNumDimensions() != second.GetNumDimensions())
242 {
243 throw InvalidArgumentException(descName + ": Tensors "
244 + firstName + " & " + secondName
245 + " must have the same number of dimensions in order to be broadcasted");
246 }
247 uint32_t numDims = first.GetNumDimensions();
248 std::vector<uint32_t> outputDims(numDims, 0u);
249 for (uint32_t i = 0; i < numDims; i++)
250 {
251 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
252 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
253 if (dimsNotEqual && dimsNotOne)
254 {
255 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
256 }
257 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
258 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100259 TensorShape broadcastShape = TensorShape(boost::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000260 if (broadcastShape != output.GetShape())
261 {
262 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
263 + firstName + " & " + secondName
264 + " does not match the output shape");
265 }
266}
267
268//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100269void ValidateDataTypes(const TensorInfo& info,
270 const std::vector<armnn::DataType>& supportedTypes,
271 std::string const& descName)
272{
273 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
274 if (iterator == supportedTypes.end())
275 {
276 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
277 }
278}
279
James Conroy4d1ff582019-06-10 17:06:39 +0100280//---------------------------------------------------------------
281void ValidateTensorDataTypesMatch(const TensorInfo& first,
282 const TensorInfo& second,
283 std::string const& descName,
284 std::string const& firstName,
285 std::string const& secondName)
286{
287 if (first.GetDataType() != second.GetDataType())
288 {
289 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
290 " must have identical data types.");
291 }
292}
293
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100294//---------------------------------------------------------------
295void ValidateTensorNumElementsMatch(const TensorInfo& first,
296 const TensorInfo& second,
297 std::string const& descName,
298 std::string const& firstName,
299 std::string const& secondName)
300{
301 if (first.GetNumElements() != second.GetNumElements())
302 {
303 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
304 " must have the same number of elements.");
305 }
306}
307
308} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000309
310void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
311 unsigned int numExpectedIn, unsigned int numExpectedOut) const
312{
313 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
314 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
315}
316
317//---------------------------------------------------------------
318void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
319{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100320 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000321
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100322 ValidateNumInputs(workloadInfo, descriptorName, 1);
323 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000324
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100325 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
326 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
327
328 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
329 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000330
331 if (m_Inputs.size() != m_Outputs.size())
332 {
333 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100334 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
335 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000336 }
337
338 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
339 {
340 if (!m_Inputs[i])
341 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100342 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
343 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000344 }
345
346 if (!m_Outputs[i])
347 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100348 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
349 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000350 }
351 }
352}
353
Derek Lambertif674aa02019-08-01 15:56:25 +0100354//---------------------------------------------------------------
355void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
356{
357 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
358 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
359
360 if (workloadInfo.m_InputTensorInfos.size() != 1)
361 {
362 throw InvalidArgumentException(boost::str(
363 boost::format("Number of input infos (%1%) is not 1.")
364 % workloadInfo.m_InputTensorInfos.size()));
365
366 }
367
368 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
369 {
370 throw InvalidArgumentException(boost::str(
371 boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
372 % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
373 }
374
375 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
376 {
377 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
378 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
379 {
380 throw InvalidArgumentException(boost::str(
381 boost::format("Number of elements for tensor input and output %1% does not match")
382 % i ));
383 }
384 }
385
386 if (m_Inputs.size() != 1)
387 {
388 throw InvalidArgumentException(boost::str(
389 boost::format("Number of inputs (%1%) is not 1.")
390 % m_Inputs.size()));
391 }
392
393 if (m_Inputs.size() != m_Outputs.size())
394 {
395 throw InvalidArgumentException(boost::str(
396 boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
397 % m_Inputs.size() % m_Outputs.size()));
398 }
399
400 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
401 {
402 if (!m_Inputs[i])
403 {
404 throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
405 }
406
407 if (!m_Outputs[i])
408 {
409 throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
410 }
411 }
412}
413
414//---------------------------------------------------------------
415void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
416{
417 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
418 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
419
Derek Lambertif674aa02019-08-01 15:56:25 +0100420 if (m_Inputs.size() != 1)
421 {
422 throw InvalidArgumentException(boost::str(
423 boost::format("Number of inputs (%1%) is not 1.")
424 % m_Inputs.size()));
425 }
426
427 if (m_Outputs.size() != 0)
428 {
429 throw InvalidArgumentException(boost::str(
430 boost::format("Number of outputs (%1%) is not 0.")
431 % m_Inputs.size() % m_Outputs.size()));
432 }
433
434 if (!m_Inputs[0])
435 {
436 throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
437 }
438}
439
440//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000441void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
442{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100443 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100444
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100445 ValidateNumInputs(workloadInfo, descriptorName, 1);
446 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100447
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100448 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
449 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100450
451 std::vector<DataType> supportedTypes =
452 {
453 DataType::Float16,
454 DataType::Float32,
455 DataType::QuantisedAsymm8,
456 DataType::QuantisedSymm16
457 };
458
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100459 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
460 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
461 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000462}
463
Nikhil Rajee391d52019-09-05 17:50:44 +0100464void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
465{
466 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
467
468 ValidateNumInputs(workloadInfo, descriptorName, 1);
469 ValidateNumOutputs(workloadInfo, descriptorName, 1);
470
471 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
472 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
473
474 std::vector<DataType> supportedTypes =
475 {
476 DataType::Float16,
477 DataType::Float32,
478 DataType::QuantisedAsymm8,
479 DataType::QuantisedSymm16
480 };
481
482 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
483 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
484 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
485}
486
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100487void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
488{
489 const std::string descriptorName{"SoftmaxQueueDescriptor"};
490
491 ValidateNumInputs(workloadInfo, descriptorName, 1);
492 ValidateNumOutputs(workloadInfo, descriptorName, 1);
493
494 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
495 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
496
497 std::vector<DataType> supportedTypes =
498 {
499 DataType::Float16,
500 DataType::Float32,
501 DataType::QuantisedAsymm8,
502 DataType::QuantisedSymm16
503 };
504
505 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
506 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
507 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
508}
509
telsoa014fcda012018-03-09 14:13:49 +0000510void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
511{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100512 const std::string descriptorName{"SplitterQueueDescriptor"};
513
514 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000515
Ruomei Yan25339c32019-05-28 16:48:20 +0100516 // Check the supported data types
517 std::vector<DataType> supportedTypes =
518 {
519 DataType::Float32,
520 DataType::Float16,
521 DataType::Boolean,
522 DataType::Signed32,
523 DataType::QuantisedAsymm8,
524 DataType::QuantisedSymm16
525 };
526
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100527 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
528 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100529 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100530 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
531 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
532
533 const std::string outputName = "output_" + std::to_string(i);
534 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100535 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100536
telsoa014fcda012018-03-09 14:13:49 +0000537 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
538 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100539 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000540 }
541
542 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
543 {
544 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100545 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000546 "has to match number of workloadInfo.m_OutputTensorInfos. "
547 "Number of windows: " +
548 to_string(m_ViewOrigins.size()) +
549 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
550 }
551
telsoa01c577f2c2018-08-31 09:22:23 +0100552 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000553 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
554 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
555 {
telsoa01c577f2c2018-08-31 09:22:23 +0100556 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000557 ViewOrigin const& e = m_ViewOrigins[w];
558 if (e.m_Origin.size() != inputDims)
559 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100560 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000561 "have the same dimensionality as the input tensor. "
562 "Window origin (index: " +
563 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
564 " dimensions, the input "
565 "tensor has " +
566 to_string(inputDims) + " dimensions.");
567 }
568 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
569 {
570 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
571 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
572 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100573 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000574 "be smaller or equal than the size of the input in that coord.");
575 }
576 }
577 }
578}
579
Jim Flynne242f2d2019-05-22 14:24:13 +0100580void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000581{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100582 const std::string descriptorName{"ConcatQueueDescriptor"};
583
584 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000585
586 if (m_Inputs.size() <= 0)
587 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100588 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000589 }
590 if (m_Outputs.size() <= 0)
591 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100592 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000593 }
594
595 if (workloadInfo.m_InputTensorInfos.size() <= 0)
596 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100597 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000598 }
599 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
600 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100601 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000602 }
603
Nikhil Raj8599a412018-11-19 14:51:07 +0000604 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
605 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100606 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000607 }
608
609 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
610 {
611 return;
612 }
613
telsoa014fcda012018-03-09 14:13:49 +0000614 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
615 {
616 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100617 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000618 "has to match number of workloadInfo.m_InputTensorInfos. "
619 "Number of windows: " +
620 to_string(m_ViewOrigins.size()) +
621 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
622 }
623
telsoa01c577f2c2018-08-31 09:22:23 +0100624 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000625 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
626 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
627 {
telsoa01c577f2c2018-08-31 09:22:23 +0100628 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000629 ViewOrigin const& e = m_ViewOrigins[w];
630 if (e.m_Origin.size() != outputDims)
631 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100632 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000633 "have the same dimensionality as the output tensor. "
634 "Window origin (index: " +
635 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
636 " dimensions, the output "
637 "tensor has " +
638 to_string(outputDims) + " dimensions.");
639 }
telsoa01c577f2c2018-08-31 09:22:23 +0100640 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000641 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
642 {
643 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
644 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
645 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100646 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000647 "be smaller or equal than the size of the output in that coord.");
648 }
649 }
650 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100651
652 // Check the supported data types
653 std::vector<DataType> supportedTypes =
654 {
655 DataType::Float32,
656 DataType::Float16,
657 DataType::Boolean,
658 DataType::Signed32,
659 DataType::QuantisedAsymm8,
660 DataType::QuantisedSymm16
661 };
662
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100663 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
664 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100665 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100666 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
667 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
668
669 const std::string inputName = "input_" + std::to_string(i);
670 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100671 }
telsoa014fcda012018-03-09 14:13:49 +0000672}
673
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100674void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
675{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100676 const std::string descriptorName{"StackQueueDescriptor"};
677
678 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100679
680 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
681 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100682 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100683 }
684
685 // All inputs must have the same shape, which is defined in parameters
686 const TensorShape& inputShape = m_Parameters.m_InputShape;
687 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
688 {
689 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
690 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100691 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100692 }
693 }
694
Matthew Jacksondba634f2019-08-15 15:14:18 +0100695 if (inputShape.GetNumDimensions() > 4)
696 {
697 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
698 }
699
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100700 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
701 // since the output tensor has an additional dimension.
702 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
703 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100704 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100705 "than the number of input dimensions.");
706 }
707
708 // Output shape must be as inferred from the input shape
709 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
710 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
711 {
712 if (outputShape[i] != inputShape[i])
713 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100714 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100715 "match shape inferred from input tensor.");
716 }
717 }
718
719 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
720 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100721 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100722 "match shape inferred from input tensor.");
723 }
724
725 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
726 {
727 if (outputShape[i] != inputShape[i-1])
728 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100729 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100730 "match shape inferred from input tensor.");
731 }
732 }
733
Matthew Jacksondba634f2019-08-15 15:14:18 +0100734 if (outputShape.GetNumDimensions() > 5)
735 {
736 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
737 }
738
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100739 // Check the supported data types
740 std::vector<DataType> supportedTypes =
741 {
742 DataType::Float32,
743 DataType::Float16,
744 DataType::Boolean,
745 DataType::Signed32,
746 DataType::QuantisedAsymm8,
747 DataType::QuantisedSymm16
748 };
749
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100750 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100751
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100752 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100753 {
754 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
755 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100756 descriptorName,
757 "input_0",
758 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100759 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100760
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100761 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
762 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100763 descriptorName,
764 "input_0",
765 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100766}
767
telsoa014fcda012018-03-09 14:13:49 +0000768void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
769{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100770 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000771
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100772 ValidateNumInputs(workloadInfo, descriptorName, 1);
773 ValidateNumOutputs(workloadInfo, descriptorName, 1);
774
775 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
776 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
777
778 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
779
780 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +0000781 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100782 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +0000783 }
784
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100785 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000786
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100787 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
788 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000789
790 if (m_Parameters.m_BiasEnabled)
791 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100792 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000793
telsoa01c577f2c2018-08-31 09:22:23 +0100794 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100795 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
796 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000797
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100798 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
799 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000800 }
801
Francis Murtagh46c09d02019-05-28 08:15:28 +0100802 // Check the supported data types
803 std::vector<DataType> supportedTypes =
804 {
805 DataType::Float32,
806 DataType::Float16,
807 DataType::QuantisedAsymm8,
808 DataType::QuantisedSymm16
809 };
810
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100811 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
812 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000813}
814
telsoa014fcda012018-03-09 14:13:49 +0000815void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
816{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100817 const std::string descriptorName{"NormalizationQueueDescriptor"};
818
819 ValidateNumInputs(workloadInfo, descriptorName, 1);
820 ValidateNumOutputs(workloadInfo, descriptorName, 1);
821
822 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
823 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100824
825 // Check the supported data types
826 std::vector<DataType> supportedTypes =
827 {
828 DataType::Float16,
829 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100830 DataType::QuantisedAsymm8,
831 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100832 };
833
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100834 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100835
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100836 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100837
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100838 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000839}
840
841void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
842{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100843 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000844
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100845 ValidateNumInputs(workloadInfo, descriptorName, 2);
846 ValidateNumOutputs(workloadInfo, descriptorName, 1);
847
848 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
849 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
850 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
851
852 std::vector<DataType> supportedTypes =
853 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100854 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100855 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +0100856 DataType::QuantisedSymm16,
857 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100858 };
859
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100860 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
861 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
862 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100863
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100864 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
865 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100866
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100867 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
868 inputTensorInfo1,
869 outputTensorInfo,
870 descriptorName,
871 "input_0",
872 "input_1");
telsoa014fcda012018-03-09 14:13:49 +0000873}
874
telsoa014fcda012018-03-09 14:13:49 +0000875void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
876{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100877 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +0100878
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100879 ValidateNumInputs(workloadInfo, descriptorName, 2);
880 ValidateNumOutputs(workloadInfo, descriptorName, 1);
881
882 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
883 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
884 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
885
886 std::vector<DataType> supportedTypes =
887 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100888 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100889 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +0100890 DataType::QuantisedSymm16,
891 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100892 };
893
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100894 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
895 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
896 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100897
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100898 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
899 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100900
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100901 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
902 inputTensorInfo1,
903 outputTensorInfo,
904 descriptorName,
905 "input_0",
906 "input_1");
telsoa014fcda012018-03-09 14:13:49 +0000907}
908
909void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
910{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100911 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100912
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100913 ValidateNumInputs(workloadInfo, descriptorName, 1);
914 ValidateNumOutputs(workloadInfo, descriptorName, 1);
915
916 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
917 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100918
919 std::vector<DataType> supportedTypes =
920 {
921 DataType::Float16,
922 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100923 DataType::QuantisedAsymm8,
924 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100925 };
926
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100927 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
928 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100929
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100930 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
931 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
932 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100933
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100934 ValidatePointer(m_Mean, descriptorName, "mean");
935 ValidatePointer(m_Variance, descriptorName, "variance");
936 ValidatePointer(m_Beta, descriptorName, "beta");
937 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000938
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100939 const TensorInfo& mean = m_Mean->GetTensorInfo();
940 const TensorInfo& variance = m_Variance->GetTensorInfo();
941 const TensorInfo& beta = m_Beta->GetTensorInfo();
942 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000943
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100944 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
945 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
946 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
947 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000948
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100949 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
950 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
951 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000952}
953
954void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
955{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100956 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000957
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100958 ValidateNumInputs(workloadInfo, descriptorName, 1);
959 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000960
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100961 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
962 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +0000963
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100964 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
965 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +0000966
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100967 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000968
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100969 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
970 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000971
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100972 ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight");
telsoa014fcda012018-03-09 14:13:49 +0000973
974 if (m_Parameters.m_BiasEnabled)
975 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100976 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000977
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100978 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
979 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
980
981 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
982 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000983 }
984
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100985 std::vector<DataType> supportedTypes =
986 {
Ruomei Yan88d44b82019-05-23 14:29:06 +0100987 DataType::Float32,
988 DataType::QuantisedAsymm8,
989 DataType::QuantisedSymm16,
990 DataType::Float16
991 };
992
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100993 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
994 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
995}
Ruomei Yan88d44b82019-05-23 14:29:06 +0100996
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100997void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
998{
999 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1000
1001 ValidateNumInputs(workloadInfo, descriptorName, 1);
1002 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1003
1004 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1005 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1006
1007 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1008 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1009
1010 ValidatePointer(m_Weight, descriptorName, "weight");
1011
1012 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1013 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1014
1015 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1016 {
1017 throw InvalidArgumentException(
1018 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
1019 "cannot be smaller than 1.") % descriptorName %
1020 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
1021 }
1022
1023 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1024
1025 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1026 // inputChannels * channelMultiplier should be equal to outputChannels.
1027 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1028 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1029 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1030 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1031 {
1032 throw InvalidArgumentException(
1033 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
1034 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
1035 "(provided %4%).") % descriptorName % numWeightOutputChannels %
1036 numWeightInputChannels % numWeightChannelMultiplier));
1037 }
1038
1039 ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight");
1040
1041 if (m_Parameters.m_BiasEnabled)
1042 {
1043 ValidatePointer(m_Bias, descriptorName, "bias");
1044
1045 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
1046 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
1047
1048 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1049 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1050 }
1051
1052 std::vector<DataType> supportedTypes =
1053 {
1054 DataType::Float32,
1055 DataType::QuantisedAsymm8,
1056 DataType::QuantisedSymm16,
1057 DataType::Float16
1058 };
1059
1060 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1061 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001062}
1063
1064void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1065{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001066 const std::string descriptorName{"PermuteQueueDescriptor"};
1067
1068 ValidateNumInputs(workloadInfo, descriptorName, 1);
1069 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001070
1071 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1072
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001073 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1074 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001075
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001076 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1077 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001078
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001079 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001080 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001081 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001082 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001083 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1084 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1085 "must match dst dimension " + to_string(mapping[i]) +
1086 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001087 }
1088 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001089
1090 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001091}
1092
1093void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1094{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001095 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001096
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001097 ValidateNumInputs(workloadInfo, descriptorName, 1);
1098 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1099
1100 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1101 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1102
1103 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1104 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001105
1106 std::vector<DataType> supportedTypes =
1107 {
1108 DataType::Float32,
1109 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001110 DataType::QuantisedAsymm8,
1111 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001112 };
1113
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001114 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1115 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001116}
1117
1118void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1119{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001120 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001121
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001122 ValidateNumInputs(workloadInfo, descriptorName, 1);
1123 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1124
1125 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1126 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1127
1128 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1129 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001130
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001131 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001132 {
1133 DataType::Float16,
1134 DataType::Float32,
1135 DataType::QuantisedAsymm8,
1136 DataType::QuantisedSymm16
1137 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001138
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001139 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1140 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001141
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001142 // ResizeBilinear only changes width and height: batch and channel count must match.
1143 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1144 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001145 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001146 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001147 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001148 boost::str(boost::format("%1%: Input batch size (%2%) "
1149 "does not match output batch size (%3%)") %
1150 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001151 }
1152
Teresa Charlin970f43b2019-07-01 13:51:07 +01001153 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001154 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1155 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001156 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001157 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001158 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001159 boost::str(boost::format("%1%: Input channel count (%2%) "
1160 "does not match output channel count (%3%)") %
1161 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001162 }
1163}
1164
1165void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1166{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001167 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001168
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001169 ValidateNumInputs(workloadInfo, descriptorName, 1);
1170 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1171
1172 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1173 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1174
1175 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1176 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001177
1178 std::vector<DataType> supportedTypes =
1179 {
1180 DataType::Float16,
1181 DataType::Float32,
1182 DataType::QuantisedAsymm8,
1183 DataType::QuantisedSymm16
1184 };
1185
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001186 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1187 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001188
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001189 // Resize only changes width and height: batch and channel count must match.
1190 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1191 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001192 if (inputBatchSize != outputBatchSize)
1193 {
1194 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001195 boost::str(boost::format("%1%: Input batch size (%2%) "
1196 "does not match output batch size (%3%)") %
1197 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001198 }
1199
1200 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001201 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1202 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001203 if (inputChannelCount != outputChannelCount)
1204 {
1205 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001206 boost::str(boost::format("%1%: Input channel count (%2%) "
1207 "does not match output channel count (%3%)") %
1208 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001209 }
1210}
1211
1212void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1213{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001214 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001215
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001216 ValidateNumInputs(workloadInfo, descriptorName, 1);
1217 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1218
1219 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1220 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1221
1222 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1223 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1224
1225 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1226
telsoa014fcda012018-03-09 14:13:49 +00001227 if (m_Parameters.m_Min > m_Parameters.m_Max)
1228 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001229 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001230 }
telsoa014fcda012018-03-09 14:13:49 +00001231}
1232
1233void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1234{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001235 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001236
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001237 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001238 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1239
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001240 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1241 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1242
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001243 if (inputTensorInfo.GetNumDimensions() > 4)
1244 {
1245 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1246 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001247
1248 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001249
1250 // Check the supported data types
1251 std::vector<DataType> supportedTypes =
1252 {
1253 DataType::Float32,
1254 DataType::Float16,
1255 DataType::QuantisedAsymm8,
1256 DataType::QuantisedSymm16
1257 };
1258
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001259 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1260 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1261
1262 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001263}
1264
1265void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1266{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001267 const std::string descriptorName{"ConstantQueueDescriptor"};
1268
1269 ValidateNumInputs(workloadInfo, descriptorName, 0);
1270 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001271
1272 if (!m_LayerOutput)
1273 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001274 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001275 }
1276
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001277 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1278 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001279
1280 // Check the supported data types
1281 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001282 {
1283 DataType::Float32,
1284 DataType::Float16,
1285 DataType::Signed32,
1286 DataType::QuantisedAsymm8,
1287 DataType::QuantisedSymm16
1288 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001289
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001290 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001291}
1292
1293void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1294{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001295 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001296
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001297 ValidateNumInputs(workloadInfo, descriptorName, 1);
1298 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1299
1300 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1301 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1302
1303 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001304
1305 // Check the supported data types
1306 std::vector<DataType> supportedTypes =
1307 {
1308 DataType::Float32,
1309 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001310 DataType::Signed32,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001311 DataType::QuantisedAsymm8,
1312 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001313 };
1314
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001315 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1316 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001317}
1318
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001319void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1320{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001321 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001322
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001323 ValidateNumInputs(workloadInfo, descriptorName, 1);
1324 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1325
1326 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1327 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1328
1329 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1330 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001331
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001332 if (m_Parameters.m_BlockShape.size() != 2)
1333 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001334 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001335 }
1336
1337 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1338 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001339 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1340 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001341 }
1342
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001343 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001344
1345 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001346 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001347
Matthew Bentham8800c002018-11-19 13:19:28 +00001348 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001349
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001350 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1351 widthPad.first + widthPad.second;
1352 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1353 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001354
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001355 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1356 inputShape[dimensionIndices.GetChannelsIndex()];
1357 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001358
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001359 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001360 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001361 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001362 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001363 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001364 }
1365
1366 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001367 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001368 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1369 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001370 }
nikraj01120522a2019-05-31 11:33:07 +01001371
1372 std::vector<DataType> supportedTypes =
1373 {
1374 DataType::Float16,
1375 DataType::Float32,
1376 DataType::QuantisedAsymm8,
1377 DataType::QuantisedSymm16
1378 };
1379
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001380 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1381 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001382}
1383
Keith Davisa57eccb2019-06-14 17:33:22 +01001384void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1385{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001386 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001387
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001388 ValidateNumInputs(workloadInfo, descriptorName, 1);
1389 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001390
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001391 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1392 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1393
1394 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1395 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001396
1397 std::vector<DataType> supportedTypes =
1398 {
1399 DataType::Float32,
1400 DataType::Float16,
James Conroyd2aa85e2019-07-01 17:12:40 +01001401 DataType::QuantisedAsymm8,
1402 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001403 };
1404
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001405 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1406 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001407
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001408 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1409 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1410 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1411 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001412
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001413 const TensorShape& inputShape = inputTensorInfo.GetShape();
Keith Davisa57eccb2019-06-14 17:33:22 +01001414
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001415 const unsigned int numInputElements =
1416 inputShape[0] * inputShape[wIndex] * inputShape[hIndex] * inputShape[cIndex];
1417 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
1418
1419 if (numOutputElements != numInputElements)
Keith Davisa57eccb2019-06-14 17:33:22 +01001420 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001421 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
1422 std::to_string(numInputElements) + " but output tensor has " +
1423 std::to_string(numOutputElements) + " elements.");
Keith Davisa57eccb2019-06-14 17:33:22 +01001424 }
1425
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001426 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001427 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001428 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1429 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001430 }
1431}
1432
telsoa014fcda012018-03-09 14:13:49 +00001433void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1434{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001435 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001436
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001437 ValidateNumInputs(workloadInfo, descriptorName, 1);
1438 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1439
1440 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1441 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001442
1443 std::vector<DataType> supportedTypes =
1444 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001445 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001446 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001447 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +01001448 };
1449
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001450 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001451
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001452 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001453 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001454 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001455 }
1456}
1457
telsoa01c577f2c2018-08-31 09:22:23 +01001458void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1459{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001460 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1461
1462 const std::string descriptorName{"LstmQueueDescriptor"};
1463
1464 // check dimensions of all inputs and outputs
1465 if (workloadInfo.m_InputTensorInfos.size() != 3)
1466 {
1467 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1468 }
1469 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1470 {
1471 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1472 }
1473
1474 std::vector<DataType> supportedTypes =
1475 {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001476 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001477 DataType::Float32,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001478 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001479 };
1480
Jan Eilers38e05bd2019-06-26 13:10:09 +01001481 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001482 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1483
Jan Eilers38e05bd2019-06-26 13:10:09 +01001484 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001485 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001486 {
1487 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1488 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001489 descriptorName,
1490 "input_0",
1491 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001492 }
1493 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001494 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001495 {
1496 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1497 workloadInfo.m_OutputTensorInfos[i],
1498 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001499 "input_0",
1500 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001501 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001502
Jan Eilers38e05bd2019-06-26 13:10:09 +01001503 // TODO: check clipping parameter is valid
1504
1505 // Inferring batch size, number of outputs and number of cells from the inputs.
1506 // TODO: figure out if there is a way to make sure the specific inputs are at that index of workloadInfo
1507 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1508 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1509 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1510 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1511 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1512 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1513
Jan Eilers38e05bd2019-06-26 13:10:09 +01001514 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001515 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1516 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001517 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001518 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1519 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001520 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001521 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1522 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001523 // scratchBufferTensor
1524 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001525 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1526 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001527 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001528 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1529 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001530 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001531 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1532 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001533 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001534 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1535 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001536
1537
1538 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1539 if ( m_InputToInputWeights )
1540 {
1541 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1542 (n_cell * n_input), "InputLayerNormWeights");
1543 }
1544
1545 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1546 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1547 (n_cell * n_input), "InputToForgetWeights");
1548
1549 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1550 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1551 (n_cell * n_input), "InputToCellWeights");
1552
1553 if ( m_RecurrentToInputWeights )
1554 {
1555 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1556 (n_cell * n_output), "RecurrentToInputWeights");
1557 }
1558
1559 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1560 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1561 (n_cell * n_output), "RecurrentToForgetWeights");
1562
1563 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1564 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1565 (n_cell * n_output), "RecurrentToCellWeights");
1566
1567 // Make sure the input-gate's parameters are either both present (regular
1568 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1569 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1570 !m_Parameters.m_CifgEnabled) ||
1571 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1572 m_Parameters.m_CifgEnabled));
1573 if (!cifg_weights_all_or_none)
1574 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001575 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1576 "RecurrentToInputWeights must either both be present (regular LSTM) "
1577 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1578 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001579 }
1580
1581 if ( m_CellToInputWeights )
1582 {
1583 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1584 n_cell, "CellToInputWeights");
1585 }
1586 if ( m_CellToForgetWeights )
1587 {
1588 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1589 n_cell, "CellToForgetWeights");
1590 }
1591 if ( m_CellToOutputWeights )
1592 {
1593 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1594 n_cell, "CellToOutputWeights");
1595 }
1596
1597 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1598 bool peephole_weights_all_or_none =
1599 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1600 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1601 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1602 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1603 if (!peephole_weights_all_or_none)
1604 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001605 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001606 }
1607
1608 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1609 if (m_Parameters.m_CifgEnabled)
1610 {
1611 if (m_InputGateBias)
1612 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001613 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001614 }
1615 }
1616 else
1617 {
1618 if (!m_InputGateBias)
1619 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001620 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1621 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001622 }
1623 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1624 n_cell, "InputGateBias");
1625 }
1626
1627 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1628 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1629
1630 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1631 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1632
1633 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1634 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1635
1636 if (m_ProjectionWeights)
1637 {
1638 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
1639 (n_cell * n_output), "ProjectionWeights");
1640 }
1641 if (m_ProjectionBias)
1642 {
1643 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
1644 }
1645
1646 // Making sure the projection tensors are consistent:
1647 // 1) If projection weight is not present, then projection bias should not be
1648 // present.
1649 // 2) If projection weight is present, then projection bias is optional.
1650 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
1651 !m_Parameters.m_ProjectionEnabled)
1652 || (m_ProjectionWeights && !m_ProjectionBias &&
1653 m_Parameters.m_ProjectionEnabled)
1654 || (m_ProjectionWeights && m_ProjectionBias &&
1655 m_Parameters.m_ProjectionEnabled));
1656 if (!projecton_tensors_consistent)
1657 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001658 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001659 }
1660
1661 // The four layer normalization weights either all have values or none of them have values. Additionally, if
1662 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
1663 // either all have values or none of them have values. Layer normalization is used when the values of all the
1664 // layer normalization weights are present
1665 if (m_InputLayerNormWeights)
1666 {
1667 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
1668 }
1669 if (m_ForgetLayerNormWeights)
1670 {
1671 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1672 }
1673 if (m_CellLayerNormWeights)
1674 {
1675 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1676 }
1677 if (m_OutputLayerNormWeights)
1678 {
1679 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1680 }
1681
Jan Eilers38e05bd2019-06-26 13:10:09 +01001682 if (m_Parameters.m_LayerNormEnabled)
1683 {
1684 if (!m_Parameters.m_CifgEnabled)
1685 {
1686 if (!m_InputLayerNormWeights)
1687 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001688 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
1689 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001690 }
1691 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
1692 1, n_cell, "InputLayerNormWeights");
1693 }
1694 else if (m_InputLayerNormWeights)
1695 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001696 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
1697 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001698 }
1699
1700 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
1701 "ForgetLayerNormWeights");
1702 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1703
1704 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
1705 "OutputLayerNormWeights");
1706 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1707
1708 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
1709 "CellLayerNormWeights");
1710 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1711 }
1712 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
1713 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001714 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
1715 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001716 }
telsoa01c577f2c2018-08-31 09:22:23 +01001717}
1718
1719void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1720{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001721 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001722
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001723 ValidateNumInputs(workloadInfo, descriptorName, 1);
1724 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1725
1726 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1727 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1728
1729 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01001730 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001731 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01001732 }
1733
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001734 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001735 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001736 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001737 }
1738
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001739 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001740}
1741
1742void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1743{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001744 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001745
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001746 ValidateNumInputs(workloadInfo, descriptorName, 1);
1747 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1748
1749 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1750 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1751
1752 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001753 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001754 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001755 }
1756
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001757 if (outputTensorInfo.GetDataType() != DataType::Float32)
1758 {
1759 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
1760 }
1761
1762 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001763}
1764
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001765void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1766{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001767 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001768
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001769 ValidateNumInputs(workloadInfo, descriptorName, 2);
1770 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1771
1772 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1773 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1774 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1775
1776 std::vector<DataType> supportedTypes =
1777 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001778 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001779 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001780 DataType::QuantisedSymm16,
1781 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001782 };
1783
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001784 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1785 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1786 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001787
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001788 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1789 inputTensorInfo1,
1790 outputTensorInfo,
1791 descriptorName,
1792 "input_0",
1793 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001794}
1795
David Beckc2044fe2018-09-05 15:00:38 +01001796void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1797{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001798 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01001799
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001800 ValidateNumInputs(workloadInfo, descriptorName, 2);
1801 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1802
1803 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1804 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1805 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1806
1807 std::vector<DataType> supportedTypes =
1808 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001809 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001810 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001811 DataType::QuantisedSymm16,
1812 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001813 };
1814
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001815 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1816 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1817 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001818
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001819 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1820 inputTensorInfo1,
1821 outputTensorInfo,
1822 descriptorName,
1823 "input_0",
1824 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01001825}
1826
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001827void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1828{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001829 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001830
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001831 ValidateNumInputs(workloadInfo, descriptorName, 2);
1832 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1833
1834 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1835 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1836 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1837
1838 std::vector<DataType> supportedTypes =
1839 {
Mike Kelly1da02362019-08-01 08:43:57 +01001840 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001841 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01001842 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001843 DataType::QuantisedAsymm8,
1844 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001845 };
1846
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001847 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1848 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1849 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001850
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001851 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1852 inputTensorInfo1,
1853 outputTensorInfo,
1854 descriptorName,
1855 "input_0",
1856 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001857}
1858
narpra01a6bf9122018-09-10 09:50:09 +01001859void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1860{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001861 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01001862
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001863 ValidateNumInputs(workloadInfo, descriptorName, 1);
1864 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1865
1866 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1867 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01001868
1869 std::vector<DataType> supportedTypes =
1870 {
1871 DataType::Float32,
1872 DataType::Float16,
1873 DataType::QuantisedAsymm8,
1874 DataType::QuantisedSymm16
1875 };
narpra01eb061912018-09-10 17:35:27 +01001876
James Conroy4d1ff582019-06-10 17:06:39 +01001877 // First check if input tensor data type is supported, then
1878 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001879 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1880 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01001881
narpra0132b90462018-09-13 11:07:48 +01001882 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01001883 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001884 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01001885 }
narpra0132b90462018-09-13 11:07:48 +01001886 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01001887 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001888 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01001889 }
1890 else
1891 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001892 unsigned int outputDim =
1893 inputTensorInfo.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
1894 ValidateTensorNumDimensions(outputTensorInfo,
1895 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01001896 outputDim > 0 ? outputDim : 1,
1897 "output");
1898 }
narpra01a6bf9122018-09-10 09:50:09 +01001899}
1900
jimfly012c9322a2018-09-19 10:59:49 +01001901void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1902{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001903 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01001904
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001905 ValidateNumInputs(workloadInfo, descriptorName, 1);
1906 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1907
1908 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1909 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01001910
jimfly012c9322a2018-09-19 10:59:49 +01001911 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001912 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
1913
jimfly012c9322a2018-09-19 10:59:49 +01001914 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001915 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
1916 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
1917 "as there are dimensions in the input tensor that is " +
1918 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
1919 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01001920 }
1921}
1922
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001923void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1924{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001925 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001926
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001927 ValidateNumInputs(workloadInfo, descriptorName, 1);
1928 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001929
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001930 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1931 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1932
Sadik Armagan2208b602019-07-31 16:36:27 +01001933 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001934 {
Sadik Armagan2208b602019-07-31 16:36:27 +01001935 DataType::Float32,
1936 DataType::Float16
1937 };
1938
1939 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001940
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001941 if (outputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
1942 outputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001943 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001944 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001945 }
1946}
1947
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00001948void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1949{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001950 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001951
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001952 ValidateNumInputs(workloadInfo, descriptorName, 1);
1953 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001954
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001955 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1956 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001957
1958 std::vector<DataType> supportedTypes =
1959 {
1960 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01001961 DataType::Float16,
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001962 DataType::QuantisedAsymm8,
1963 DataType::QuantisedSymm16
1964 };
1965
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001966 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1967 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00001968}
1969
Conor Kennedy430b5d82018-11-14 15:28:28 +00001970void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1971{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001972 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00001973
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001974 ValidateNumInputs(workloadInfo, descriptorName, 1);
1975 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1976
1977 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1978 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001979
1980 std::vector<DataType> supportedTypes =
1981 {
1982 DataType::Float16,
1983 DataType::Float32,
Matteo Martincigh42666a12019-05-29 08:53:41 +01001984 DataType::QuantisedAsymm8,
1985 DataType::QuantisedSymm16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001986 };
1987
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001988 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1989 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001990
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001991 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001992
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001993 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00001994 if (rank > 4)
1995 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001996 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00001997 }
1998
Conor Kennedy430b5d82018-11-14 15:28:28 +00001999 // Begin, End & Stride length must be of rank(input0)
2000 if (m_Parameters.m_Begin.size() != rank)
2001 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002002 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002003 }
2004
2005 if (m_Parameters.m_End.size() != rank)
2006 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002007 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002008 }
2009
2010 if (m_Parameters.m_Stride.size() != rank)
2011 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002012 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002013 }
2014
2015 // Stride entries must be non-zero
2016 for (auto& stride : m_Parameters.m_Stride)
2017 {
2018 if (stride == 0)
2019 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002020 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002021 }
2022 }
2023}
2024
kevmay0190539692018-11-29 08:40:19 +00002025void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2026{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002027 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002028
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002029 ValidateNumInputs(workloadInfo, descriptorName, 2);
2030 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2031
2032 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2033 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2034 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2035
2036 std::vector<DataType> supportedTypes =
2037 {
Mike Kelly1da02362019-08-01 08:43:57 +01002038 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002039 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002040 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01002041 DataType::QuantisedAsymm8,
2042 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002043 };
2044
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002045 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2046 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2047 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002048
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002049 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2050 inputTensorInfo1,
2051 outputTensorInfo,
2052 descriptorName,
2053 "input_0",
2054 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002055}
2056
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002057void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2058{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002059 const std::string descriptorName{"DebugQueueDescriptor"};
2060
2061 ValidateNumInputs(workloadInfo, descriptorName, 1);
2062 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002063}
2064
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002065void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2066{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002067 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002068
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002069 ValidateNumInputs(workloadInfo, descriptorName, 2);
2070 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002071
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002072 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2073 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2074 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2075
2076 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2077 inputTensorInfo1,
2078 outputTensorInfo,
2079 descriptorName,
2080 "input_0",
2081 "input_1");
2082
2083 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002084 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002085 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002086 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002087}
2088
FrancisMurtagh878f0232018-12-19 10:56:15 +00002089void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2090{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002091 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002092
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002093 ValidateNumInputs(workloadInfo, descriptorName, 2);
2094 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002095
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002096 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2097 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2098 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2099
2100 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2101 inputTensorInfo1,
2102 outputTensorInfo,
2103 descriptorName,
2104 "input_0",
2105 "input_1");
2106
2107 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002108 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002109 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002110 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002111}
2112
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002113void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2114{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002115 const std::string descriptorName{"RsqrtQueueDescriptor"};
2116
2117 ValidateNumInputs(workloadInfo, descriptorName, 1);
2118 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2119
2120 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2121 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2122
2123 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002124
2125 std::vector<DataType> supportedTypes =
2126 {
2127 DataType::Float16,
2128 DataType::Float32,
nikraj0124d73212019-06-14 14:20:40 +01002129 DataType::QuantisedAsymm8,
2130 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01002131 };
2132
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002133 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2134 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002135}
2136
narpra01b89b05f2019-01-16 09:53:09 +00002137void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2138{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002139 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002140
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002141 ValidateNumInputs(workloadInfo, descriptorName, 2);
2142 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002143
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002144 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2145 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002146 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002147 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002148 }
2149
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002150 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2151 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2152
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002153 std::vector<DataType> supportedTypes =
2154 {
2155 DataType::Float16,
2156 DataType::Float32,
2157 DataType::QuantisedAsymm8,
2158 DataType::QuantisedSymm16
2159 };
2160
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002161 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002162
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002163 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002164
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002165 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2166 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002167}
2168
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002169void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2170{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002171 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2172
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002173 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002174
2175 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2176 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002177 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002178 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2179 }
2180
2181 if (m_Anchors == nullptr)
2182 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002183 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002184 }
2185
2186 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002187 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2188 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2189
2190 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002191 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002192 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2193 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002194
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002195 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2196 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2197 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002198
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002199 const std::vector<DataType> supportedInputTypes =
2200 {
2201 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002202 DataType::Float16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002203 DataType::QuantisedAsymm8,
2204 DataType::QuantisedSymm16
2205 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002206
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002207 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2208 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2209 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2210
2211 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2212 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2213 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2214 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2215
2216 // NOTE: Output is always Float32 regardless of input type
2217 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2218 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2219 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2220 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002221
2222 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2223 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002224 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002225 "must be positive and less than or equal to 1.");
2226 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002227
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002228 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2229 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002230 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002231 "should be equal to number of classes + 1.");
2232 }
2233}
2234
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002235void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2236{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002237 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002238
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002239 ValidateNumInputs(workloadInfo, descriptorName, 1);
2240 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2241
2242 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2243 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2244
2245 if (inputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
2246 inputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002247 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002248 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002249 }
2250
Sadik Armagan2208b602019-07-31 16:36:27 +01002251 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002252 {
Sadik Armagan2208b602019-07-31 16:36:27 +01002253 DataType::Float32,
2254 DataType::Float16
2255 };
2256
2257 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002258}
2259
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002260void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2261{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002262 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002263
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002264 ValidateNumInputs(workloadInfo, descriptorName, 2);
2265 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002266
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002267 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2268 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2269 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002270
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002271 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2272 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2273
2274 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2275 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002276}
2277
Sadik Armaganeff363d2019-04-05 15:25:46 +01002278void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2279{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002280 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002281
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002282 ValidateNumInputs(workloadInfo, descriptorName, 2);
2283 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2284
2285 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2286 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2287
2288 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2289 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2290
2291 std::vector<DataType> supportedTypes =
2292 {
Sadik Armaganeff363d2019-04-05 15:25:46 +01002293 DataType::Float32,
2294 DataType::QuantisedAsymm8,
2295 DataType::QuantisedSymm16
2296 };
2297
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002298 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2299 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002300
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002301 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2302 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002303
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002304 ValidateTensorShapesMatch(inputTensorInfo0,
2305 outputTensorInfo0,
2306 descriptorName,
2307 "input_0",
2308 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002309
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002310 ValidateTensorShapesMatch(inputTensorInfo0,
2311 outputTensorInfo1,
2312 descriptorName,
2313 "input_0",
2314 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002315}
2316
Matteo Martincigh49124022019-01-11 13:25:59 +00002317void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2318{
2319 // This is internally generated so it should not need validation.
2320}
2321
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002322void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2323{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002324 const std::string& descriptorName{"PreluQueueDescriptor"};
2325
2326 ValidateNumInputs(workloadInfo, descriptorName, 2);
2327 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2328
2329 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2330 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2331 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002332
2333 std::vector<DataType> supportedTypes
2334 {
2335 DataType::Float16,
2336 DataType::Float32,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002337 DataType::QuantisedAsymm8,
2338 DataType::QuantisedSymm16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002339 };
2340
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002341 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2342 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002343
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002344 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002345
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002346 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2347 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002348
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002349 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2350 alphaTensorInfo,
2351 outputTensorInfo,
2352 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002353 "input",
2354 "alpha");
2355}
2356
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002357void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2358{
2359 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2360
2361 ValidateNumInputs(workloadInfo, descriptorName, 1);
2362 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2363
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002364 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2365 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2366
2367 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2368 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002369
2370 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002371
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002372 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2373 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
2374 ValidateTensorDataType(weightTensorInfo, inputTensorInfo.GetDataType(), descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002375
2376 if (m_Parameters.m_BiasEnabled)
2377 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002378 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002379
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002380 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
2381 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002382
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002383 ValidateTensorDataType(biasTensorInfo,
2384 GetBiasDataType(inputTensorInfo.GetDataType()),
2385 descriptorName,
2386 "bias");
2387
2388 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002389 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002390}
2391
James Conroy9c3cae82019-08-01 16:01:48 +01002392void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2393{
2394 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
2395
2396 // Validate number of inputs/outputs
2397 ValidateNumInputs(workloadInfo, descriptorName, 3);
2398 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2399
2400 // Input/output tensor infos
2401 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2402 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
2403 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
2404
2405 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2406 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2407
2408 std::vector<DataType> inputOutputSupportedTypes =
2409 {
2410 DataType::QuantisedAsymm8
2411 };
2412
2413 std::vector<DataType> cellStateSupportedTypes =
2414 {
2415 DataType::QuantisedSymm16
2416 };
2417
2418 std::vector<DataType> weightsSupportedTypes =
2419 {
2420 DataType::QuantisedAsymm8
2421 };
2422
2423 std::vector<DataType> biasSupportedTypes =
2424 {
2425 DataType::Signed32
2426 };
2427
2428 // Validate types of input/output tensors
2429 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2430 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2431 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2432
2433 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2434 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2435
2436 // Validate matching types of input/output tensors
2437 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2438 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2439 "outputStateIn", "outputStateOut");
2440 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2441
2442 // Validate matching quantization info for input/output tensors
2443 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2444 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
2445 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002446
James Conroy9c3cae82019-08-01 16:01:48 +01002447 // Infer number of batches, input size and output size from tensor dimensions
2448 const uint32_t numBatches = inputInfo.GetShape()[0];
2449 const uint32_t inputSize = inputInfo.GetShape()[1];
2450 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
2451
2452 // Validate number of dimensions and number of elements for input/output tensors
2453 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2454 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
2455 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2456 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
2457 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2458
2459 // Validate number of dimensions and number of elements for weights tensors
2460 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
2461 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
2462 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
2463
2464 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2465 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2466 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
2467
2468 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2469 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2470 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
2471
2472 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2473 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2474 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
2475
2476 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
2477 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
2478 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
2479
2480 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2481 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2482 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
2483 " RecurrentToForgetWeights");
2484
2485 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2486 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2487 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2488
2489 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2490 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2491 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2492
2493 // Validate data types for weights tensors (all should match each other)
2494 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
2495
2496 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
2497 "inputToInputWeights", "inputToForgetWeights");
2498 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
2499 "inputToInputWeights", "inputToCellWeights");
2500 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2501 "inputToInputWeights", "inputToOutputWeights");
2502
2503 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
2504 "inputToInputWeights", "recurrentToInputWeights");
2505 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2506 "inputToInputWeights", "recurrentToForgeteights");
2507 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2508 "inputToInputWeights", "recurrentToCellWeights");
2509 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2510 "inputToInputWeights", "recurrentToOutputWeights");
2511
2512 // Validate matching quantization info for weight tensors (all should match each other)
2513 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
2514 descriptorName, "inputToInputWeights", "inputToForgetWeights");
2515 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
2516 descriptorName, "inputToInputWeights", "inputToCellWeights");
2517 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
2518 descriptorName, "inputToInputWeights", "inputToOutputWeights");
2519
2520 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
2521 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
2522 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
2523 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
2524 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
2525 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
2526 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
2527 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
2528
2529 // Validate number of dimensions and number of elements in bias tensors
2530 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
2531 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
2532 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
2533
2534 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2535 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2536 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
2537
2538 ValidatePointer(m_CellBias, descriptorName, "CellBias");
2539 auto cellBiasInfo = m_CellBias->GetTensorInfo();
2540 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
2541
2542 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
2543 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
2544 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
2545
2546 // Validate data types for bias tensors (all should match each other)
2547 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
2548
2549 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
2550 "inputGateBias", "forgetGateBias");
2551 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
2552 "inputGateBias", "cellBias");
2553 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
2554 "inputGateBias", "outputGateBias");
2555
2556 // Validate bias tensor quantization info
2557 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2558 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2559 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2560 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2561}
2562
Kevin May868eb142019-09-04 17:29:31 +01002563void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2564{
2565 const std::string descriptorName{"AbsQueueDescriptor"};
2566
2567 ValidateNumInputs(workloadInfo, descriptorName, 1);
2568 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2569
2570 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2571 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2572
2573 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2574
2575 std::vector<DataType> supportedTypes =
2576 {
2577 DataType::Float16,
2578 DataType::Float32,
2579 DataType::QuantisedAsymm8,
2580 DataType::QuantisedSymm16
2581 };
2582
2583 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2584 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2585}
2586
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002587void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2588{
2589 const std::string descriptorName{"SliceQueueDescriptor"};
2590
2591 ValidateNumInputs(workloadInfo, descriptorName, 1);
2592 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2593
2594 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2595 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2596
2597 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2598
2599 const unsigned int rank = inputTensorInfo.GetNumDimensions();
2600 if (rank > 4)
2601 {
2602 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
2603 }
2604
2605 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
2606
2607 // Check if m_Begin and m_Size have the expected length
2608 if (m_Parameters.m_Begin.size() != rank)
2609 {
2610 throw InvalidArgumentException(descriptorName +
2611 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
2612 }
2613 if (m_Parameters.m_Size.size() != rank)
2614 {
2615 throw InvalidArgumentException(descriptorName +
2616 ": Length of size descriptor must equal rank " + std::to_string(rank));
2617 }
2618
2619 // Check if the shape of the output tensor matches m_Size
2620 const TensorShape& outputShape = outputTensorInfo.GetShape();
2621 for (unsigned int i = 0u; i < rank; ++i)
2622 {
2623 if (m_Parameters.m_Size[i] != outputShape[i])
2624 {
2625 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
2626 }
2627 }
2628
2629 // Check if the sum of begin offset and size in a given dimension
2630 // does not exceed the size of corresponding input
2631 const TensorShape& inputShape = inputTensorInfo.GetShape();
2632 for(unsigned int i = 0u; i < rank; ++i)
2633 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002634 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002635 {
2636 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
2637 std::to_string(i) + " exceeds input size.");
2638 }
2639 }
2640}
2641
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002642} // namespace armnn