blob: e49fd09be0829c8ea0d3b54ea5b2e75bbd18f85d [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "WorkloadData.hpp"
6
7#include "CpuTensorHandle.hpp"
telsoa014fcda012018-03-09 14:13:49 +00008
Matteo Martincigh21350152018-11-28 16:22:22 +00009#include <DataLayoutIndexed.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000013#include <string>
14#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/format.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010017#include <boost/numeric/conversion/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
telsoa01c577f2c2018-08-31 09:22:23 +010029 case DataType::Float16:
30 return DataType::Float16;
telsoa014fcda012018-03-09 14:13:49 +000031 case DataType::Float32:
32 return DataType::Float32;
33 case DataType::QuantisedAsymm8:
34 return DataType::Signed32;
Ruomei Yan88d44b82019-05-23 14:29:06 +010035 case DataType::QuantisedSymm16:
36 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000037 default:
38 BOOST_ASSERT_MSG(false, "Invalid input data type");
39 return DataType::Float32;
40 }
41}
42
43namespace
44{
45
46//---------------------------------------------------------------
47//android ndk does not support std::to_string function.
48template <typename T>
49std::string to_string(T value)
50{
51 std::ostringstream os;
52 os << value;
53 return os.str();
54}
55
56//---------------------------------------------------------------
57void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
58{
59 if (!ptr)
60 {
61 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
62 paramName + " parameter must be set.");
63 }
64}
65
66//---------------------------------------------------------------
67void ValidateTensorShapesMatch(const TensorInfo& first,
68 const TensorInfo& second,
69 std::string const& descName,
70 std::string const& firstName,
71 std::string const& secondName)
72{
73 if (first.GetShape() != second.GetShape())
74 {
75 throw InvalidArgumentException(descName + ": "
76 + firstName + " & " + secondName + " must have identical shapes");
77 }
78}
79
80//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010081void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000082{
Sadik Armaganeff363d2019-04-05 15:25:46 +010083 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000084 {
85 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010086 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000087 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
88 }
89}
90
91//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010092void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000093{
Sadik Armaganeff363d2019-04-05 15:25:46 +010094 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000095 {
96 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010097 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +000098 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
99 }
100}
101
102//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100103void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000104 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100105 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000106 std::string const& tensorName)
107{
108 if (tensor.GetNumDimensions() != numDimensions)
109 {
110 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
111 to_string(tensor.GetNumDimensions()) + " dimensions for " +
112 tensorName + " tensor.");
113 }
114}
115
116//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100117void ValidateTensorNumElements(const TensorInfo& tensor,
118 std::string const& descName,
119 unsigned int numElements,
120 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100121{
122 if (tensor.GetNumElements() != numElements)
123 {
124 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100125 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100126 tensorName + " tensor.");
127 }
128}
129
130//---------------------------------------------------------------
131void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100132 unsigned int numDimension,
133 unsigned int numElements,
134 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100135{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100136 const std::string functionName{"ValidateTensorNumDimNumElem"};
137 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
138 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100139}
140
141//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000142void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
143 const std::string& descName, std::string const& tensorName)
144{
145 if (tensor.GetDataType() != dataType)
146 {
147 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
148 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
149 }
150}
151
152//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100153void ValidateTensorQuantizationSpace(const TensorInfo& first,
154 const TensorInfo& second,
155 const std::string& descName,
156 std::string const& firstName,
157 std::string const& secondName)
158{
159 if (!first.IsQuantized() ||
160 !second.IsQuantized())
161 {
162 // Not a quantized type, ignore the validation
163 return;
164 }
165
166 DataType firstDataType = first.GetDataType();
167 DataType secondDataType = second.GetDataType();
168
169 if (firstDataType != secondDataType)
170 {
171 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
172 " must be of the same quantized type, " +
173 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
174 secondName + " is " + GetDataTypeName(secondDataType));
175 }
176
177 if (!first.IsTypeSpaceMatch(second))
178 {
179 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
180 " must have the same quantization space, " +
181 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
182 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
183 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
184 " and scale " + to_string(second.GetQuantizationScale()));
185 }
186}
187
188//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100189void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
190 const TensorInfo& inputTensorInfo,
191 const TensorInfo& weightsTensorInfo,
192 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000193{
194 if (biasTensor.GetQuantizationOffset() != 0)
195 {
196 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
197 to_string(biasTensor.GetQuantizationOffset()));
198 }
199 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
kevmay016c46dd32018-12-17 15:32:45 +0000200 if (std::abs(biasTensor.GetQuantizationScale() - expectedScale) > 0.00000001f)
telsoa014fcda012018-03-09 14:13:49 +0000201 {
202 // Print the float values with extra precision to see very small differences
203 std::stringstream msg;
204 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
205 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
206 biasTensor.GetQuantizationScale();
207 throw InvalidArgumentException(msg.str());
208 }
209}
210
211//---------------------------------------------------------------
212void ValidateTensors(const std::vector<ITensorHandle*>& vec,
213 unsigned int numExpected,
214 const std::string& descName,
215 const std::string& varName)
216{
217 if (vec.empty() && numExpected > 0)
218 {
219 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
220 }
221
222 for (unsigned int i = 0; i < numExpected; ++i)
223 {
224 if (!vec[i])
225 {
226 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
227 }
228 }
229}
230
231//---------------------------------------------------------------
232void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
233 const TensorInfo& second,
234 const TensorInfo& output,
235 std::string const& descName,
236 std::string const& firstName,
237 std::string const& secondName)
238{
239 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
240 // broadcasted.
241 if (first.GetNumDimensions() != second.GetNumDimensions())
242 {
243 throw InvalidArgumentException(descName + ": Tensors "
244 + firstName + " & " + secondName
245 + " must have the same number of dimensions in order to be broadcasted");
246 }
247 uint32_t numDims = first.GetNumDimensions();
248 std::vector<uint32_t> outputDims(numDims, 0u);
249 for (uint32_t i = 0; i < numDims; i++)
250 {
251 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
252 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
253 if (dimsNotEqual && dimsNotOne)
254 {
255 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
256 }
257 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
258 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100259 TensorShape broadcastShape = TensorShape(boost::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000260 if (broadcastShape != output.GetShape())
261 {
262 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
263 + firstName + " & " + secondName
264 + " does not match the output shape");
265 }
266}
267
268//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100269void ValidateDataTypes(const TensorInfo& info,
270 const std::vector<armnn::DataType>& supportedTypes,
271 std::string const& descName)
272{
273 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
274 if (iterator == supportedTypes.end())
275 {
276 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
277 }
278}
279
James Conroy4d1ff582019-06-10 17:06:39 +0100280//---------------------------------------------------------------
281void ValidateTensorDataTypesMatch(const TensorInfo& first,
282 const TensorInfo& second,
283 std::string const& descName,
284 std::string const& firstName,
285 std::string const& secondName)
286{
287 if (first.GetDataType() != second.GetDataType())
288 {
289 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
290 " must have identical data types.");
291 }
292}
293
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100294//---------------------------------------------------------------
295void ValidateTensorNumElementsMatch(const TensorInfo& first,
296 const TensorInfo& second,
297 std::string const& descName,
298 std::string const& firstName,
299 std::string const& secondName)
300{
301 if (first.GetNumElements() != second.GetNumElements())
302 {
303 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
304 " must have the same number of elements.");
305 }
306}
307
308} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000309
310void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
311 unsigned int numExpectedIn, unsigned int numExpectedOut) const
312{
313 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
314 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
315}
316
317//---------------------------------------------------------------
318void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
319{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100320 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000321
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100322 ValidateNumInputs(workloadInfo, descriptorName, 1);
323 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000324
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100325 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
326 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
327
328 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
329 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000330
331 if (m_Inputs.size() != m_Outputs.size())
332 {
333 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100334 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
335 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000336 }
337
338 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
339 {
340 if (!m_Inputs[i])
341 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100342 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
343 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000344 }
345
346 if (!m_Outputs[i])
347 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100348 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
349 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000350 }
351 }
352}
353
Derek Lambertif674aa02019-08-01 15:56:25 +0100354//---------------------------------------------------------------
355void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
356{
357 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
358 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
359
360 if (workloadInfo.m_InputTensorInfos.size() != 1)
361 {
362 throw InvalidArgumentException(boost::str(
363 boost::format("Number of input infos (%1%) is not 1.")
364 % workloadInfo.m_InputTensorInfos.size()));
365
366 }
367
368 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
369 {
370 throw InvalidArgumentException(boost::str(
371 boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
372 % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
373 }
374
375 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
376 {
377 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
378 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
379 {
380 throw InvalidArgumentException(boost::str(
381 boost::format("Number of elements for tensor input and output %1% does not match")
382 % i ));
383 }
384 }
385
386 if (m_Inputs.size() != 1)
387 {
388 throw InvalidArgumentException(boost::str(
389 boost::format("Number of inputs (%1%) is not 1.")
390 % m_Inputs.size()));
391 }
392
393 if (m_Inputs.size() != m_Outputs.size())
394 {
395 throw InvalidArgumentException(boost::str(
396 boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
397 % m_Inputs.size() % m_Outputs.size()));
398 }
399
400 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
401 {
402 if (!m_Inputs[i])
403 {
404 throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
405 }
406
407 if (!m_Outputs[i])
408 {
409 throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
410 }
411 }
412}
413
414//---------------------------------------------------------------
415void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
416{
417 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
418 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
419
Derek Lambertif674aa02019-08-01 15:56:25 +0100420 if (m_Inputs.size() != 1)
421 {
422 throw InvalidArgumentException(boost::str(
423 boost::format("Number of inputs (%1%) is not 1.")
424 % m_Inputs.size()));
425 }
426
427 if (m_Outputs.size() != 0)
428 {
429 throw InvalidArgumentException(boost::str(
430 boost::format("Number of outputs (%1%) is not 0.")
431 % m_Inputs.size() % m_Outputs.size()));
432 }
433
434 if (!m_Inputs[0])
435 {
436 throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
437 }
438}
439
440//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000441void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
442{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100443 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100444
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100445 ValidateNumInputs(workloadInfo, descriptorName, 1);
446 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100447
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100448 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
449 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100450
451 std::vector<DataType> supportedTypes =
452 {
James Conroyd47a0642019-09-17 14:22:06 +0100453 DataType::Float16,
454 DataType::Float32,
455 DataType::QuantisedAsymm8,
456 DataType::QuantisedSymm16
nikraj01248683f2019-05-29 16:46:50 +0100457 };
458
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100459 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
460 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
461 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000462}
463
Nikhil Rajee391d52019-09-05 17:50:44 +0100464void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
465{
466 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
467
468 ValidateNumInputs(workloadInfo, descriptorName, 1);
469 ValidateNumOutputs(workloadInfo, descriptorName, 1);
470
471 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
472 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
473
Nikhil Raj68c2c902019-09-19 11:21:11 +0100474 if (outputTensorInfo.GetDataType() != DataType::Signed32)
475 {
476 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32.");
477 }
478
James Conroyd47a0642019-09-17 14:22:06 +0100479 std::vector<DataType> supportedInputTypes =
480 {
481 DataType::Float16,
482 DataType::Float32,
483 DataType::QuantisedAsymm8,
484 DataType::QuantisedSymm16
485 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100486
James Conroyd47a0642019-09-17 14:22:06 +0100487 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
Nikhil Rajee391d52019-09-05 17:50:44 +0100488}
489
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100490void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
491{
492 const std::string descriptorName{"SoftmaxQueueDescriptor"};
493
494 ValidateNumInputs(workloadInfo, descriptorName, 1);
495 ValidateNumOutputs(workloadInfo, descriptorName, 1);
496
497 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
498 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
499
500 std::vector<DataType> supportedTypes =
501 {
James Conroyd47a0642019-09-17 14:22:06 +0100502 DataType::Float16,
503 DataType::Float32,
504 DataType::QuantisedAsymm8,
505 DataType::QuantisedSymm16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100506 };
507
508 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
509 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
510 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
511}
512
telsoa014fcda012018-03-09 14:13:49 +0000513void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
514{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100515 const std::string descriptorName{"SplitterQueueDescriptor"};
516
517 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000518
Ruomei Yan25339c32019-05-28 16:48:20 +0100519 // Check the supported data types
520 std::vector<DataType> supportedTypes =
521 {
James Conroyd47a0642019-09-17 14:22:06 +0100522 DataType::Float32,
523 DataType::Float16,
524 DataType::Boolean,
525 DataType::Signed32,
526 DataType::QuantisedAsymm8,
527 DataType::QuantisedSymm16
Ruomei Yan25339c32019-05-28 16:48:20 +0100528 };
529
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100530 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
531 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100532 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100533 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
534 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
535
536 const std::string outputName = "output_" + std::to_string(i);
537 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100538 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100539
telsoa014fcda012018-03-09 14:13:49 +0000540 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
541 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100542 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000543 }
544
545 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
546 {
547 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100548 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000549 "has to match number of workloadInfo.m_OutputTensorInfos. "
550 "Number of windows: " +
551 to_string(m_ViewOrigins.size()) +
552 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
553 }
554
telsoa01c577f2c2018-08-31 09:22:23 +0100555 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000556 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
557 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
558 {
telsoa01c577f2c2018-08-31 09:22:23 +0100559 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000560 ViewOrigin const& e = m_ViewOrigins[w];
561 if (e.m_Origin.size() != inputDims)
562 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100563 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000564 "have the same dimensionality as the input tensor. "
565 "Window origin (index: " +
566 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
567 " dimensions, the input "
568 "tensor has " +
569 to_string(inputDims) + " dimensions.");
570 }
571 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
572 {
573 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
574 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
575 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100576 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000577 "be smaller or equal than the size of the input in that coord.");
578 }
579 }
580 }
581}
582
Jim Flynne242f2d2019-05-22 14:24:13 +0100583void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000584{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100585 const std::string descriptorName{"ConcatQueueDescriptor"};
586
587 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000588
589 if (m_Inputs.size() <= 0)
590 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100591 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000592 }
593 if (m_Outputs.size() <= 0)
594 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100595 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000596 }
597
598 if (workloadInfo.m_InputTensorInfos.size() <= 0)
599 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100600 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000601 }
602 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
603 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100604 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000605 }
606
Nikhil Raj8599a412018-11-19 14:51:07 +0000607 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
608 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100609 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000610 }
611
612 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
613 {
614 return;
615 }
616
telsoa014fcda012018-03-09 14:13:49 +0000617 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
618 {
619 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100620 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000621 "has to match number of workloadInfo.m_InputTensorInfos. "
622 "Number of windows: " +
623 to_string(m_ViewOrigins.size()) +
624 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
625 }
626
telsoa01c577f2c2018-08-31 09:22:23 +0100627 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000628 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
629 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
630 {
telsoa01c577f2c2018-08-31 09:22:23 +0100631 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000632 ViewOrigin const& e = m_ViewOrigins[w];
633 if (e.m_Origin.size() != outputDims)
634 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100635 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000636 "have the same dimensionality as the output tensor. "
637 "Window origin (index: " +
638 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
639 " dimensions, the output "
640 "tensor has " +
641 to_string(outputDims) + " dimensions.");
642 }
telsoa01c577f2c2018-08-31 09:22:23 +0100643 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000644 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
645 {
646 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
647 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
648 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100649 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000650 "be smaller or equal than the size of the output in that coord.");
651 }
652 }
653 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100654
655 // Check the supported data types
656 std::vector<DataType> supportedTypes =
657 {
James Conroyd47a0642019-09-17 14:22:06 +0100658 DataType::Float32,
659 DataType::Float16,
660 DataType::Boolean,
661 DataType::Signed32,
662 DataType::QuantisedAsymm8,
663 DataType::QuantisedSymm16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100664 };
665
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100666 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
667 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100668 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100669 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
670 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
671
672 const std::string inputName = "input_" + std::to_string(i);
673 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100674 }
telsoa014fcda012018-03-09 14:13:49 +0000675}
676
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100677void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
678{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100679 const std::string descriptorName{"StackQueueDescriptor"};
680
681 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100682
683 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
684 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100685 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100686 }
687
688 // All inputs must have the same shape, which is defined in parameters
689 const TensorShape& inputShape = m_Parameters.m_InputShape;
690 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
691 {
692 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
693 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100694 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100695 }
696 }
697
Matthew Jacksondba634f2019-08-15 15:14:18 +0100698 if (inputShape.GetNumDimensions() > 4)
699 {
700 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
701 }
702
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100703 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
704 // since the output tensor has an additional dimension.
705 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
706 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100707 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100708 "than the number of input dimensions.");
709 }
710
711 // Output shape must be as inferred from the input shape
712 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
713 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
714 {
715 if (outputShape[i] != inputShape[i])
716 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100717 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100718 "match shape inferred from input tensor.");
719 }
720 }
721
722 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
723 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100724 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100725 "match shape inferred from input tensor.");
726 }
727
728 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
729 {
730 if (outputShape[i] != inputShape[i-1])
731 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100732 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100733 "match shape inferred from input tensor.");
734 }
735 }
736
Matthew Jacksondba634f2019-08-15 15:14:18 +0100737 if (outputShape.GetNumDimensions() > 5)
738 {
739 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
740 }
741
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100742 // Check the supported data types
743 std::vector<DataType> supportedTypes =
744 {
James Conroyd47a0642019-09-17 14:22:06 +0100745 DataType::Float32,
746 DataType::Float16,
747 DataType::Boolean,
748 DataType::Signed32,
749 DataType::QuantisedAsymm8,
750 DataType::QuantisedSymm16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100751 };
752
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100753 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100754
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100755 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100756 {
757 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
758 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100759 descriptorName,
760 "input_0",
761 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100762 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100763
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100764 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
765 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100766 descriptorName,
767 "input_0",
768 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100769}
770
telsoa014fcda012018-03-09 14:13:49 +0000771void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
772{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100773 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000774
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100775 ValidateNumInputs(workloadInfo, descriptorName, 1);
776 ValidateNumOutputs(workloadInfo, descriptorName, 1);
777
778 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
779 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
780
781 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
782
783 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +0000784 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100785 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +0000786 }
787
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100788 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000789
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100790 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
791 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000792
793 if (m_Parameters.m_BiasEnabled)
794 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100795 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000796
telsoa01c577f2c2018-08-31 09:22:23 +0100797 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100798 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
799 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000800
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100801 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
802 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000803 }
804
Francis Murtagh46c09d02019-05-28 08:15:28 +0100805 // Check the supported data types
806 std::vector<DataType> supportedTypes =
807 {
James Conroyd47a0642019-09-17 14:22:06 +0100808 DataType::Float32,
809 DataType::Float16,
810 DataType::QuantisedAsymm8,
811 DataType::QuantisedSymm16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100812 };
813
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100814 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
815 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000816}
817
telsoa014fcda012018-03-09 14:13:49 +0000818void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
819{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100820 const std::string descriptorName{"NormalizationQueueDescriptor"};
821
822 ValidateNumInputs(workloadInfo, descriptorName, 1);
823 ValidateNumOutputs(workloadInfo, descriptorName, 1);
824
825 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
826 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100827
828 // Check the supported data types
829 std::vector<DataType> supportedTypes =
830 {
831 DataType::Float16,
832 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100833 DataType::QuantisedAsymm8,
834 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100835 };
836
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100837 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100838
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100839 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100840
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100841 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000842}
843
844void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
845{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100846 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000847
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100848 ValidateNumInputs(workloadInfo, descriptorName, 2);
849 ValidateNumOutputs(workloadInfo, descriptorName, 1);
850
851 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
852 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
853 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
854
855 std::vector<DataType> supportedTypes =
856 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100857 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100858 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +0100859 DataType::QuantisedSymm16,
860 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100861 };
862
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100863 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
864 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
865 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100866
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100867 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
868 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100869
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100870 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
871 inputTensorInfo1,
872 outputTensorInfo,
873 descriptorName,
874 "input_0",
875 "input_1");
telsoa014fcda012018-03-09 14:13:49 +0000876}
877
telsoa014fcda012018-03-09 14:13:49 +0000878void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
879{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100880 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +0100881
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100882 ValidateNumInputs(workloadInfo, descriptorName, 2);
883 ValidateNumOutputs(workloadInfo, descriptorName, 1);
884
885 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
886 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
887 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
888
889 std::vector<DataType> supportedTypes =
890 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100891 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100892 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +0100893 DataType::QuantisedSymm16,
894 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100895 };
896
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100897 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
898 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
899 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100900
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100901 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
902 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100903
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100904 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
905 inputTensorInfo1,
906 outputTensorInfo,
907 descriptorName,
908 "input_0",
909 "input_1");
telsoa014fcda012018-03-09 14:13:49 +0000910}
911
912void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
913{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100914 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100915
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100916 ValidateNumInputs(workloadInfo, descriptorName, 1);
917 ValidateNumOutputs(workloadInfo, descriptorName, 1);
918
919 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
920 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100921
922 std::vector<DataType> supportedTypes =
923 {
924 DataType::Float16,
925 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100926 DataType::QuantisedAsymm8,
927 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100928 };
929
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100930 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
931 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100932
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100933 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
934 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
935 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100936
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100937 ValidatePointer(m_Mean, descriptorName, "mean");
938 ValidatePointer(m_Variance, descriptorName, "variance");
939 ValidatePointer(m_Beta, descriptorName, "beta");
940 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000941
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100942 const TensorInfo& mean = m_Mean->GetTensorInfo();
943 const TensorInfo& variance = m_Variance->GetTensorInfo();
944 const TensorInfo& beta = m_Beta->GetTensorInfo();
945 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000946
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100947 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
948 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
949 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
950 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000951
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100952 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
953 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
954 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000955}
956
957void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
958{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100959 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000960
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100961 ValidateNumInputs(workloadInfo, descriptorName, 1);
962 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000963
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100964 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
965 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +0000966
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100967 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
968 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +0000969
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100970 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000971
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100972 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
973 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000974
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100975 ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight");
telsoa014fcda012018-03-09 14:13:49 +0000976
977 if (m_Parameters.m_BiasEnabled)
978 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100979 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000980
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100981 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
982 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
983
984 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
985 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000986 }
987
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100988 std::vector<DataType> supportedTypes =
989 {
Ruomei Yan88d44b82019-05-23 14:29:06 +0100990 DataType::Float32,
991 DataType::QuantisedAsymm8,
992 DataType::QuantisedSymm16,
993 DataType::Float16
994 };
995
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100996 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
997 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
998}
Ruomei Yan88d44b82019-05-23 14:29:06 +0100999
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001000void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1001{
1002 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1003
1004 ValidateNumInputs(workloadInfo, descriptorName, 1);
1005 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1006
1007 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1008 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1009
1010 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1011 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1012
1013 ValidatePointer(m_Weight, descriptorName, "weight");
1014
1015 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1016 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1017
1018 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1019 {
1020 throw InvalidArgumentException(
1021 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
1022 "cannot be smaller than 1.") % descriptorName %
1023 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
1024 }
1025
1026 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1027
1028 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1029 // inputChannels * channelMultiplier should be equal to outputChannels.
1030 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1031 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1032 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1033 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1034 {
1035 throw InvalidArgumentException(
1036 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
1037 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
1038 "(provided %4%).") % descriptorName % numWeightOutputChannels %
1039 numWeightInputChannels % numWeightChannelMultiplier));
1040 }
1041
1042 ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight");
1043
1044 if (m_Parameters.m_BiasEnabled)
1045 {
1046 ValidatePointer(m_Bias, descriptorName, "bias");
1047
1048 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
1049 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
1050
1051 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1052 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1053 }
1054
1055 std::vector<DataType> supportedTypes =
1056 {
1057 DataType::Float32,
1058 DataType::QuantisedAsymm8,
1059 DataType::QuantisedSymm16,
1060 DataType::Float16
1061 };
1062
1063 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1064 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001065}
1066
1067void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1068{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001069 const std::string descriptorName{"PermuteQueueDescriptor"};
1070
1071 ValidateNumInputs(workloadInfo, descriptorName, 1);
1072 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001073
1074 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1075
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001076 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1077 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001078
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001079 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1080 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001081
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001082 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001083 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001084 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001085 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001086 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1087 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1088 "must match dst dimension " + to_string(mapping[i]) +
1089 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001090 }
1091 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001092
1093 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001094}
1095
1096void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1097{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001098 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001099
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001100 ValidateNumInputs(workloadInfo, descriptorName, 1);
1101 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1102
1103 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1104 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1105
1106 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1107 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001108
1109 std::vector<DataType> supportedTypes =
1110 {
1111 DataType::Float32,
1112 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001113 DataType::QuantisedAsymm8,
1114 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001115 };
1116
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001117 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1118 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001119}
1120
1121void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1122{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001123 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001124
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001125 ValidateNumInputs(workloadInfo, descriptorName, 1);
1126 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1127
1128 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1129 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1130
1131 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1132 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001133
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001134 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001135 {
1136 DataType::Float16,
1137 DataType::Float32,
1138 DataType::QuantisedAsymm8,
1139 DataType::QuantisedSymm16
1140 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001141
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001142 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1143 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001144
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001145 // ResizeBilinear only changes width and height: batch and channel count must match.
1146 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1147 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001148 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001149 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001150 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001151 boost::str(boost::format("%1%: Input batch size (%2%) "
1152 "does not match output batch size (%3%)") %
1153 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001154 }
1155
Teresa Charlin970f43b2019-07-01 13:51:07 +01001156 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001157 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1158 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001159 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001160 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001161 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001162 boost::str(boost::format("%1%: Input channel count (%2%) "
1163 "does not match output channel count (%3%)") %
1164 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001165 }
1166}
1167
1168void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1169{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001170 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001171
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001172 ValidateNumInputs(workloadInfo, descriptorName, 1);
1173 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1174
1175 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1176 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1177
1178 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1179 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001180
1181 std::vector<DataType> supportedTypes =
1182 {
1183 DataType::Float16,
1184 DataType::Float32,
1185 DataType::QuantisedAsymm8,
1186 DataType::QuantisedSymm16
1187 };
1188
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001189 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1190 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001191
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001192 // Resize only changes width and height: batch and channel count must match.
1193 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1194 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001195 if (inputBatchSize != outputBatchSize)
1196 {
1197 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001198 boost::str(boost::format("%1%: Input batch size (%2%) "
1199 "does not match output batch size (%3%)") %
1200 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001201 }
1202
1203 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001204 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1205 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001206 if (inputChannelCount != outputChannelCount)
1207 {
1208 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001209 boost::str(boost::format("%1%: Input channel count (%2%) "
1210 "does not match output channel count (%3%)") %
1211 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001212 }
1213}
1214
1215void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1216{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001217 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001218
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001219 ValidateNumInputs(workloadInfo, descriptorName, 1);
1220 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1221
1222 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1223 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1224
1225 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1226 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1227
1228 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1229
telsoa014fcda012018-03-09 14:13:49 +00001230 if (m_Parameters.m_Min > m_Parameters.m_Max)
1231 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001232 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001233 }
telsoa014fcda012018-03-09 14:13:49 +00001234}
1235
1236void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1237{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001238 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001239
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001240 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001241 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1242
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001243 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1244 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1245
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001246 if (inputTensorInfo.GetNumDimensions() > 4)
1247 {
1248 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1249 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001250
1251 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001252
1253 // Check the supported data types
1254 std::vector<DataType> supportedTypes =
1255 {
1256 DataType::Float32,
1257 DataType::Float16,
1258 DataType::QuantisedAsymm8,
1259 DataType::QuantisedSymm16
1260 };
1261
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001262 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1263 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1264
1265 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001266}
1267
1268void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1269{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001270 const std::string descriptorName{"ConstantQueueDescriptor"};
1271
1272 ValidateNumInputs(workloadInfo, descriptorName, 0);
1273 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001274
1275 if (!m_LayerOutput)
1276 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001277 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001278 }
1279
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001280 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1281 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001282
1283 // Check the supported data types
1284 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001285 {
1286 DataType::Float32,
1287 DataType::Float16,
1288 DataType::Signed32,
1289 DataType::QuantisedAsymm8,
1290 DataType::QuantisedSymm16
1291 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001292
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001293 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001294}
1295
1296void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1297{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001298 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001299
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001300 ValidateNumInputs(workloadInfo, descriptorName, 1);
1301 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1302
1303 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1304 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1305
1306 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001307
1308 // Check the supported data types
1309 std::vector<DataType> supportedTypes =
1310 {
1311 DataType::Float32,
1312 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001313 DataType::Signed32,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001314 DataType::QuantisedAsymm8,
1315 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001316 };
1317
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001318 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1319 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001320}
1321
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001322void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1323{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001324 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001325
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001326 ValidateNumInputs(workloadInfo, descriptorName, 1);
1327 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1328
1329 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1330 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1331
1332 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1333 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001334
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001335 if (m_Parameters.m_BlockShape.size() != 2)
1336 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001337 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001338 }
1339
1340 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1341 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001342 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1343 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001344 }
1345
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001346 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001347
1348 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001349 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001350
Matthew Bentham8800c002018-11-19 13:19:28 +00001351 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001352
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001353 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1354 widthPad.first + widthPad.second;
1355 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1356 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001357
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001358 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1359 inputShape[dimensionIndices.GetChannelsIndex()];
1360 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001361
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001362 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001363 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001364 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001365 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001366 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001367 }
1368
1369 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001370 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001371 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1372 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001373 }
nikraj01120522a2019-05-31 11:33:07 +01001374
1375 std::vector<DataType> supportedTypes =
1376 {
1377 DataType::Float16,
1378 DataType::Float32,
1379 DataType::QuantisedAsymm8,
1380 DataType::QuantisedSymm16
1381 };
1382
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001383 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1384 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001385}
1386
Keith Davisa57eccb2019-06-14 17:33:22 +01001387void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1388{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001389 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001390
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001391 ValidateNumInputs(workloadInfo, descriptorName, 1);
1392 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001393
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001394 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1395 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1396
1397 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1398 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001399
1400 std::vector<DataType> supportedTypes =
1401 {
1402 DataType::Float32,
1403 DataType::Float16,
James Conroyd2aa85e2019-07-01 17:12:40 +01001404 DataType::QuantisedAsymm8,
1405 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001406 };
1407
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001408 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1409 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001410
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001411 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1412
1413 if (m_Parameters.m_BlockSize == 0)
1414 {
1415 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1416 }
1417
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001418 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1419 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1420 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1421 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001422
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001423 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001424 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001425 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001426 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1427 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001428 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001429
1430 const TensorShape& outputShape = outputTensorInfo.GetShape();
1431 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1432 {
1433 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1434 "must be divisible by the square of block size." );
1435 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001436}
1437
telsoa014fcda012018-03-09 14:13:49 +00001438void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1439{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001440 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001441
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001442 ValidateNumInputs(workloadInfo, descriptorName, 1);
1443 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1444
1445 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1446 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001447
1448 std::vector<DataType> supportedTypes =
1449 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001450 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001451 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001452 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +01001453 };
1454
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001455 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001456
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001457 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001458 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001459 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001460 }
1461}
1462
telsoa01c577f2c2018-08-31 09:22:23 +01001463void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1464{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001465 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1466
1467 const std::string descriptorName{"LstmQueueDescriptor"};
1468
1469 // check dimensions of all inputs and outputs
1470 if (workloadInfo.m_InputTensorInfos.size() != 3)
1471 {
1472 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1473 }
1474 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1475 {
1476 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1477 }
1478
1479 std::vector<DataType> supportedTypes =
1480 {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001481 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001482 DataType::Float32,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001483 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001484 };
1485
Jan Eilers38e05bd2019-06-26 13:10:09 +01001486 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001487 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1488
Jan Eilers38e05bd2019-06-26 13:10:09 +01001489 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001490 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001491 {
1492 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1493 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001494 descriptorName,
1495 "input_0",
1496 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001497 }
1498 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001499 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001500 {
1501 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1502 workloadInfo.m_OutputTensorInfos[i],
1503 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001504 "input_0",
1505 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001506 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001507
Jan Eilers38e05bd2019-06-26 13:10:09 +01001508 // TODO: check clipping parameter is valid
1509
1510 // Inferring batch size, number of outputs and number of cells from the inputs.
1511 // TODO: figure out if there is a way to make sure the specific inputs are at that index of workloadInfo
1512 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1513 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1514 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1515 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1516 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1517 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1518
Jan Eilers38e05bd2019-06-26 13:10:09 +01001519 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001520 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1521 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001522 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001523 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1524 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001525 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001526 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1527 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001528 // scratchBufferTensor
1529 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001530 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1531 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001532 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001533 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1534 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001535 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001536 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1537 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001538 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001539 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1540 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001541
1542
1543 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1544 if ( m_InputToInputWeights )
1545 {
1546 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1547 (n_cell * n_input), "InputLayerNormWeights");
1548 }
1549
1550 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1551 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1552 (n_cell * n_input), "InputToForgetWeights");
1553
1554 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1555 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1556 (n_cell * n_input), "InputToCellWeights");
1557
1558 if ( m_RecurrentToInputWeights )
1559 {
1560 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1561 (n_cell * n_output), "RecurrentToInputWeights");
1562 }
1563
1564 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1565 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1566 (n_cell * n_output), "RecurrentToForgetWeights");
1567
1568 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1569 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1570 (n_cell * n_output), "RecurrentToCellWeights");
1571
1572 // Make sure the input-gate's parameters are either both present (regular
1573 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1574 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1575 !m_Parameters.m_CifgEnabled) ||
1576 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1577 m_Parameters.m_CifgEnabled));
1578 if (!cifg_weights_all_or_none)
1579 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001580 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1581 "RecurrentToInputWeights must either both be present (regular LSTM) "
1582 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1583 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001584 }
1585
1586 if ( m_CellToInputWeights )
1587 {
1588 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1589 n_cell, "CellToInputWeights");
1590 }
1591 if ( m_CellToForgetWeights )
1592 {
1593 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1594 n_cell, "CellToForgetWeights");
1595 }
1596 if ( m_CellToOutputWeights )
1597 {
1598 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1599 n_cell, "CellToOutputWeights");
1600 }
1601
1602 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1603 bool peephole_weights_all_or_none =
1604 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1605 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1606 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1607 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1608 if (!peephole_weights_all_or_none)
1609 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001610 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001611 }
1612
1613 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1614 if (m_Parameters.m_CifgEnabled)
1615 {
1616 if (m_InputGateBias)
1617 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001618 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001619 }
1620 }
1621 else
1622 {
1623 if (!m_InputGateBias)
1624 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001625 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1626 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001627 }
1628 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1629 n_cell, "InputGateBias");
1630 }
1631
1632 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1633 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1634
1635 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1636 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1637
1638 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1639 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1640
1641 if (m_ProjectionWeights)
1642 {
1643 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
1644 (n_cell * n_output), "ProjectionWeights");
1645 }
1646 if (m_ProjectionBias)
1647 {
1648 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
1649 }
1650
1651 // Making sure the projection tensors are consistent:
1652 // 1) If projection weight is not present, then projection bias should not be
1653 // present.
1654 // 2) If projection weight is present, then projection bias is optional.
1655 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
1656 !m_Parameters.m_ProjectionEnabled)
1657 || (m_ProjectionWeights && !m_ProjectionBias &&
1658 m_Parameters.m_ProjectionEnabled)
1659 || (m_ProjectionWeights && m_ProjectionBias &&
1660 m_Parameters.m_ProjectionEnabled));
1661 if (!projecton_tensors_consistent)
1662 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001663 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001664 }
1665
1666 // The four layer normalization weights either all have values or none of them have values. Additionally, if
1667 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
1668 // either all have values or none of them have values. Layer normalization is used when the values of all the
1669 // layer normalization weights are present
1670 if (m_InputLayerNormWeights)
1671 {
1672 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
1673 }
1674 if (m_ForgetLayerNormWeights)
1675 {
1676 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1677 }
1678 if (m_CellLayerNormWeights)
1679 {
1680 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1681 }
1682 if (m_OutputLayerNormWeights)
1683 {
1684 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1685 }
1686
Jan Eilers38e05bd2019-06-26 13:10:09 +01001687 if (m_Parameters.m_LayerNormEnabled)
1688 {
1689 if (!m_Parameters.m_CifgEnabled)
1690 {
1691 if (!m_InputLayerNormWeights)
1692 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001693 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
1694 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001695 }
1696 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
1697 1, n_cell, "InputLayerNormWeights");
1698 }
1699 else if (m_InputLayerNormWeights)
1700 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001701 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
1702 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001703 }
1704
1705 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
1706 "ForgetLayerNormWeights");
1707 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1708
1709 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
1710 "OutputLayerNormWeights");
1711 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1712
1713 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
1714 "CellLayerNormWeights");
1715 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1716 }
1717 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
1718 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001719 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
1720 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001721 }
telsoa01c577f2c2018-08-31 09:22:23 +01001722}
1723
1724void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1725{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001726 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001727
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001728 ValidateNumInputs(workloadInfo, descriptorName, 1);
1729 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1730
1731 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1732 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1733
1734 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01001735 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001736 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01001737 }
1738
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001739 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001740 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001741 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001742 }
1743
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001744 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001745}
1746
1747void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1748{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001749 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001750
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001751 ValidateNumInputs(workloadInfo, descriptorName, 1);
1752 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1753
1754 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1755 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1756
1757 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001758 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001759 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001760 }
1761
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001762 if (outputTensorInfo.GetDataType() != DataType::Float32)
1763 {
1764 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
1765 }
1766
1767 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001768}
1769
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001770void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1771{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001772 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001773
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001774 ValidateNumInputs(workloadInfo, descriptorName, 2);
1775 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1776
1777 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1778 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1779 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1780
1781 std::vector<DataType> supportedTypes =
1782 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001783 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001784 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001785 DataType::QuantisedSymm16,
1786 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001787 };
1788
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001789 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1790 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1791 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001792
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001793 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1794 inputTensorInfo1,
1795 outputTensorInfo,
1796 descriptorName,
1797 "input_0",
1798 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001799}
1800
David Beckc2044fe2018-09-05 15:00:38 +01001801void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1802{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001803 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01001804
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001805 ValidateNumInputs(workloadInfo, descriptorName, 2);
1806 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1807
1808 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1809 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1810 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1811
1812 std::vector<DataType> supportedTypes =
1813 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001814 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001815 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001816 DataType::QuantisedSymm16,
1817 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001818 };
1819
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001820 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1821 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1822 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001823
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001824 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1825 inputTensorInfo1,
1826 outputTensorInfo,
1827 descriptorName,
1828 "input_0",
1829 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01001830}
1831
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001832void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1833{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001834 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001835
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001836 ValidateNumInputs(workloadInfo, descriptorName, 2);
1837 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1838
1839 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1840 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1841 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1842
1843 std::vector<DataType> supportedTypes =
1844 {
Mike Kelly1da02362019-08-01 08:43:57 +01001845 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001846 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01001847 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001848 DataType::QuantisedAsymm8,
1849 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001850 };
1851
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001852 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1853 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1854 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001855
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001856 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1857 inputTensorInfo1,
1858 outputTensorInfo,
1859 descriptorName,
1860 "input_0",
1861 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001862}
1863
narpra01a6bf9122018-09-10 09:50:09 +01001864void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1865{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001866 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01001867
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001868 ValidateNumInputs(workloadInfo, descriptorName, 1);
1869 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1870
1871 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1872 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01001873
1874 std::vector<DataType> supportedTypes =
1875 {
1876 DataType::Float32,
1877 DataType::Float16,
1878 DataType::QuantisedAsymm8,
1879 DataType::QuantisedSymm16
1880 };
narpra01eb061912018-09-10 17:35:27 +01001881
James Conroy4d1ff582019-06-10 17:06:39 +01001882 // First check if input tensor data type is supported, then
1883 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001884 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1885 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01001886
narpra0132b90462018-09-13 11:07:48 +01001887 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01001888 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001889 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01001890 }
narpra0132b90462018-09-13 11:07:48 +01001891 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01001892 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001893 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01001894 }
1895 else
1896 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001897 unsigned int outputDim =
1898 inputTensorInfo.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
1899 ValidateTensorNumDimensions(outputTensorInfo,
1900 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01001901 outputDim > 0 ? outputDim : 1,
1902 "output");
1903 }
narpra01a6bf9122018-09-10 09:50:09 +01001904}
1905
jimfly012c9322a2018-09-19 10:59:49 +01001906void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1907{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001908 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01001909
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001910 ValidateNumInputs(workloadInfo, descriptorName, 1);
1911 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1912
1913 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1914 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01001915
jimfly012c9322a2018-09-19 10:59:49 +01001916 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001917 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
1918
jimfly012c9322a2018-09-19 10:59:49 +01001919 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001920 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
1921 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
1922 "as there are dimensions in the input tensor that is " +
1923 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
1924 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01001925 }
1926}
1927
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001928void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1929{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001930 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001931
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001932 ValidateNumInputs(workloadInfo, descriptorName, 1);
1933 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001934
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001935 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1936 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1937
Sadik Armagan2208b602019-07-31 16:36:27 +01001938 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001939 {
James Conroyd47a0642019-09-17 14:22:06 +01001940 DataType::Float32,
1941 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01001942 };
1943
1944 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001945
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001946 if (outputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
1947 outputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001948 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001949 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001950 }
1951}
1952
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00001953void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1954{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001955 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001956
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001957 ValidateNumInputs(workloadInfo, descriptorName, 1);
1958 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001959
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001960 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1961 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001962
1963 std::vector<DataType> supportedTypes =
1964 {
James Conroyd47a0642019-09-17 14:22:06 +01001965 DataType::Float32,
1966 DataType::Float16,
1967 DataType::QuantisedAsymm8,
1968 DataType::QuantisedSymm16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001969 };
1970
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001971 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1972 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00001973}
1974
Conor Kennedy430b5d82018-11-14 15:28:28 +00001975void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1976{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001977 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00001978
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001979 ValidateNumInputs(workloadInfo, descriptorName, 1);
1980 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1981
1982 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1983 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001984
1985 std::vector<DataType> supportedTypes =
1986 {
1987 DataType::Float16,
1988 DataType::Float32,
Matteo Martincigh42666a12019-05-29 08:53:41 +01001989 DataType::QuantisedAsymm8,
1990 DataType::QuantisedSymm16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001991 };
1992
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001993 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1994 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001995
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001996 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001997
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001998 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00001999 if (rank > 4)
2000 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002001 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002002 }
2003
Conor Kennedy430b5d82018-11-14 15:28:28 +00002004 // Begin, End & Stride length must be of rank(input0)
2005 if (m_Parameters.m_Begin.size() != rank)
2006 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002007 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002008 }
2009
2010 if (m_Parameters.m_End.size() != rank)
2011 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002012 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002013 }
2014
2015 if (m_Parameters.m_Stride.size() != rank)
2016 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002017 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002018 }
2019
2020 // Stride entries must be non-zero
2021 for (auto& stride : m_Parameters.m_Stride)
2022 {
2023 if (stride == 0)
2024 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002025 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002026 }
2027 }
2028}
2029
kevmay0190539692018-11-29 08:40:19 +00002030void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2031{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002032 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002033
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002034 ValidateNumInputs(workloadInfo, descriptorName, 2);
2035 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2036
2037 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2038 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2039 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2040
2041 std::vector<DataType> supportedTypes =
2042 {
Mike Kelly1da02362019-08-01 08:43:57 +01002043 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002044 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002045 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01002046 DataType::QuantisedAsymm8,
2047 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002048 };
2049
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002050 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2051 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2052 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002053
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002054 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2055 inputTensorInfo1,
2056 outputTensorInfo,
2057 descriptorName,
2058 "input_0",
2059 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002060}
2061
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002062void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2063{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002064 const std::string descriptorName{"DebugQueueDescriptor"};
2065
2066 ValidateNumInputs(workloadInfo, descriptorName, 1);
2067 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002068}
2069
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002070void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2071{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002072 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002073
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002074 ValidateNumInputs(workloadInfo, descriptorName, 2);
2075 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002076
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002077 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2078 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2079 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2080
2081 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2082 inputTensorInfo1,
2083 outputTensorInfo,
2084 descriptorName,
2085 "input_0",
2086 "input_1");
2087
2088 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002089 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002090 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002091 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002092}
2093
FrancisMurtagh878f0232018-12-19 10:56:15 +00002094void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2095{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002096 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002097
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002098 ValidateNumInputs(workloadInfo, descriptorName, 2);
2099 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002100
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002101 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2102 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2103 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2104
2105 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2106 inputTensorInfo1,
2107 outputTensorInfo,
2108 descriptorName,
2109 "input_0",
2110 "input_1");
2111
2112 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002113 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002114 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002115 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002116}
2117
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002118void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2119{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002120 const std::string descriptorName{"RsqrtQueueDescriptor"};
2121
2122 ValidateNumInputs(workloadInfo, descriptorName, 1);
2123 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2124
2125 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2126 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2127
2128 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002129
2130 std::vector<DataType> supportedTypes =
2131 {
James Conroyd47a0642019-09-17 14:22:06 +01002132 DataType::Float16,
2133 DataType::Float32,
2134 DataType::QuantisedAsymm8,
2135 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01002136 };
2137
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002138 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2139 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002140}
2141
narpra01b89b05f2019-01-16 09:53:09 +00002142void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2143{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002144 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002145
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002146 ValidateNumInputs(workloadInfo, descriptorName, 2);
2147 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002148
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002149 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2150 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002151 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002152 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002153 }
2154
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002155 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2156 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2157
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002158 std::vector<DataType> supportedTypes =
2159 {
James Conroyd47a0642019-09-17 14:22:06 +01002160 DataType::Float16,
2161 DataType::Float32,
2162 DataType::QuantisedAsymm8,
2163 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002164 };
2165
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002166 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002167
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002168 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002169
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002170 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2171 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002172}
2173
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002174void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2175{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002176 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2177
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002178 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002179
2180 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2181 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002182 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002183 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2184 }
2185
2186 if (m_Anchors == nullptr)
2187 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002188 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002189 }
2190
2191 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002192 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2193 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2194
2195 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002196 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002197 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2198 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002199
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002200 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2201 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2202 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002203
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002204 const std::vector<DataType> supportedInputTypes =
2205 {
2206 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002207 DataType::Float16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002208 DataType::QuantisedAsymm8,
2209 DataType::QuantisedSymm16
2210 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002211
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002212 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2213 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2214 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2215
2216 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2217 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2218 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2219 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2220
2221 // NOTE: Output is always Float32 regardless of input type
2222 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2223 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2224 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2225 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002226
2227 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2228 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002229 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002230 "must be positive and less than or equal to 1.");
2231 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002232
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002233 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2234 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002235 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002236 "should be equal to number of classes + 1.");
2237 }
2238}
2239
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002240void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2241{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002242 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002243
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002244 ValidateNumInputs(workloadInfo, descriptorName, 1);
2245 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2246
2247 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2248 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2249
2250 if (inputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
2251 inputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002252 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002253 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002254 }
2255
Sadik Armagan2208b602019-07-31 16:36:27 +01002256 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002257 {
James Conroyd47a0642019-09-17 14:22:06 +01002258 DataType::Float32,
2259 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002260 };
2261
2262 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002263}
2264
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002265void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2266{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002267 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002268
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002269 ValidateNumInputs(workloadInfo, descriptorName, 2);
2270 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002271
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002272 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2273 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2274 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002275
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002276 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2277 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2278
2279 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2280 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002281}
2282
Sadik Armaganeff363d2019-04-05 15:25:46 +01002283void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2284{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002285 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002286
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002287 ValidateNumInputs(workloadInfo, descriptorName, 2);
2288 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2289
2290 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2291 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2292
2293 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2294 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2295
2296 std::vector<DataType> supportedTypes =
2297 {
Sadik Armaganeff363d2019-04-05 15:25:46 +01002298 DataType::Float32,
2299 DataType::QuantisedAsymm8,
2300 DataType::QuantisedSymm16
2301 };
2302
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002303 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2304 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002305
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002306 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2307 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002308
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002309 ValidateTensorShapesMatch(inputTensorInfo0,
2310 outputTensorInfo0,
2311 descriptorName,
2312 "input_0",
2313 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002314
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002315 ValidateTensorShapesMatch(inputTensorInfo0,
2316 outputTensorInfo1,
2317 descriptorName,
2318 "input_0",
2319 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002320}
2321
Matteo Martincigh49124022019-01-11 13:25:59 +00002322void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2323{
2324 // This is internally generated so it should not need validation.
2325}
2326
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002327void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2328{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002329 const std::string& descriptorName{"PreluQueueDescriptor"};
2330
2331 ValidateNumInputs(workloadInfo, descriptorName, 2);
2332 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2333
2334 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2335 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2336 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002337
2338 std::vector<DataType> supportedTypes
2339 {
2340 DataType::Float16,
2341 DataType::Float32,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002342 DataType::QuantisedAsymm8,
2343 DataType::QuantisedSymm16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002344 };
2345
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002346 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2347 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002348
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002349 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002350
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002351 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2352 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002353
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002354 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2355 alphaTensorInfo,
2356 outputTensorInfo,
2357 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002358 "input",
2359 "alpha");
2360}
2361
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002362void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2363{
2364 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2365
2366 ValidateNumInputs(workloadInfo, descriptorName, 1);
2367 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2368
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002369 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2370 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2371
2372 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2373 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002374
2375 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002376
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002377 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2378 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
2379 ValidateTensorDataType(weightTensorInfo, inputTensorInfo.GetDataType(), descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002380
2381 if (m_Parameters.m_BiasEnabled)
2382 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002383 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002384
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002385 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
2386 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002387
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002388 ValidateTensorDataType(biasTensorInfo,
2389 GetBiasDataType(inputTensorInfo.GetDataType()),
2390 descriptorName,
2391 "bias");
2392
2393 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002394 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002395}
2396
James Conroy9c3cae82019-08-01 16:01:48 +01002397void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2398{
2399 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
2400
2401 // Validate number of inputs/outputs
2402 ValidateNumInputs(workloadInfo, descriptorName, 3);
2403 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2404
2405 // Input/output tensor infos
2406 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2407 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
2408 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
2409
2410 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2411 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2412
2413 std::vector<DataType> inputOutputSupportedTypes =
2414 {
2415 DataType::QuantisedAsymm8
2416 };
2417
2418 std::vector<DataType> cellStateSupportedTypes =
2419 {
2420 DataType::QuantisedSymm16
2421 };
2422
2423 std::vector<DataType> weightsSupportedTypes =
2424 {
2425 DataType::QuantisedAsymm8
2426 };
2427
2428 std::vector<DataType> biasSupportedTypes =
2429 {
2430 DataType::Signed32
2431 };
2432
2433 // Validate types of input/output tensors
2434 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2435 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2436 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2437
2438 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2439 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2440
2441 // Validate matching types of input/output tensors
2442 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2443 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2444 "outputStateIn", "outputStateOut");
2445 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2446
2447 // Validate matching quantization info for input/output tensors
2448 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2449 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
2450 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002451
James Conroy9c3cae82019-08-01 16:01:48 +01002452 // Infer number of batches, input size and output size from tensor dimensions
2453 const uint32_t numBatches = inputInfo.GetShape()[0];
2454 const uint32_t inputSize = inputInfo.GetShape()[1];
2455 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
2456
2457 // Validate number of dimensions and number of elements for input/output tensors
2458 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2459 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
2460 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2461 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
2462 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2463
2464 // Validate number of dimensions and number of elements for weights tensors
2465 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
2466 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
2467 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
2468
2469 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2470 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2471 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
2472
2473 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2474 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2475 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
2476
2477 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2478 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2479 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
2480
2481 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
2482 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
2483 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
2484
2485 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2486 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2487 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
2488 " RecurrentToForgetWeights");
2489
2490 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2491 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2492 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2493
2494 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2495 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2496 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2497
2498 // Validate data types for weights tensors (all should match each other)
2499 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
2500
2501 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
2502 "inputToInputWeights", "inputToForgetWeights");
2503 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
2504 "inputToInputWeights", "inputToCellWeights");
2505 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2506 "inputToInputWeights", "inputToOutputWeights");
2507
2508 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
2509 "inputToInputWeights", "recurrentToInputWeights");
2510 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2511 "inputToInputWeights", "recurrentToForgeteights");
2512 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2513 "inputToInputWeights", "recurrentToCellWeights");
2514 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2515 "inputToInputWeights", "recurrentToOutputWeights");
2516
2517 // Validate matching quantization info for weight tensors (all should match each other)
2518 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
2519 descriptorName, "inputToInputWeights", "inputToForgetWeights");
2520 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
2521 descriptorName, "inputToInputWeights", "inputToCellWeights");
2522 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
2523 descriptorName, "inputToInputWeights", "inputToOutputWeights");
2524
2525 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
2526 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
2527 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
2528 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
2529 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
2530 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
2531 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
2532 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
2533
2534 // Validate number of dimensions and number of elements in bias tensors
2535 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
2536 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
2537 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
2538
2539 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2540 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2541 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
2542
2543 ValidatePointer(m_CellBias, descriptorName, "CellBias");
2544 auto cellBiasInfo = m_CellBias->GetTensorInfo();
2545 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
2546
2547 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
2548 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
2549 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
2550
2551 // Validate data types for bias tensors (all should match each other)
2552 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
2553
2554 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
2555 "inputGateBias", "forgetGateBias");
2556 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
2557 "inputGateBias", "cellBias");
2558 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
2559 "inputGateBias", "outputGateBias");
2560
2561 // Validate bias tensor quantization info
2562 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2563 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2564 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2565 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2566}
2567
Kevin May868eb142019-09-04 17:29:31 +01002568void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2569{
2570 const std::string descriptorName{"AbsQueueDescriptor"};
2571
2572 ValidateNumInputs(workloadInfo, descriptorName, 1);
2573 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2574
2575 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2576 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2577
2578 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2579
2580 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01002581 {
2582 DataType::Float16,
2583 DataType::Float32,
2584 DataType::QuantisedAsymm8,
2585 DataType::QuantisedSymm16
2586 };
Kevin May868eb142019-09-04 17:29:31 +01002587
2588 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2589 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2590}
2591
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002592void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2593{
2594 const std::string descriptorName{"SliceQueueDescriptor"};
2595
2596 ValidateNumInputs(workloadInfo, descriptorName, 1);
2597 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2598
2599 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2600 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2601
2602 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2603
2604 const unsigned int rank = inputTensorInfo.GetNumDimensions();
2605 if (rank > 4)
2606 {
2607 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
2608 }
2609
2610 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
2611
2612 // Check if m_Begin and m_Size have the expected length
2613 if (m_Parameters.m_Begin.size() != rank)
2614 {
2615 throw InvalidArgumentException(descriptorName +
2616 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
2617 }
2618 if (m_Parameters.m_Size.size() != rank)
2619 {
2620 throw InvalidArgumentException(descriptorName +
2621 ": Length of size descriptor must equal rank " + std::to_string(rank));
2622 }
2623
2624 // Check if the shape of the output tensor matches m_Size
2625 const TensorShape& outputShape = outputTensorInfo.GetShape();
2626 for (unsigned int i = 0u; i < rank; ++i)
2627 {
2628 if (m_Parameters.m_Size[i] != outputShape[i])
2629 {
2630 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
2631 }
2632 }
2633
2634 // Check if the sum of begin offset and size in a given dimension
2635 // does not exceed the size of corresponding input
2636 const TensorShape& inputShape = inputTensorInfo.GetShape();
2637 for(unsigned int i = 0u; i < rank; ++i)
2638 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002639 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002640 {
2641 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
2642 std::to_string(i) + " exceeds input size.");
2643 }
2644 }
2645}
2646
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01002647void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2648{
2649 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
2650
2651 ValidateNumInputs(workloadInfo, descriptorName, 1);
2652 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2653
2654 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
2655 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
2656
2657 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
2658 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
2659
2660 std::vector<DataType> supportedTypes =
2661 {
2662 DataType::Float32,
2663 DataType::Float16,
2664 DataType::QuantisedAsymm8,
2665 DataType::QuantisedSymm16
2666 };
2667
2668 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
2669 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
2670
2671 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
2672
2673 if (m_Parameters.m_BlockSize == 0)
2674 {
2675 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
2676 }
2677
2678 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
2679 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
2680 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
2681 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
2682
2683 const TensorShape& outputShape = outputInfo.GetShape();
2684 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
2685 {
2686 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
2687 "must be divisible by block size.");
2688 }
2689
2690 const TensorShape& inputShape = inputInfo.GetShape();
2691 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
2692 {
2693 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
2694 "must be divisible by the square of block size." );
2695 }
2696}
2697
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002698} // namespace armnn