blob: 1c607da707f4327bfe168a86cfbd894190dbdee3 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "WorkloadData.hpp"
6
7#include "CpuTensorHandle.hpp"
telsoa014fcda012018-03-09 14:13:49 +00008
Matteo Martincigh21350152018-11-28 16:22:22 +00009#include <DataLayoutIndexed.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000013#include <string>
14#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/format.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010017#include <boost/numeric/conversion/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
telsoa01c577f2c2018-08-31 09:22:23 +010029 case DataType::Float16:
30 return DataType::Float16;
telsoa014fcda012018-03-09 14:13:49 +000031 case DataType::Float32:
32 return DataType::Float32;
33 case DataType::QuantisedAsymm8:
34 return DataType::Signed32;
Ruomei Yan88d44b82019-05-23 14:29:06 +010035 case DataType::QuantisedSymm16:
36 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000037 default:
38 BOOST_ASSERT_MSG(false, "Invalid input data type");
39 return DataType::Float32;
40 }
41}
42
43namespace
44{
45
46//---------------------------------------------------------------
47//android ndk does not support std::to_string function.
48template <typename T>
49std::string to_string(T value)
50{
51 std::ostringstream os;
52 os << value;
53 return os.str();
54}
55
56//---------------------------------------------------------------
57void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
58{
59 if (!ptr)
60 {
61 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
62 paramName + " parameter must be set.");
63 }
64}
65
66//---------------------------------------------------------------
67void ValidateTensorShapesMatch(const TensorInfo& first,
68 const TensorInfo& second,
69 std::string const& descName,
70 std::string const& firstName,
71 std::string const& secondName)
72{
73 if (first.GetShape() != second.GetShape())
74 {
75 throw InvalidArgumentException(descName + ": "
76 + firstName + " & " + secondName + " must have identical shapes");
77 }
78}
79
80//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010081void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000082{
Sadik Armaganeff363d2019-04-05 15:25:46 +010083 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000084 {
85 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010086 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000087 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
88 }
89}
90
91//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010092void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000093{
Sadik Armaganeff363d2019-04-05 15:25:46 +010094 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000095 {
96 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010097 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +000098 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
99 }
100}
101
102//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100103void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000104 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100105 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000106 std::string const& tensorName)
107{
108 if (tensor.GetNumDimensions() != numDimensions)
109 {
110 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
111 to_string(tensor.GetNumDimensions()) + " dimensions for " +
112 tensorName + " tensor.");
113 }
114}
115
116//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100117void ValidateTensorNumElements(const TensorInfo& tensor,
118 std::string const& descName,
119 unsigned int numElements,
120 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100121{
122 if (tensor.GetNumElements() != numElements)
123 {
124 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
125 to_string(tensor.GetNumDimensions()) + " elements for " +
126 tensorName + " tensor.");
127 }
128}
129
130//---------------------------------------------------------------
131void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100132 unsigned int numDimension,
133 unsigned int numElements,
134 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100135{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100136 const std::string functionName{"ValidateTensorNumDimNumElem"};
137 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
138 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100139}
140
141//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000142void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
143 const std::string& descName, std::string const& tensorName)
144{
145 if (tensor.GetDataType() != dataType)
146 {
147 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
148 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
149 }
150}
151
152//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100153void ValidateTensorQuantizationSpace(const TensorInfo& first,
154 const TensorInfo& second,
155 const std::string& descName,
156 std::string const& firstName,
157 std::string const& secondName)
158{
159 if (!first.IsQuantized() ||
160 !second.IsQuantized())
161 {
162 // Not a quantized type, ignore the validation
163 return;
164 }
165
166 DataType firstDataType = first.GetDataType();
167 DataType secondDataType = second.GetDataType();
168
169 if (firstDataType != secondDataType)
170 {
171 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
172 " must be of the same quantized type, " +
173 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
174 secondName + " is " + GetDataTypeName(secondDataType));
175 }
176
177 if (!first.IsTypeSpaceMatch(second))
178 {
179 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
180 " must have the same quantization space, " +
181 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
182 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
183 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
184 " and scale " + to_string(second.GetQuantizationScale()));
185 }
186}
187
188//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100189void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
190 const TensorInfo& inputTensorInfo,
191 const TensorInfo& weightsTensorInfo,
192 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000193{
194 if (biasTensor.GetQuantizationOffset() != 0)
195 {
196 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
197 to_string(biasTensor.GetQuantizationOffset()));
198 }
199 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
kevmay016c46dd32018-12-17 15:32:45 +0000200 if (std::abs(biasTensor.GetQuantizationScale() - expectedScale) > 0.00000001f)
telsoa014fcda012018-03-09 14:13:49 +0000201 {
202 // Print the float values with extra precision to see very small differences
203 std::stringstream msg;
204 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
205 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
206 biasTensor.GetQuantizationScale();
207 throw InvalidArgumentException(msg.str());
208 }
209}
210
211//---------------------------------------------------------------
212void ValidateTensors(const std::vector<ITensorHandle*>& vec,
213 unsigned int numExpected,
214 const std::string& descName,
215 const std::string& varName)
216{
217 if (vec.empty() && numExpected > 0)
218 {
219 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
220 }
221
222 for (unsigned int i = 0; i < numExpected; ++i)
223 {
224 if (!vec[i])
225 {
226 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
227 }
228 }
229}
230
231//---------------------------------------------------------------
232void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
233 const TensorInfo& second,
234 const TensorInfo& output,
235 std::string const& descName,
236 std::string const& firstName,
237 std::string const& secondName)
238{
239 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
240 // broadcasted.
241 if (first.GetNumDimensions() != second.GetNumDimensions())
242 {
243 throw InvalidArgumentException(descName + ": Tensors "
244 + firstName + " & " + secondName
245 + " must have the same number of dimensions in order to be broadcasted");
246 }
247 uint32_t numDims = first.GetNumDimensions();
248 std::vector<uint32_t> outputDims(numDims, 0u);
249 for (uint32_t i = 0; i < numDims; i++)
250 {
251 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
252 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
253 if (dimsNotEqual && dimsNotOne)
254 {
255 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
256 }
257 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
258 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100259 TensorShape broadcastShape = TensorShape(boost::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000260 if (broadcastShape != output.GetShape())
261 {
262 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
263 + firstName + " & " + secondName
264 + " does not match the output shape");
265 }
266}
267
268//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100269void ValidateDataTypes(const TensorInfo& info,
270 const std::vector<armnn::DataType>& supportedTypes,
271 std::string const& descName)
272{
273 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
274 if (iterator == supportedTypes.end())
275 {
276 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
277 }
278}
279
James Conroy4d1ff582019-06-10 17:06:39 +0100280//---------------------------------------------------------------
281void ValidateTensorDataTypesMatch(const TensorInfo& first,
282 const TensorInfo& second,
283 std::string const& descName,
284 std::string const& firstName,
285 std::string const& secondName)
286{
287 if (first.GetDataType() != second.GetDataType())
288 {
289 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
290 " must have identical data types.");
291 }
292}
293
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100294//---------------------------------------------------------------
295void ValidateTensorNumElementsMatch(const TensorInfo& first,
296 const TensorInfo& second,
297 std::string const& descName,
298 std::string const& firstName,
299 std::string const& secondName)
300{
301 if (first.GetNumElements() != second.GetNumElements())
302 {
303 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
304 " must have the same number of elements.");
305 }
306}
307
308} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000309
310void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
311 unsigned int numExpectedIn, unsigned int numExpectedOut) const
312{
313 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
314 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
315}
316
317//---------------------------------------------------------------
318void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
319{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100320 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000321
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100322 ValidateNumInputs(workloadInfo, descriptorName, 1);
323 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000324
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100325 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
326 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
327
328 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
329 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000330
331 if (m_Inputs.size() != m_Outputs.size())
332 {
333 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100334 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
335 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000336 }
337
338 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
339 {
340 if (!m_Inputs[i])
341 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100342 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
343 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000344 }
345
346 if (!m_Outputs[i])
347 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100348 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
349 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000350 }
351 }
352}
353
Derek Lambertif674aa02019-08-01 15:56:25 +0100354//---------------------------------------------------------------
355void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
356{
357 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
358 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
359
360 if (workloadInfo.m_InputTensorInfos.size() != 1)
361 {
362 throw InvalidArgumentException(boost::str(
363 boost::format("Number of input infos (%1%) is not 1.")
364 % workloadInfo.m_InputTensorInfos.size()));
365
366 }
367
368 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
369 {
370 throw InvalidArgumentException(boost::str(
371 boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
372 % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
373 }
374
375 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
376 {
377 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
378 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
379 {
380 throw InvalidArgumentException(boost::str(
381 boost::format("Number of elements for tensor input and output %1% does not match")
382 % i ));
383 }
384 }
385
386 if (m_Inputs.size() != 1)
387 {
388 throw InvalidArgumentException(boost::str(
389 boost::format("Number of inputs (%1%) is not 1.")
390 % m_Inputs.size()));
391 }
392
393 if (m_Inputs.size() != m_Outputs.size())
394 {
395 throw InvalidArgumentException(boost::str(
396 boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
397 % m_Inputs.size() % m_Outputs.size()));
398 }
399
400 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
401 {
402 if (!m_Inputs[i])
403 {
404 throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
405 }
406
407 if (!m_Outputs[i])
408 {
409 throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
410 }
411 }
412}
413
414//---------------------------------------------------------------
415void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
416{
417 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
418 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
419
420 if (workloadInfo.m_InputTensorInfos.size() != 1)
421 {
422 throw InvalidArgumentException(boost::str(
423 boost::format("Number of input infos (%1%) is not 1.")
424 % workloadInfo.m_InputTensorInfos.size()));
425
426 }
427
428 if (workloadInfo.m_OutputTensorInfos.size() != 0)
429 {
430 throw InvalidArgumentException(boost::str(
431 boost::format("Number of output infos (%1%) is not 0.")
432 % workloadInfo.m_InputTensorInfos.size()));
433
434 }
435
436 if (m_Inputs.size() != 1)
437 {
438 throw InvalidArgumentException(boost::str(
439 boost::format("Number of inputs (%1%) is not 1.")
440 % m_Inputs.size()));
441 }
442
443 if (m_Outputs.size() != 0)
444 {
445 throw InvalidArgumentException(boost::str(
446 boost::format("Number of outputs (%1%) is not 0.")
447 % m_Inputs.size() % m_Outputs.size()));
448 }
449
450 if (!m_Inputs[0])
451 {
452 throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
453 }
454}
455
456//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000457void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
458{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100459 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100460
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100461 ValidateNumInputs(workloadInfo, descriptorName, 1);
462 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100463
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100464 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
465 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100466
467 std::vector<DataType> supportedTypes =
468 {
469 DataType::Float16,
470 DataType::Float32,
471 DataType::QuantisedAsymm8,
472 DataType::QuantisedSymm16
473 };
474
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100475 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
476 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
477 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000478}
479
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100480void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
481{
482 const std::string descriptorName{"SoftmaxQueueDescriptor"};
483
484 ValidateNumInputs(workloadInfo, descriptorName, 1);
485 ValidateNumOutputs(workloadInfo, descriptorName, 1);
486
487 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
488 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
489
490 std::vector<DataType> supportedTypes =
491 {
492 DataType::Float16,
493 DataType::Float32,
494 DataType::QuantisedAsymm8,
495 DataType::QuantisedSymm16
496 };
497
498 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
499 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
500 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
501}
502
telsoa014fcda012018-03-09 14:13:49 +0000503void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
504{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100505 const std::string descriptorName{"SplitterQueueDescriptor"};
506
507 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000508
Ruomei Yan25339c32019-05-28 16:48:20 +0100509 // Check the supported data types
510 std::vector<DataType> supportedTypes =
511 {
512 DataType::Float32,
513 DataType::Float16,
514 DataType::Boolean,
515 DataType::Signed32,
516 DataType::QuantisedAsymm8,
517 DataType::QuantisedSymm16
518 };
519
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100520 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
521 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100522 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100523 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
524 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
525
526 const std::string outputName = "output_" + std::to_string(i);
527 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100528 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100529
telsoa014fcda012018-03-09 14:13:49 +0000530 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
531 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100532 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000533 }
534
535 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
536 {
537 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100538 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000539 "has to match number of workloadInfo.m_OutputTensorInfos. "
540 "Number of windows: " +
541 to_string(m_ViewOrigins.size()) +
542 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
543 }
544
telsoa01c577f2c2018-08-31 09:22:23 +0100545 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000546 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
547 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
548 {
telsoa01c577f2c2018-08-31 09:22:23 +0100549 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000550 ViewOrigin const& e = m_ViewOrigins[w];
551 if (e.m_Origin.size() != inputDims)
552 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100553 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000554 "have the same dimensionality as the input tensor. "
555 "Window origin (index: " +
556 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
557 " dimensions, the input "
558 "tensor has " +
559 to_string(inputDims) + " dimensions.");
560 }
561 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
562 {
563 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
564 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
565 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100566 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000567 "be smaller or equal than the size of the input in that coord.");
568 }
569 }
570 }
571}
572
Jim Flynne242f2d2019-05-22 14:24:13 +0100573void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000574{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100575 const std::string descriptorName{"ConcatQueueDescriptor"};
576
577 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000578
579 if (m_Inputs.size() <= 0)
580 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100581 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000582 }
583 if (m_Outputs.size() <= 0)
584 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100585 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000586 }
587
588 if (workloadInfo.m_InputTensorInfos.size() <= 0)
589 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100590 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000591 }
592 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
593 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100594 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000595 }
596
Nikhil Raj8599a412018-11-19 14:51:07 +0000597 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
598 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100599 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000600 }
601
602 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
603 {
604 return;
605 }
606
telsoa014fcda012018-03-09 14:13:49 +0000607 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
608 {
609 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100610 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000611 "has to match number of workloadInfo.m_InputTensorInfos. "
612 "Number of windows: " +
613 to_string(m_ViewOrigins.size()) +
614 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
615 }
616
telsoa01c577f2c2018-08-31 09:22:23 +0100617 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000618 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
619 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
620 {
telsoa01c577f2c2018-08-31 09:22:23 +0100621 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000622 ViewOrigin const& e = m_ViewOrigins[w];
623 if (e.m_Origin.size() != outputDims)
624 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100625 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000626 "have the same dimensionality as the output tensor. "
627 "Window origin (index: " +
628 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
629 " dimensions, the output "
630 "tensor has " +
631 to_string(outputDims) + " dimensions.");
632 }
telsoa01c577f2c2018-08-31 09:22:23 +0100633 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000634 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
635 {
636 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
637 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
638 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100639 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000640 "be smaller or equal than the size of the output in that coord.");
641 }
642 }
643 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100644
645 // Check the supported data types
646 std::vector<DataType> supportedTypes =
647 {
648 DataType::Float32,
649 DataType::Float16,
650 DataType::Boolean,
651 DataType::Signed32,
652 DataType::QuantisedAsymm8,
653 DataType::QuantisedSymm16
654 };
655
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100656 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
657 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100658 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100659 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
660 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
661
662 const std::string inputName = "input_" + std::to_string(i);
663 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100664 }
telsoa014fcda012018-03-09 14:13:49 +0000665}
666
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100667void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
668{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100669 const std::string descriptorName{"StackQueueDescriptor"};
670
671 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100672
673 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
674 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100675 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100676 }
677
678 // All inputs must have the same shape, which is defined in parameters
679 const TensorShape& inputShape = m_Parameters.m_InputShape;
680 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
681 {
682 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
683 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100684 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100685 }
686 }
687
688 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
689 // since the output tensor has an additional dimension.
690 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
691 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100692 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100693 "than the number of input dimensions.");
694 }
695
696 // Output shape must be as inferred from the input shape
697 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
698 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
699 {
700 if (outputShape[i] != inputShape[i])
701 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100702 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100703 "match shape inferred from input tensor.");
704 }
705 }
706
707 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
708 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100709 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100710 "match shape inferred from input tensor.");
711 }
712
713 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
714 {
715 if (outputShape[i] != inputShape[i-1])
716 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100717 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100718 "match shape inferred from input tensor.");
719 }
720 }
721
722 // Check the supported data types
723 std::vector<DataType> supportedTypes =
724 {
725 DataType::Float32,
726 DataType::Float16,
727 DataType::Boolean,
728 DataType::Signed32,
729 DataType::QuantisedAsymm8,
730 DataType::QuantisedSymm16
731 };
732
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100733 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100734
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100735 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100736 {
737 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
738 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100739 descriptorName,
740 "input_0",
741 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100742 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100743
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100744 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
745 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100746 descriptorName,
747 "input_0",
748 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100749}
750
telsoa014fcda012018-03-09 14:13:49 +0000751void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
752{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100753 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000754
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100755 ValidateNumInputs(workloadInfo, descriptorName, 1);
756 ValidateNumOutputs(workloadInfo, descriptorName, 1);
757
758 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
759 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
760
761 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
762
763 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +0000764 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100765 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +0000766 }
767
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100768 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000769
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100770 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
771 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000772
773 if (m_Parameters.m_BiasEnabled)
774 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100775 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000776
telsoa01c577f2c2018-08-31 09:22:23 +0100777 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100778 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
779 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000780
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100781 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
782 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000783 }
784
Francis Murtagh46c09d02019-05-28 08:15:28 +0100785 // Check the supported data types
786 std::vector<DataType> supportedTypes =
787 {
788 DataType::Float32,
789 DataType::Float16,
790 DataType::QuantisedAsymm8,
791 DataType::QuantisedSymm16
792 };
793
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100794 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
795 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000796}
797
telsoa014fcda012018-03-09 14:13:49 +0000798void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
799{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100800 const std::string descriptorName{"NormalizationQueueDescriptor"};
801
802 ValidateNumInputs(workloadInfo, descriptorName, 1);
803 ValidateNumOutputs(workloadInfo, descriptorName, 1);
804
805 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
806 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100807
808 // Check the supported data types
809 std::vector<DataType> supportedTypes =
810 {
811 DataType::Float16,
812 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100813 DataType::QuantisedAsymm8,
814 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100815 };
816
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100817 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100818
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100819 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100820
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100821 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000822}
823
824void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
825{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100826 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000827
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100828 ValidateNumInputs(workloadInfo, descriptorName, 2);
829 ValidateNumOutputs(workloadInfo, descriptorName, 1);
830
831 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
832 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
833 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
834
835 std::vector<DataType> supportedTypes =
836 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100837 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100838 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +0100839 DataType::QuantisedSymm16,
840 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100841 };
842
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100843 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
844 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
845 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100846
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100847 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
848 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100849
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100850 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
851 inputTensorInfo1,
852 outputTensorInfo,
853 descriptorName,
854 "input_0",
855 "input_1");
telsoa014fcda012018-03-09 14:13:49 +0000856}
857
telsoa014fcda012018-03-09 14:13:49 +0000858void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
859{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100860 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +0100861
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100862 ValidateNumInputs(workloadInfo, descriptorName, 2);
863 ValidateNumOutputs(workloadInfo, descriptorName, 1);
864
865 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
866 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
867 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
868
869 std::vector<DataType> supportedTypes =
870 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100871 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100872 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +0100873 DataType::QuantisedSymm16,
874 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100875 };
876
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100877 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
878 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
879 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100880
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100881 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
882 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100883
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100884 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
885 inputTensorInfo1,
886 outputTensorInfo,
887 descriptorName,
888 "input_0",
889 "input_1");
telsoa014fcda012018-03-09 14:13:49 +0000890}
891
892void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
893{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100894 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100895
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100896 ValidateNumInputs(workloadInfo, descriptorName, 1);
897 ValidateNumOutputs(workloadInfo, descriptorName, 1);
898
899 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
900 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100901
902 std::vector<DataType> supportedTypes =
903 {
904 DataType::Float16,
905 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100906 DataType::QuantisedAsymm8,
907 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100908 };
909
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100910 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
911 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100912
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100913 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
914 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
915 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100916
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100917 ValidatePointer(m_Mean, descriptorName, "mean");
918 ValidatePointer(m_Variance, descriptorName, "variance");
919 ValidatePointer(m_Beta, descriptorName, "beta");
920 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000921
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100922 const TensorInfo& mean = m_Mean->GetTensorInfo();
923 const TensorInfo& variance = m_Variance->GetTensorInfo();
924 const TensorInfo& beta = m_Beta->GetTensorInfo();
925 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000926
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100927 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
928 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
929 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
930 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000931
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100932 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
933 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
934 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000935}
936
937void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
938{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100939 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000940
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100941 ValidateNumInputs(workloadInfo, descriptorName, 1);
942 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000943
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100944 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
945 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +0000946
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100947 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
948 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +0000949
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100950 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000951
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100952 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
953 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000954
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100955 ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight");
telsoa014fcda012018-03-09 14:13:49 +0000956
957 if (m_Parameters.m_BiasEnabled)
958 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100959 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000960
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100961 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
962 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
963
964 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
965 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000966 }
967
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100968 std::vector<DataType> supportedTypes =
969 {
Ruomei Yan88d44b82019-05-23 14:29:06 +0100970 DataType::Float32,
971 DataType::QuantisedAsymm8,
972 DataType::QuantisedSymm16,
973 DataType::Float16
974 };
975
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100976 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
977 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
978}
Ruomei Yan88d44b82019-05-23 14:29:06 +0100979
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100980void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
981{
982 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
983
984 ValidateNumInputs(workloadInfo, descriptorName, 1);
985 ValidateNumOutputs(workloadInfo, descriptorName, 1);
986
987 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
988 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
989
990 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
991 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
992
993 ValidatePointer(m_Weight, descriptorName, "weight");
994
995 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
996 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
997
998 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
999 {
1000 throw InvalidArgumentException(
1001 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
1002 "cannot be smaller than 1.") % descriptorName %
1003 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
1004 }
1005
1006 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1007
1008 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1009 // inputChannels * channelMultiplier should be equal to outputChannels.
1010 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1011 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1012 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1013 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1014 {
1015 throw InvalidArgumentException(
1016 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
1017 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
1018 "(provided %4%).") % descriptorName % numWeightOutputChannels %
1019 numWeightInputChannels % numWeightChannelMultiplier));
1020 }
1021
1022 ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight");
1023
1024 if (m_Parameters.m_BiasEnabled)
1025 {
1026 ValidatePointer(m_Bias, descriptorName, "bias");
1027
1028 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
1029 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
1030
1031 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1032 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1033 }
1034
1035 std::vector<DataType> supportedTypes =
1036 {
1037 DataType::Float32,
1038 DataType::QuantisedAsymm8,
1039 DataType::QuantisedSymm16,
1040 DataType::Float16
1041 };
1042
1043 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1044 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001045}
1046
1047void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1048{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001049 const std::string descriptorName{"PermuteQueueDescriptor"};
1050
1051 ValidateNumInputs(workloadInfo, descriptorName, 1);
1052 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001053
1054 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1055
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001056 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1057 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001058
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001059 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1060 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001061
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001062 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001063 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001064 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001065 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001066 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1067 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1068 "must match dst dimension " + to_string(mapping[i]) +
1069 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001070 }
1071 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001072
1073 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001074}
1075
1076void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1077{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001078 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001079
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001080 ValidateNumInputs(workloadInfo, descriptorName, 1);
1081 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1082
1083 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1084 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1085
1086 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1087 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001088
1089 std::vector<DataType> supportedTypes =
1090 {
1091 DataType::Float32,
1092 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001093 DataType::QuantisedAsymm8,
1094 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001095 };
1096
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001097 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1098 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001099}
1100
1101void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1102{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001103 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001104
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001105 ValidateNumInputs(workloadInfo, descriptorName, 1);
1106 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1107
1108 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1109 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1110
1111 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1112 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001113
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001114 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001115 {
1116 DataType::Float16,
1117 DataType::Float32,
1118 DataType::QuantisedAsymm8,
1119 DataType::QuantisedSymm16
1120 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001121
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001122 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1123 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001124
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001125 // ResizeBilinear only changes width and height: batch and channel count must match.
1126 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1127 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001128 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001129 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001130 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001131 boost::str(boost::format("%1%: Input batch size (%2%) "
1132 "does not match output batch size (%3%)") %
1133 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001134 }
1135
Teresa Charlin970f43b2019-07-01 13:51:07 +01001136 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001137 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1138 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001139 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001140 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001141 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001142 boost::str(boost::format("%1%: Input channel count (%2%) "
1143 "does not match output channel count (%3%)") %
1144 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001145 }
1146}
1147
1148void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1149{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001150 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001151
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001152 ValidateNumInputs(workloadInfo, descriptorName, 1);
1153 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1154
1155 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1156 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1157
1158 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1159 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001160
1161 std::vector<DataType> supportedTypes =
1162 {
1163 DataType::Float16,
1164 DataType::Float32,
1165 DataType::QuantisedAsymm8,
1166 DataType::QuantisedSymm16
1167 };
1168
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001169 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1170 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001171
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001172 // Resize only changes width and height: batch and channel count must match.
1173 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1174 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001175 if (inputBatchSize != outputBatchSize)
1176 {
1177 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001178 boost::str(boost::format("%1%: Input batch size (%2%) "
1179 "does not match output batch size (%3%)") %
1180 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001181 }
1182
1183 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001184 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1185 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001186 if (inputChannelCount != outputChannelCount)
1187 {
1188 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001189 boost::str(boost::format("%1%: Input channel count (%2%) "
1190 "does not match output channel count (%3%)") %
1191 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001192 }
1193}
1194
1195void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1196{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001197 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001198
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001199 ValidateNumInputs(workloadInfo, descriptorName, 1);
1200 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1201
1202 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1203 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1204
1205 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1206 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1207
1208 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1209
telsoa014fcda012018-03-09 14:13:49 +00001210 if (m_Parameters.m_Min > m_Parameters.m_Max)
1211 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001212 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001213 }
telsoa014fcda012018-03-09 14:13:49 +00001214}
1215
1216void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1217{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001218 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001219
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001220 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001221 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1222
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001223 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1224 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1225
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001226 if (inputTensorInfo.GetNumDimensions() > 4)
1227 {
1228 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1229 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001230
1231 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001232
1233 // Check the supported data types
1234 std::vector<DataType> supportedTypes =
1235 {
1236 DataType::Float32,
1237 DataType::Float16,
1238 DataType::QuantisedAsymm8,
1239 DataType::QuantisedSymm16
1240 };
1241
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001242 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1243 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1244
1245 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001246}
1247
1248void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1249{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001250 const std::string descriptorName{"ConstantQueueDescriptor"};
1251
1252 ValidateNumInputs(workloadInfo, descriptorName, 0);
1253 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001254
1255 if (!m_LayerOutput)
1256 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001257 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001258 }
1259
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001260 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1261 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001262
1263 // Check the supported data types
1264 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001265 {
1266 DataType::Float32,
1267 DataType::Float16,
1268 DataType::Signed32,
1269 DataType::QuantisedAsymm8,
1270 DataType::QuantisedSymm16
1271 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001272
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001273 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001274}
1275
1276void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1277{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001278 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001279
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001280 ValidateNumInputs(workloadInfo, descriptorName, 1);
1281 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1282
1283 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1284 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1285
1286 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001287
1288 // Check the supported data types
1289 std::vector<DataType> supportedTypes =
1290 {
1291 DataType::Float32,
1292 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001293 DataType::QuantisedAsymm8,
1294 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001295 };
1296
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001297 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1298 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001299}
1300
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001301void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1302{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001303 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001304
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001305 ValidateNumInputs(workloadInfo, descriptorName, 1);
1306 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1307
1308 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1309 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1310
1311 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1312 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001313
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001314 if (m_Parameters.m_BlockShape.size() != 2)
1315 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001316 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001317 }
1318
1319 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1320 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001321 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1322 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001323 }
1324
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001325 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001326
1327 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001328 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001329
Matthew Bentham8800c002018-11-19 13:19:28 +00001330 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001331
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001332 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1333 widthPad.first + widthPad.second;
1334 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1335 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001336
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001337 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1338 inputShape[dimensionIndices.GetChannelsIndex()];
1339 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001340
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001341 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001342 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001343 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001344 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001345 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001346 }
1347
1348 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001349 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001350 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1351 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001352 }
nikraj01120522a2019-05-31 11:33:07 +01001353
1354 std::vector<DataType> supportedTypes =
1355 {
1356 DataType::Float16,
1357 DataType::Float32,
1358 DataType::QuantisedAsymm8,
1359 DataType::QuantisedSymm16
1360 };
1361
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001362 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1363 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001364}
1365
Keith Davisa57eccb2019-06-14 17:33:22 +01001366void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1367{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001368 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001369
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001370 ValidateNumInputs(workloadInfo, descriptorName, 1);
1371 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001372
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001373 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1374 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1375
1376 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1377 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001378
1379 std::vector<DataType> supportedTypes =
1380 {
1381 DataType::Float32,
1382 DataType::Float16,
James Conroyd2aa85e2019-07-01 17:12:40 +01001383 DataType::QuantisedAsymm8,
1384 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001385 };
1386
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001387 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1388 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001389
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001390 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1391 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1392 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1393 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001394
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001395 const TensorShape& inputShape = inputTensorInfo.GetShape();
Keith Davisa57eccb2019-06-14 17:33:22 +01001396
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001397 const unsigned int numInputElements =
1398 inputShape[0] * inputShape[wIndex] * inputShape[hIndex] * inputShape[cIndex];
1399 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
1400
1401 if (numOutputElements != numInputElements)
Keith Davisa57eccb2019-06-14 17:33:22 +01001402 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001403 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
1404 std::to_string(numInputElements) + " but output tensor has " +
1405 std::to_string(numOutputElements) + " elements.");
Keith Davisa57eccb2019-06-14 17:33:22 +01001406 }
1407
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001408 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001409 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001410 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1411 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001412 }
1413}
1414
telsoa014fcda012018-03-09 14:13:49 +00001415void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1416{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001417 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001418
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001419 ValidateNumInputs(workloadInfo, descriptorName, 1);
1420 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1421
1422 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1423 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001424
1425 std::vector<DataType> supportedTypes =
1426 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001427 DataType::Float32,
1428 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +01001429 };
1430
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001431 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001432
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001433 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001434 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001435 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001436 }
1437}
1438
telsoa01c577f2c2018-08-31 09:22:23 +01001439void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1440{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001441 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1442
1443 const std::string descriptorName{"LstmQueueDescriptor"};
1444
1445 // check dimensions of all inputs and outputs
1446 if (workloadInfo.m_InputTensorInfos.size() != 3)
1447 {
1448 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1449 }
1450 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1451 {
1452 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1453 }
1454
1455 std::vector<DataType> supportedTypes =
1456 {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001457 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001458 DataType::Float32,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001459 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001460 };
1461
Jan Eilers38e05bd2019-06-26 13:10:09 +01001462 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001463 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1464
Jan Eilers38e05bd2019-06-26 13:10:09 +01001465 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001466 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001467 {
1468 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1469 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001470 descriptorName,
1471 "input_0",
1472 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001473 }
1474 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001475 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001476 {
1477 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1478 workloadInfo.m_OutputTensorInfos[i],
1479 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001480 "input_0",
1481 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001482 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001483
Jan Eilers38e05bd2019-06-26 13:10:09 +01001484 // TODO: check clipping parameter is valid
1485
1486 // Inferring batch size, number of outputs and number of cells from the inputs.
1487 // TODO: figure out if there is a way to make sure the specific inputs are at that index of workloadInfo
1488 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1489 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1490 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1491 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1492 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1493 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1494
Jan Eilers38e05bd2019-06-26 13:10:09 +01001495 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001496 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1497 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001498 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001499 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1500 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001501 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001502 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1503 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001504 // scratchBufferTensor
1505 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001506 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1507 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001508 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001509 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1510 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001511 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001512 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1513 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001514 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001515 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1516 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001517
1518
1519 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1520 if ( m_InputToInputWeights )
1521 {
1522 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1523 (n_cell * n_input), "InputLayerNormWeights");
1524 }
1525
1526 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1527 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1528 (n_cell * n_input), "InputToForgetWeights");
1529
1530 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1531 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1532 (n_cell * n_input), "InputToCellWeights");
1533
1534 if ( m_RecurrentToInputWeights )
1535 {
1536 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1537 (n_cell * n_output), "RecurrentToInputWeights");
1538 }
1539
1540 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1541 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1542 (n_cell * n_output), "RecurrentToForgetWeights");
1543
1544 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1545 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1546 (n_cell * n_output), "RecurrentToCellWeights");
1547
1548 // Make sure the input-gate's parameters are either both present (regular
1549 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1550 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1551 !m_Parameters.m_CifgEnabled) ||
1552 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1553 m_Parameters.m_CifgEnabled));
1554 if (!cifg_weights_all_or_none)
1555 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001556 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1557 "RecurrentToInputWeights must either both be present (regular LSTM) "
1558 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1559 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001560 }
1561
1562 if ( m_CellToInputWeights )
1563 {
1564 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1565 n_cell, "CellToInputWeights");
1566 }
1567 if ( m_CellToForgetWeights )
1568 {
1569 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1570 n_cell, "CellToForgetWeights");
1571 }
1572 if ( m_CellToOutputWeights )
1573 {
1574 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1575 n_cell, "CellToOutputWeights");
1576 }
1577
1578 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1579 bool peephole_weights_all_or_none =
1580 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1581 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1582 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1583 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1584 if (!peephole_weights_all_or_none)
1585 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001586 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001587 }
1588
1589 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1590 if (m_Parameters.m_CifgEnabled)
1591 {
1592 if (m_InputGateBias)
1593 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001594 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001595 }
1596 }
1597 else
1598 {
1599 if (!m_InputGateBias)
1600 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001601 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1602 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001603 }
1604 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1605 n_cell, "InputGateBias");
1606 }
1607
1608 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1609 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1610
1611 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1612 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1613
1614 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1615 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1616
1617 if (m_ProjectionWeights)
1618 {
1619 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
1620 (n_cell * n_output), "ProjectionWeights");
1621 }
1622 if (m_ProjectionBias)
1623 {
1624 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
1625 }
1626
1627 // Making sure the projection tensors are consistent:
1628 // 1) If projection weight is not present, then projection bias should not be
1629 // present.
1630 // 2) If projection weight is present, then projection bias is optional.
1631 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
1632 !m_Parameters.m_ProjectionEnabled)
1633 || (m_ProjectionWeights && !m_ProjectionBias &&
1634 m_Parameters.m_ProjectionEnabled)
1635 || (m_ProjectionWeights && m_ProjectionBias &&
1636 m_Parameters.m_ProjectionEnabled));
1637 if (!projecton_tensors_consistent)
1638 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001639 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001640 }
1641
1642 // The four layer normalization weights either all have values or none of them have values. Additionally, if
1643 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
1644 // either all have values or none of them have values. Layer normalization is used when the values of all the
1645 // layer normalization weights are present
1646 if (m_InputLayerNormWeights)
1647 {
1648 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
1649 }
1650 if (m_ForgetLayerNormWeights)
1651 {
1652 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1653 }
1654 if (m_CellLayerNormWeights)
1655 {
1656 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1657 }
1658 if (m_OutputLayerNormWeights)
1659 {
1660 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1661 }
1662
Jan Eilers38e05bd2019-06-26 13:10:09 +01001663 if (m_Parameters.m_LayerNormEnabled)
1664 {
1665 if (!m_Parameters.m_CifgEnabled)
1666 {
1667 if (!m_InputLayerNormWeights)
1668 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001669 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
1670 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001671 }
1672 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
1673 1, n_cell, "InputLayerNormWeights");
1674 }
1675 else if (m_InputLayerNormWeights)
1676 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001677 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
1678 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001679 }
1680
1681 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
1682 "ForgetLayerNormWeights");
1683 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1684
1685 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
1686 "OutputLayerNormWeights");
1687 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1688
1689 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
1690 "CellLayerNormWeights");
1691 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1692 }
1693 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
1694 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001695 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
1696 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001697 }
telsoa01c577f2c2018-08-31 09:22:23 +01001698}
1699
1700void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1701{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001702 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001703
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001704 ValidateNumInputs(workloadInfo, descriptorName, 1);
1705 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1706
1707 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1708 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1709
1710 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01001711 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001712 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01001713 }
1714
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001715 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001716 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001717 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001718 }
1719
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001720 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001721}
1722
1723void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1724{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001725 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001726
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001727 ValidateNumInputs(workloadInfo, descriptorName, 1);
1728 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1729
1730 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1731 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1732
1733 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001734 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001735 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001736 }
1737
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001738 if (outputTensorInfo.GetDataType() != DataType::Float32)
1739 {
1740 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
1741 }
1742
1743 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001744}
1745
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001746void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1747{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001748 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001749
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001750 ValidateNumInputs(workloadInfo, descriptorName, 2);
1751 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1752
1753 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1754 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1755 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1756
1757 std::vector<DataType> supportedTypes =
1758 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001759 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001760 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001761 DataType::QuantisedSymm16,
1762 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001763 };
1764
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001765 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1766 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1767 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001768
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001769 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1770 inputTensorInfo1,
1771 outputTensorInfo,
1772 descriptorName,
1773 "input_0",
1774 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001775}
1776
David Beckc2044fe2018-09-05 15:00:38 +01001777void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1778{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001779 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01001780
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001781 ValidateNumInputs(workloadInfo, descriptorName, 2);
1782 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1783
1784 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1785 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1786 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1787
1788 std::vector<DataType> supportedTypes =
1789 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001790 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001791 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001792 DataType::QuantisedSymm16,
1793 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001794 };
1795
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001796 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1797 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1798 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001799
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001800 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1801 inputTensorInfo1,
1802 outputTensorInfo,
1803 descriptorName,
1804 "input_0",
1805 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01001806}
1807
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001808void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1809{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001810 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001811
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001812 ValidateNumInputs(workloadInfo, descriptorName, 2);
1813 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1814
1815 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1816 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1817 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1818
1819 std::vector<DataType> supportedTypes =
1820 {
Mike Kelly1da02362019-08-01 08:43:57 +01001821 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001822 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01001823 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001824 DataType::QuantisedAsymm8,
1825 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001826 };
1827
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001828 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1829 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1830 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001831
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001832 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1833 inputTensorInfo1,
1834 outputTensorInfo,
1835 descriptorName,
1836 "input_0",
1837 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001838}
1839
narpra01a6bf9122018-09-10 09:50:09 +01001840void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1841{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001842 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01001843
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001844 ValidateNumInputs(workloadInfo, descriptorName, 1);
1845 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1846
1847 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1848 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01001849
1850 std::vector<DataType> supportedTypes =
1851 {
1852 DataType::Float32,
1853 DataType::Float16,
1854 DataType::QuantisedAsymm8,
1855 DataType::QuantisedSymm16
1856 };
narpra01eb061912018-09-10 17:35:27 +01001857
James Conroy4d1ff582019-06-10 17:06:39 +01001858 // First check if input tensor data type is supported, then
1859 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001860 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1861 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01001862
narpra0132b90462018-09-13 11:07:48 +01001863 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01001864 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001865 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01001866 }
narpra0132b90462018-09-13 11:07:48 +01001867 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01001868 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001869 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01001870 }
1871 else
1872 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001873 unsigned int outputDim =
1874 inputTensorInfo.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
1875 ValidateTensorNumDimensions(outputTensorInfo,
1876 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01001877 outputDim > 0 ? outputDim : 1,
1878 "output");
1879 }
narpra01a6bf9122018-09-10 09:50:09 +01001880}
1881
jimfly012c9322a2018-09-19 10:59:49 +01001882void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1883{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001884 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01001885
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001886 ValidateNumInputs(workloadInfo, descriptorName, 1);
1887 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1888
1889 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1890 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01001891
jimfly012c9322a2018-09-19 10:59:49 +01001892 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001893 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
1894
jimfly012c9322a2018-09-19 10:59:49 +01001895 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001896 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
1897 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
1898 "as there are dimensions in the input tensor that is " +
1899 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
1900 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01001901 }
1902}
1903
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001904void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1905{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001906 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001907
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001908 ValidateNumInputs(workloadInfo, descriptorName, 1);
1909 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001910
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001911 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1912 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1913
Sadik Armagan2208b602019-07-31 16:36:27 +01001914 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001915 {
Sadik Armagan2208b602019-07-31 16:36:27 +01001916 DataType::Float32,
1917 DataType::Float16
1918 };
1919
1920 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001921
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001922 if (outputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
1923 outputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001924 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001925 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001926 }
1927}
1928
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00001929void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1930{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001931 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001932
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001933 ValidateNumInputs(workloadInfo, descriptorName, 1);
1934 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001935
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001936 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1937 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001938
1939 std::vector<DataType> supportedTypes =
1940 {
1941 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01001942 DataType::Float16,
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001943 DataType::QuantisedAsymm8,
1944 DataType::QuantisedSymm16
1945 };
1946
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001947 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1948 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00001949}
1950
Conor Kennedy430b5d82018-11-14 15:28:28 +00001951void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1952{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001953 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00001954
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001955 ValidateNumInputs(workloadInfo, descriptorName, 1);
1956 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1957
1958 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1959 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001960
1961 std::vector<DataType> supportedTypes =
1962 {
1963 DataType::Float16,
1964 DataType::Float32,
Matteo Martincigh42666a12019-05-29 08:53:41 +01001965 DataType::QuantisedAsymm8,
1966 DataType::QuantisedSymm16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001967 };
1968
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001969 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1970 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001971
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001972 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001973
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001974 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00001975 if (rank > 4)
1976 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001977 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00001978 }
1979
Conor Kennedy430b5d82018-11-14 15:28:28 +00001980 // Begin, End & Stride length must be of rank(input0)
1981 if (m_Parameters.m_Begin.size() != rank)
1982 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001983 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00001984 }
1985
1986 if (m_Parameters.m_End.size() != rank)
1987 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001988 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00001989 }
1990
1991 if (m_Parameters.m_Stride.size() != rank)
1992 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001993 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00001994 }
1995
1996 // Stride entries must be non-zero
1997 for (auto& stride : m_Parameters.m_Stride)
1998 {
1999 if (stride == 0)
2000 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002001 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002002 }
2003 }
2004}
2005
kevmay0190539692018-11-29 08:40:19 +00002006void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2007{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002008 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002009
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002010 ValidateNumInputs(workloadInfo, descriptorName, 2);
2011 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2012
2013 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2014 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2015 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2016
2017 std::vector<DataType> supportedTypes =
2018 {
Mike Kelly1da02362019-08-01 08:43:57 +01002019 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002020 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002021 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01002022 DataType::QuantisedAsymm8,
2023 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002024 };
2025
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002026 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2027 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2028 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002029
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002030 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2031 inputTensorInfo1,
2032 outputTensorInfo,
2033 descriptorName,
2034 "input_0",
2035 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002036}
2037
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002038void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2039{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002040 const std::string descriptorName{"DebugQueueDescriptor"};
2041
2042 ValidateNumInputs(workloadInfo, descriptorName, 1);
2043 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002044}
2045
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002046void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2047{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002048 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002049
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002050 ValidateNumInputs(workloadInfo, descriptorName, 2);
2051 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002052
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002053 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2054 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2055 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2056
2057 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2058 inputTensorInfo1,
2059 outputTensorInfo,
2060 descriptorName,
2061 "input_0",
2062 "input_1");
2063
2064 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002065 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002066 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002067 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002068}
2069
FrancisMurtagh878f0232018-12-19 10:56:15 +00002070void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2071{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002072 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002073
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002074 ValidateNumInputs(workloadInfo, descriptorName, 2);
2075 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002076
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002077 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2078 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2079 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2080
2081 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2082 inputTensorInfo1,
2083 outputTensorInfo,
2084 descriptorName,
2085 "input_0",
2086 "input_1");
2087
2088 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002089 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002090 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002091 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002092}
2093
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002094void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2095{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002096 const std::string descriptorName{"RsqrtQueueDescriptor"};
2097
2098 ValidateNumInputs(workloadInfo, descriptorName, 1);
2099 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2100
2101 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2102 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2103
2104 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002105
2106 std::vector<DataType> supportedTypes =
2107 {
2108 DataType::Float16,
2109 DataType::Float32,
nikraj0124d73212019-06-14 14:20:40 +01002110 DataType::QuantisedAsymm8,
2111 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01002112 };
2113
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002114 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2115 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002116}
2117
narpra01b89b05f2019-01-16 09:53:09 +00002118void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2119{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002120 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002121
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002122 ValidateNumInputs(workloadInfo, descriptorName, 2);
2123 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002124
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002125 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2126 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002127 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002128 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002129 }
2130
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002131 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2132 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2133
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002134 std::vector<DataType> supportedTypes =
2135 {
2136 DataType::Float16,
2137 DataType::Float32,
2138 DataType::QuantisedAsymm8,
2139 DataType::QuantisedSymm16
2140 };
2141
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002142 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002143
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002144 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002145
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002146 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2147 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002148}
2149
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002150void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2151{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002152 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2153
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002154 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002155
2156 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2157 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002158 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002159 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2160 }
2161
2162 if (m_Anchors == nullptr)
2163 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002164 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002165 }
2166
2167 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002168 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2169 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2170
2171 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002172 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002173 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2174 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002175
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002176 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2177 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2178 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002179
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002180 const std::vector<DataType> supportedInputTypes =
2181 {
2182 DataType::Float32,
2183 DataType::QuantisedAsymm8,
2184 DataType::QuantisedSymm16
2185 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002186
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002187 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2188 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2189 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2190
2191 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2192 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2193 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2194 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2195
2196 // NOTE: Output is always Float32 regardless of input type
2197 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2198 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2199 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2200 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002201
2202 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2203 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002204 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002205 "must be positive and less than or equal to 1.");
2206 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002207
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002208 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2209 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002210 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002211 "should be equal to number of classes + 1.");
2212 }
2213}
2214
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002215void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2216{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002217 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002218
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002219 ValidateNumInputs(workloadInfo, descriptorName, 1);
2220 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2221
2222 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2223 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2224
2225 if (inputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
2226 inputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002227 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002228 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002229 }
2230
Sadik Armagan2208b602019-07-31 16:36:27 +01002231 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002232 {
Sadik Armagan2208b602019-07-31 16:36:27 +01002233 DataType::Float32,
2234 DataType::Float16
2235 };
2236
2237 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002238}
2239
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002240void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2241{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002242 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002243
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002244 ValidateNumInputs(workloadInfo, descriptorName, 2);
2245 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002246
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002247 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2248 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2249 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002250
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002251 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2252 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2253
2254 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2255 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002256}
2257
Sadik Armaganeff363d2019-04-05 15:25:46 +01002258void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2259{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002260 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002261
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002262 ValidateNumInputs(workloadInfo, descriptorName, 2);
2263 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2264
2265 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2266 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2267
2268 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2269 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2270
2271 std::vector<DataType> supportedTypes =
2272 {
Sadik Armaganeff363d2019-04-05 15:25:46 +01002273 DataType::Float32,
2274 DataType::QuantisedAsymm8,
2275 DataType::QuantisedSymm16
2276 };
2277
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002278 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2279 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002280
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002281 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2282 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002283
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002284 ValidateTensorShapesMatch(inputTensorInfo0,
2285 outputTensorInfo0,
2286 descriptorName,
2287 "input_0",
2288 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002289
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002290 ValidateTensorShapesMatch(inputTensorInfo0,
2291 outputTensorInfo1,
2292 descriptorName,
2293 "input_0",
2294 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002295}
2296
Matteo Martincigh49124022019-01-11 13:25:59 +00002297void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2298{
2299 // This is internally generated so it should not need validation.
2300}
2301
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002302void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2303{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002304 const std::string& descriptorName{"PreluQueueDescriptor"};
2305
2306 ValidateNumInputs(workloadInfo, descriptorName, 2);
2307 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2308
2309 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2310 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2311 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002312
2313 std::vector<DataType> supportedTypes
2314 {
2315 DataType::Float16,
2316 DataType::Float32,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002317 DataType::QuantisedAsymm8,
2318 DataType::QuantisedSymm16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002319 };
2320
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002321 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2322 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002323
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002324 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002325
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002326 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2327 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002328
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002329 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2330 alphaTensorInfo,
2331 outputTensorInfo,
2332 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002333 "input",
2334 "alpha");
2335}
2336
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002337void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2338{
2339 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2340
2341 ValidateNumInputs(workloadInfo, descriptorName, 1);
2342 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2343
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002344 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2345 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2346
2347 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2348 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002349
2350 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002351
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002352 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2353 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
2354 ValidateTensorDataType(weightTensorInfo, inputTensorInfo.GetDataType(), descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002355
2356 if (m_Parameters.m_BiasEnabled)
2357 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002358 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002359
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002360 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
2361 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002362
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002363 ValidateTensorDataType(biasTensorInfo,
2364 GetBiasDataType(inputTensorInfo.GetDataType()),
2365 descriptorName,
2366 "bias");
2367
2368 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002369 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002370}
2371
James Conroy9c3cae82019-08-01 16:01:48 +01002372void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2373{
2374 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
2375
2376 // Validate number of inputs/outputs
2377 ValidateNumInputs(workloadInfo, descriptorName, 3);
2378 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2379
2380 // Input/output tensor infos
2381 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2382 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
2383 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
2384
2385 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2386 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2387
2388 std::vector<DataType> inputOutputSupportedTypes =
2389 {
2390 DataType::QuantisedAsymm8
2391 };
2392
2393 std::vector<DataType> cellStateSupportedTypes =
2394 {
2395 DataType::QuantisedSymm16
2396 };
2397
2398 std::vector<DataType> weightsSupportedTypes =
2399 {
2400 DataType::QuantisedAsymm8
2401 };
2402
2403 std::vector<DataType> biasSupportedTypes =
2404 {
2405 DataType::Signed32
2406 };
2407
2408 // Validate types of input/output tensors
2409 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2410 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2411 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2412
2413 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2414 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2415
2416 // Validate matching types of input/output tensors
2417 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2418 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2419 "outputStateIn", "outputStateOut");
2420 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2421
2422 // Validate matching quantization info for input/output tensors
2423 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2424 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
2425 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2426
2427 // Infer number of batches, input size and output size from tensor dimensions
2428 const uint32_t numBatches = inputInfo.GetShape()[0];
2429 const uint32_t inputSize = inputInfo.GetShape()[1];
2430 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
2431
2432 // Validate number of dimensions and number of elements for input/output tensors
2433 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2434 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
2435 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2436 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
2437 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2438
2439 // Validate number of dimensions and number of elements for weights tensors
2440 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
2441 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
2442 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
2443
2444 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2445 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2446 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
2447
2448 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2449 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2450 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
2451
2452 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2453 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2454 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
2455
2456 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
2457 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
2458 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
2459
2460 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2461 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2462 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
2463 " RecurrentToForgetWeights");
2464
2465 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2466 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2467 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2468
2469 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2470 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2471 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2472
2473 // Validate data types for weights tensors (all should match each other)
2474 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
2475
2476 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
2477 "inputToInputWeights", "inputToForgetWeights");
2478 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
2479 "inputToInputWeights", "inputToCellWeights");
2480 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2481 "inputToInputWeights", "inputToOutputWeights");
2482
2483 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
2484 "inputToInputWeights", "recurrentToInputWeights");
2485 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2486 "inputToInputWeights", "recurrentToForgeteights");
2487 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2488 "inputToInputWeights", "recurrentToCellWeights");
2489 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2490 "inputToInputWeights", "recurrentToOutputWeights");
2491
2492 // Validate matching quantization info for weight tensors (all should match each other)
2493 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
2494 descriptorName, "inputToInputWeights", "inputToForgetWeights");
2495 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
2496 descriptorName, "inputToInputWeights", "inputToCellWeights");
2497 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
2498 descriptorName, "inputToInputWeights", "inputToOutputWeights");
2499
2500 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
2501 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
2502 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
2503 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
2504 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
2505 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
2506 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
2507 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
2508
2509 // Validate number of dimensions and number of elements in bias tensors
2510 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
2511 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
2512 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
2513
2514 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2515 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2516 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
2517
2518 ValidatePointer(m_CellBias, descriptorName, "CellBias");
2519 auto cellBiasInfo = m_CellBias->GetTensorInfo();
2520 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
2521
2522 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
2523 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
2524 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
2525
2526 // Validate data types for bias tensors (all should match each other)
2527 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
2528
2529 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
2530 "inputGateBias", "forgetGateBias");
2531 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
2532 "inputGateBias", "cellBias");
2533 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
2534 "inputGateBias", "outputGateBias");
2535
2536 // Validate bias tensor quantization info
2537 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2538 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2539 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2540 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2541}
2542
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002543} // namespace armnn