blob: fed159bd606016c6aa9a5edba9ef4e36d0279323 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "WorkloadData.hpp"
6
7#include "CpuTensorHandle.hpp"
telsoa014fcda012018-03-09 14:13:49 +00008
Matteo Martincigh21350152018-11-28 16:22:22 +00009#include <DataLayoutIndexed.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000013#include <string>
14#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/format.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010017#include <boost/numeric/conversion/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Matteo Martincigh21350152018-11-28 16:22:22 +000019using namespace armnnUtils;
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24//---------------------------------------------------------------
25DataType GetBiasDataType(DataType inputDataType)
26{
27 switch (inputDataType)
28 {
telsoa01c577f2c2018-08-31 09:22:23 +010029 case DataType::Float16:
30 return DataType::Float16;
telsoa014fcda012018-03-09 14:13:49 +000031 case DataType::Float32:
32 return DataType::Float32;
33 case DataType::QuantisedAsymm8:
34 return DataType::Signed32;
Ruomei Yan88d44b82019-05-23 14:29:06 +010035 case DataType::QuantisedSymm16:
36 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000037 default:
38 BOOST_ASSERT_MSG(false, "Invalid input data type");
39 return DataType::Float32;
40 }
41}
42
43namespace
44{
45
46//---------------------------------------------------------------
47//android ndk does not support std::to_string function.
48template <typename T>
49std::string to_string(T value)
50{
51 std::ostringstream os;
52 os << value;
53 return os.str();
54}
55
56//---------------------------------------------------------------
57void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
58{
59 if (!ptr)
60 {
61 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
62 paramName + " parameter must be set.");
63 }
64}
65
66//---------------------------------------------------------------
67void ValidateTensorShapesMatch(const TensorInfo& first,
68 const TensorInfo& second,
69 std::string const& descName,
70 std::string const& firstName,
71 std::string const& secondName)
72{
73 if (first.GetShape() != second.GetShape())
74 {
75 throw InvalidArgumentException(descName + ": "
76 + firstName + " & " + secondName + " must have identical shapes");
77 }
78}
79
80//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010081void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000082{
Sadik Armaganeff363d2019-04-05 15:25:46 +010083 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000084 {
85 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010086 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000087 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
88 }
89}
90
91//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010092void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000093{
Sadik Armaganeff363d2019-04-05 15:25:46 +010094 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000095 {
96 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010097 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +000098 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
99 }
100}
101
102//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100103void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000104 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100105 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000106 std::string const& tensorName)
107{
108 if (tensor.GetNumDimensions() != numDimensions)
109 {
110 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
111 to_string(tensor.GetNumDimensions()) + " dimensions for " +
112 tensorName + " tensor.");
113 }
114}
115
116//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100117void ValidateTensorNumElements(const TensorInfo& tensor,
118 std::string const& descName,
119 unsigned int numElements,
120 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100121{
122 if (tensor.GetNumElements() != numElements)
123 {
124 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100125 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100126 tensorName + " tensor.");
127 }
128}
129
130//---------------------------------------------------------------
131void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100132 unsigned int numDimension,
133 unsigned int numElements,
134 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100135{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100136 const std::string functionName{"ValidateTensorNumDimNumElem"};
137 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
138 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100139}
140
141//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000142void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
143 const std::string& descName, std::string const& tensorName)
144{
145 if (tensor.GetDataType() != dataType)
146 {
147 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
148 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
149 }
150}
151
152//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100153void ValidateTensorQuantizationSpace(const TensorInfo& first,
154 const TensorInfo& second,
155 const std::string& descName,
156 std::string const& firstName,
157 std::string const& secondName)
158{
159 if (!first.IsQuantized() ||
160 !second.IsQuantized())
161 {
162 // Not a quantized type, ignore the validation
163 return;
164 }
165
166 DataType firstDataType = first.GetDataType();
167 DataType secondDataType = second.GetDataType();
168
169 if (firstDataType != secondDataType)
170 {
171 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
172 " must be of the same quantized type, " +
173 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
174 secondName + " is " + GetDataTypeName(secondDataType));
175 }
176
177 if (!first.IsTypeSpaceMatch(second))
178 {
179 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
180 " must have the same quantization space, " +
181 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
182 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
183 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
184 " and scale " + to_string(second.GetQuantizationScale()));
185 }
186}
187
188//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100189void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
190 const TensorInfo& inputTensorInfo,
191 const TensorInfo& weightsTensorInfo,
192 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000193{
194 if (biasTensor.GetQuantizationOffset() != 0)
195 {
196 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
197 to_string(biasTensor.GetQuantizationOffset()));
198 }
199 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
kevmay016c46dd32018-12-17 15:32:45 +0000200 if (std::abs(biasTensor.GetQuantizationScale() - expectedScale) > 0.00000001f)
telsoa014fcda012018-03-09 14:13:49 +0000201 {
202 // Print the float values with extra precision to see very small differences
203 std::stringstream msg;
204 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
205 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
206 biasTensor.GetQuantizationScale();
207 throw InvalidArgumentException(msg.str());
208 }
209}
210
211//---------------------------------------------------------------
212void ValidateTensors(const std::vector<ITensorHandle*>& vec,
213 unsigned int numExpected,
214 const std::string& descName,
215 const std::string& varName)
216{
217 if (vec.empty() && numExpected > 0)
218 {
219 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
220 }
221
222 for (unsigned int i = 0; i < numExpected; ++i)
223 {
224 if (!vec[i])
225 {
226 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
227 }
228 }
229}
230
231//---------------------------------------------------------------
232void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
233 const TensorInfo& second,
234 const TensorInfo& output,
235 std::string const& descName,
236 std::string const& firstName,
237 std::string const& secondName)
238{
239 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
240 // broadcasted.
241 if (first.GetNumDimensions() != second.GetNumDimensions())
242 {
243 throw InvalidArgumentException(descName + ": Tensors "
244 + firstName + " & " + secondName
245 + " must have the same number of dimensions in order to be broadcasted");
246 }
247 uint32_t numDims = first.GetNumDimensions();
248 std::vector<uint32_t> outputDims(numDims, 0u);
249 for (uint32_t i = 0; i < numDims; i++)
250 {
251 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
252 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
253 if (dimsNotEqual && dimsNotOne)
254 {
255 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
256 }
257 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
258 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100259 TensorShape broadcastShape = TensorShape(boost::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000260 if (broadcastShape != output.GetShape())
261 {
262 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
263 + firstName + " & " + secondName
264 + " does not match the output shape");
265 }
266}
267
268//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100269void ValidateDataTypes(const TensorInfo& info,
270 const std::vector<armnn::DataType>& supportedTypes,
271 std::string const& descName)
272{
273 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
274 if (iterator == supportedTypes.end())
275 {
276 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
277 }
278}
279
James Conroy4d1ff582019-06-10 17:06:39 +0100280//---------------------------------------------------------------
281void ValidateTensorDataTypesMatch(const TensorInfo& first,
282 const TensorInfo& second,
283 std::string const& descName,
284 std::string const& firstName,
285 std::string const& secondName)
286{
287 if (first.GetDataType() != second.GetDataType())
288 {
289 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
290 " must have identical data types.");
291 }
292}
293
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100294//---------------------------------------------------------------
295void ValidateTensorNumElementsMatch(const TensorInfo& first,
296 const TensorInfo& second,
297 std::string const& descName,
298 std::string const& firstName,
299 std::string const& secondName)
300{
301 if (first.GetNumElements() != second.GetNumElements())
302 {
303 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
304 " must have the same number of elements.");
305 }
306}
307
308} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000309
310void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
311 unsigned int numExpectedIn, unsigned int numExpectedOut) const
312{
313 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
314 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
315}
316
317//---------------------------------------------------------------
318void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
319{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100320 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000321
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100322 ValidateNumInputs(workloadInfo, descriptorName, 1);
323 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000324
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100325 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
326 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
327
328 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
329 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000330
331 if (m_Inputs.size() != m_Outputs.size())
332 {
333 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100334 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
335 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000336 }
337
338 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
339 {
340 if (!m_Inputs[i])
341 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100342 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
343 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000344 }
345
346 if (!m_Outputs[i])
347 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100348 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
349 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000350 }
351 }
352}
353
Derek Lambertif674aa02019-08-01 15:56:25 +0100354//---------------------------------------------------------------
355void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
356{
357 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
358 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
359
360 if (workloadInfo.m_InputTensorInfos.size() != 1)
361 {
362 throw InvalidArgumentException(boost::str(
363 boost::format("Number of input infos (%1%) is not 1.")
364 % workloadInfo.m_InputTensorInfos.size()));
365
366 }
367
368 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
369 {
370 throw InvalidArgumentException(boost::str(
371 boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
372 % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
373 }
374
375 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
376 {
377 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
378 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
379 {
380 throw InvalidArgumentException(boost::str(
381 boost::format("Number of elements for tensor input and output %1% does not match")
382 % i ));
383 }
384 }
385
386 if (m_Inputs.size() != 1)
387 {
388 throw InvalidArgumentException(boost::str(
389 boost::format("Number of inputs (%1%) is not 1.")
390 % m_Inputs.size()));
391 }
392
393 if (m_Inputs.size() != m_Outputs.size())
394 {
395 throw InvalidArgumentException(boost::str(
396 boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
397 % m_Inputs.size() % m_Outputs.size()));
398 }
399
400 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
401 {
402 if (!m_Inputs[i])
403 {
404 throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
405 }
406
407 if (!m_Outputs[i])
408 {
409 throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
410 }
411 }
412}
413
414//---------------------------------------------------------------
415void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
416{
417 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
418 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
419
Derek Lambertif674aa02019-08-01 15:56:25 +0100420 if (m_Inputs.size() != 1)
421 {
422 throw InvalidArgumentException(boost::str(
423 boost::format("Number of inputs (%1%) is not 1.")
424 % m_Inputs.size()));
425 }
426
427 if (m_Outputs.size() != 0)
428 {
429 throw InvalidArgumentException(boost::str(
430 boost::format("Number of outputs (%1%) is not 0.")
431 % m_Inputs.size() % m_Outputs.size()));
432 }
433
434 if (!m_Inputs[0])
435 {
436 throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
437 }
438}
439
440//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000441void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
442{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100443 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100444
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100445 ValidateNumInputs(workloadInfo, descriptorName, 1);
446 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100447
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100448 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
449 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100450
451 std::vector<DataType> supportedTypes =
452 {
453 DataType::Float16,
454 DataType::Float32,
455 DataType::QuantisedAsymm8,
456 DataType::QuantisedSymm16
457 };
458
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100459 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
460 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
461 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000462}
463
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100464void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
465{
466 const std::string descriptorName{"SoftmaxQueueDescriptor"};
467
468 ValidateNumInputs(workloadInfo, descriptorName, 1);
469 ValidateNumOutputs(workloadInfo, descriptorName, 1);
470
471 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
472 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
473
474 std::vector<DataType> supportedTypes =
475 {
476 DataType::Float16,
477 DataType::Float32,
478 DataType::QuantisedAsymm8,
479 DataType::QuantisedSymm16
480 };
481
482 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
483 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
484 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
485}
486
telsoa014fcda012018-03-09 14:13:49 +0000487void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
488{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100489 const std::string descriptorName{"SplitterQueueDescriptor"};
490
491 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000492
Ruomei Yan25339c32019-05-28 16:48:20 +0100493 // Check the supported data types
494 std::vector<DataType> supportedTypes =
495 {
496 DataType::Float32,
497 DataType::Float16,
498 DataType::Boolean,
499 DataType::Signed32,
500 DataType::QuantisedAsymm8,
501 DataType::QuantisedSymm16
502 };
503
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100504 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
505 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100506 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100507 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
508 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
509
510 const std::string outputName = "output_" + std::to_string(i);
511 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100512 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100513
telsoa014fcda012018-03-09 14:13:49 +0000514 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
515 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100516 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000517 }
518
519 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
520 {
521 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100522 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000523 "has to match number of workloadInfo.m_OutputTensorInfos. "
524 "Number of windows: " +
525 to_string(m_ViewOrigins.size()) +
526 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
527 }
528
telsoa01c577f2c2018-08-31 09:22:23 +0100529 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000530 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
531 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
532 {
telsoa01c577f2c2018-08-31 09:22:23 +0100533 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000534 ViewOrigin const& e = m_ViewOrigins[w];
535 if (e.m_Origin.size() != inputDims)
536 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100537 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000538 "have the same dimensionality as the input tensor. "
539 "Window origin (index: " +
540 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
541 " dimensions, the input "
542 "tensor has " +
543 to_string(inputDims) + " dimensions.");
544 }
545 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
546 {
547 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
548 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
549 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100550 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000551 "be smaller or equal than the size of the input in that coord.");
552 }
553 }
554 }
555}
556
Jim Flynne242f2d2019-05-22 14:24:13 +0100557void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000558{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100559 const std::string descriptorName{"ConcatQueueDescriptor"};
560
561 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000562
563 if (m_Inputs.size() <= 0)
564 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100565 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000566 }
567 if (m_Outputs.size() <= 0)
568 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100569 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000570 }
571
572 if (workloadInfo.m_InputTensorInfos.size() <= 0)
573 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100574 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000575 }
576 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
577 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100578 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000579 }
580
Nikhil Raj8599a412018-11-19 14:51:07 +0000581 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
582 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100583 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000584 }
585
586 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
587 {
588 return;
589 }
590
telsoa014fcda012018-03-09 14:13:49 +0000591 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
592 {
593 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100594 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000595 "has to match number of workloadInfo.m_InputTensorInfos. "
596 "Number of windows: " +
597 to_string(m_ViewOrigins.size()) +
598 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
599 }
600
telsoa01c577f2c2018-08-31 09:22:23 +0100601 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000602 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
603 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
604 {
telsoa01c577f2c2018-08-31 09:22:23 +0100605 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000606 ViewOrigin const& e = m_ViewOrigins[w];
607 if (e.m_Origin.size() != outputDims)
608 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100609 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000610 "have the same dimensionality as the output tensor. "
611 "Window origin (index: " +
612 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
613 " dimensions, the output "
614 "tensor has " +
615 to_string(outputDims) + " dimensions.");
616 }
telsoa01c577f2c2018-08-31 09:22:23 +0100617 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000618 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
619 {
620 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
621 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
622 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100623 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000624 "be smaller or equal than the size of the output in that coord.");
625 }
626 }
627 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100628
629 // Check the supported data types
630 std::vector<DataType> supportedTypes =
631 {
632 DataType::Float32,
633 DataType::Float16,
634 DataType::Boolean,
635 DataType::Signed32,
636 DataType::QuantisedAsymm8,
637 DataType::QuantisedSymm16
638 };
639
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100640 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
641 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100642 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100643 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
644 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
645
646 const std::string inputName = "input_" + std::to_string(i);
647 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100648 }
telsoa014fcda012018-03-09 14:13:49 +0000649}
650
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100651void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
652{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100653 const std::string descriptorName{"StackQueueDescriptor"};
654
655 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100656
657 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
658 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100659 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100660 }
661
662 // All inputs must have the same shape, which is defined in parameters
663 const TensorShape& inputShape = m_Parameters.m_InputShape;
664 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
665 {
666 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
667 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100668 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100669 }
670 }
671
Matthew Jacksondba634f2019-08-15 15:14:18 +0100672 if (inputShape.GetNumDimensions() > 4)
673 {
674 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
675 }
676
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100677 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
678 // since the output tensor has an additional dimension.
679 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
680 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100681 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100682 "than the number of input dimensions.");
683 }
684
685 // Output shape must be as inferred from the input shape
686 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
687 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
688 {
689 if (outputShape[i] != inputShape[i])
690 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100691 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100692 "match shape inferred from input tensor.");
693 }
694 }
695
696 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
697 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100698 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100699 "match shape inferred from input tensor.");
700 }
701
702 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
703 {
704 if (outputShape[i] != inputShape[i-1])
705 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100706 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100707 "match shape inferred from input tensor.");
708 }
709 }
710
Matthew Jacksondba634f2019-08-15 15:14:18 +0100711 if (outputShape.GetNumDimensions() > 5)
712 {
713 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
714 }
715
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100716 // Check the supported data types
717 std::vector<DataType> supportedTypes =
718 {
719 DataType::Float32,
720 DataType::Float16,
721 DataType::Boolean,
722 DataType::Signed32,
723 DataType::QuantisedAsymm8,
724 DataType::QuantisedSymm16
725 };
726
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100727 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100728
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100729 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100730 {
731 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
732 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100733 descriptorName,
734 "input_0",
735 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100736 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100737
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100738 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
739 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100740 descriptorName,
741 "input_0",
742 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100743}
744
telsoa014fcda012018-03-09 14:13:49 +0000745void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
746{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100747 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000748
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100749 ValidateNumInputs(workloadInfo, descriptorName, 1);
750 ValidateNumOutputs(workloadInfo, descriptorName, 1);
751
752 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
753 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
754
755 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
756
757 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +0000758 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100759 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +0000760 }
761
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100762 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000763
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100764 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
765 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000766
767 if (m_Parameters.m_BiasEnabled)
768 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100769 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000770
telsoa01c577f2c2018-08-31 09:22:23 +0100771 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100772 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
773 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000774
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100775 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
776 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000777 }
778
Francis Murtagh46c09d02019-05-28 08:15:28 +0100779 // Check the supported data types
780 std::vector<DataType> supportedTypes =
781 {
782 DataType::Float32,
783 DataType::Float16,
784 DataType::QuantisedAsymm8,
785 DataType::QuantisedSymm16
786 };
787
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100788 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
789 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000790}
791
telsoa014fcda012018-03-09 14:13:49 +0000792void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
793{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100794 const std::string descriptorName{"NormalizationQueueDescriptor"};
795
796 ValidateNumInputs(workloadInfo, descriptorName, 1);
797 ValidateNumOutputs(workloadInfo, descriptorName, 1);
798
799 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
800 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100801
802 // Check the supported data types
803 std::vector<DataType> supportedTypes =
804 {
805 DataType::Float16,
806 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100807 DataType::QuantisedAsymm8,
808 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100809 };
810
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100811 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100812
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100813 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100814
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100815 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000816}
817
818void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
819{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100820 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000821
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100822 ValidateNumInputs(workloadInfo, descriptorName, 2);
823 ValidateNumOutputs(workloadInfo, descriptorName, 1);
824
825 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
826 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
827 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
828
829 std::vector<DataType> supportedTypes =
830 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100831 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100832 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +0100833 DataType::QuantisedSymm16,
834 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100835 };
836
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100837 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
838 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
839 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100840
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100841 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
842 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100843
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100844 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
845 inputTensorInfo1,
846 outputTensorInfo,
847 descriptorName,
848 "input_0",
849 "input_1");
telsoa014fcda012018-03-09 14:13:49 +0000850}
851
telsoa014fcda012018-03-09 14:13:49 +0000852void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
853{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100854 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +0100855
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100856 ValidateNumInputs(workloadInfo, descriptorName, 2);
857 ValidateNumOutputs(workloadInfo, descriptorName, 1);
858
859 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
860 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
861 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
862
863 std::vector<DataType> supportedTypes =
864 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100865 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100866 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +0100867 DataType::QuantisedSymm16,
868 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100869 };
870
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100871 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
872 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
873 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100874
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100875 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
876 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100877
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100878 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
879 inputTensorInfo1,
880 outputTensorInfo,
881 descriptorName,
882 "input_0",
883 "input_1");
telsoa014fcda012018-03-09 14:13:49 +0000884}
885
886void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
887{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100888 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100889
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100890 ValidateNumInputs(workloadInfo, descriptorName, 1);
891 ValidateNumOutputs(workloadInfo, descriptorName, 1);
892
893 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
894 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100895
896 std::vector<DataType> supportedTypes =
897 {
898 DataType::Float16,
899 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100900 DataType::QuantisedAsymm8,
901 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100902 };
903
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100904 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
905 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100906
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100907 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
908 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
909 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100910
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100911 ValidatePointer(m_Mean, descriptorName, "mean");
912 ValidatePointer(m_Variance, descriptorName, "variance");
913 ValidatePointer(m_Beta, descriptorName, "beta");
914 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000915
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100916 const TensorInfo& mean = m_Mean->GetTensorInfo();
917 const TensorInfo& variance = m_Variance->GetTensorInfo();
918 const TensorInfo& beta = m_Beta->GetTensorInfo();
919 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000920
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100921 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
922 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
923 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
924 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000925
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100926 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
927 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
928 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000929}
930
931void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
932{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100933 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000934
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100935 ValidateNumInputs(workloadInfo, descriptorName, 1);
936 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000937
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100938 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
939 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +0000940
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100941 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
942 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +0000943
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100944 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000945
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100946 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
947 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000948
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100949 ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight");
telsoa014fcda012018-03-09 14:13:49 +0000950
951 if (m_Parameters.m_BiasEnabled)
952 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100953 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000954
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100955 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
956 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
957
958 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
959 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000960 }
961
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100962 std::vector<DataType> supportedTypes =
963 {
Ruomei Yan88d44b82019-05-23 14:29:06 +0100964 DataType::Float32,
965 DataType::QuantisedAsymm8,
966 DataType::QuantisedSymm16,
967 DataType::Float16
968 };
969
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100970 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
971 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
972}
Ruomei Yan88d44b82019-05-23 14:29:06 +0100973
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100974void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
975{
976 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
977
978 ValidateNumInputs(workloadInfo, descriptorName, 1);
979 ValidateNumOutputs(workloadInfo, descriptorName, 1);
980
981 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
982 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
983
984 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
985 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
986
987 ValidatePointer(m_Weight, descriptorName, "weight");
988
989 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
990 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
991
992 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
993 {
994 throw InvalidArgumentException(
995 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
996 "cannot be smaller than 1.") % descriptorName %
997 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
998 }
999
1000 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1001
1002 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1003 // inputChannels * channelMultiplier should be equal to outputChannels.
1004 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1005 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1006 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1007 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1008 {
1009 throw InvalidArgumentException(
1010 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
1011 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
1012 "(provided %4%).") % descriptorName % numWeightOutputChannels %
1013 numWeightInputChannels % numWeightChannelMultiplier));
1014 }
1015
1016 ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight");
1017
1018 if (m_Parameters.m_BiasEnabled)
1019 {
1020 ValidatePointer(m_Bias, descriptorName, "bias");
1021
1022 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
1023 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
1024
1025 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1026 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1027 }
1028
1029 std::vector<DataType> supportedTypes =
1030 {
1031 DataType::Float32,
1032 DataType::QuantisedAsymm8,
1033 DataType::QuantisedSymm16,
1034 DataType::Float16
1035 };
1036
1037 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1038 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001039}
1040
1041void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1042{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001043 const std::string descriptorName{"PermuteQueueDescriptor"};
1044
1045 ValidateNumInputs(workloadInfo, descriptorName, 1);
1046 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001047
1048 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1049
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001050 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1051 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001052
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001053 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1054 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001055
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001056 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001057 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001058 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001059 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001060 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1061 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1062 "must match dst dimension " + to_string(mapping[i]) +
1063 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001064 }
1065 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001066
1067 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001068}
1069
1070void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1071{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001072 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001073
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001074 ValidateNumInputs(workloadInfo, descriptorName, 1);
1075 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1076
1077 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1078 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1079
1080 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1081 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001082
1083 std::vector<DataType> supportedTypes =
1084 {
1085 DataType::Float32,
1086 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001087 DataType::QuantisedAsymm8,
1088 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001089 };
1090
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001091 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1092 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001093}
1094
1095void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1096{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001097 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001098
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001099 ValidateNumInputs(workloadInfo, descriptorName, 1);
1100 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1101
1102 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1103 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1104
1105 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1106 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001107
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001108 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001109 {
1110 DataType::Float16,
1111 DataType::Float32,
1112 DataType::QuantisedAsymm8,
1113 DataType::QuantisedSymm16
1114 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001115
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001116 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1117 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001118
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001119 // ResizeBilinear only changes width and height: batch and channel count must match.
1120 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1121 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001122 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001123 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001124 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001125 boost::str(boost::format("%1%: Input batch size (%2%) "
1126 "does not match output batch size (%3%)") %
1127 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001128 }
1129
Teresa Charlin970f43b2019-07-01 13:51:07 +01001130 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001131 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1132 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001133 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001134 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001135 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001136 boost::str(boost::format("%1%: Input channel count (%2%) "
1137 "does not match output channel count (%3%)") %
1138 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001139 }
1140}
1141
1142void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1143{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001144 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001145
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001146 ValidateNumInputs(workloadInfo, descriptorName, 1);
1147 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1148
1149 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1150 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1151
1152 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1153 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001154
1155 std::vector<DataType> supportedTypes =
1156 {
1157 DataType::Float16,
1158 DataType::Float32,
1159 DataType::QuantisedAsymm8,
1160 DataType::QuantisedSymm16
1161 };
1162
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001163 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1164 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001165
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001166 // Resize only changes width and height: batch and channel count must match.
1167 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1168 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001169 if (inputBatchSize != outputBatchSize)
1170 {
1171 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001172 boost::str(boost::format("%1%: Input batch size (%2%) "
1173 "does not match output batch size (%3%)") %
1174 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001175 }
1176
1177 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001178 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1179 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001180 if (inputChannelCount != outputChannelCount)
1181 {
1182 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001183 boost::str(boost::format("%1%: Input channel count (%2%) "
1184 "does not match output channel count (%3%)") %
1185 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001186 }
1187}
1188
1189void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1190{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001191 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001192
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001193 ValidateNumInputs(workloadInfo, descriptorName, 1);
1194 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1195
1196 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1197 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1198
1199 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1200 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1201
1202 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1203
telsoa014fcda012018-03-09 14:13:49 +00001204 if (m_Parameters.m_Min > m_Parameters.m_Max)
1205 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001206 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001207 }
telsoa014fcda012018-03-09 14:13:49 +00001208}
1209
1210void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1211{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001212 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001213
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001214 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001215 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1216
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001217 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1218 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1219
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001220 if (inputTensorInfo.GetNumDimensions() > 4)
1221 {
1222 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1223 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001224
1225 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001226
1227 // Check the supported data types
1228 std::vector<DataType> supportedTypes =
1229 {
1230 DataType::Float32,
1231 DataType::Float16,
1232 DataType::QuantisedAsymm8,
1233 DataType::QuantisedSymm16
1234 };
1235
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001236 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1237 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1238
1239 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001240}
1241
1242void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1243{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001244 const std::string descriptorName{"ConstantQueueDescriptor"};
1245
1246 ValidateNumInputs(workloadInfo, descriptorName, 0);
1247 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001248
1249 if (!m_LayerOutput)
1250 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001251 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001252 }
1253
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001254 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1255 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001256
1257 // Check the supported data types
1258 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001259 {
1260 DataType::Float32,
1261 DataType::Float16,
1262 DataType::Signed32,
1263 DataType::QuantisedAsymm8,
1264 DataType::QuantisedSymm16
1265 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001266
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001267 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001268}
1269
1270void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1271{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001272 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001273
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001274 ValidateNumInputs(workloadInfo, descriptorName, 1);
1275 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1276
1277 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1278 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1279
1280 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001281
1282 // Check the supported data types
1283 std::vector<DataType> supportedTypes =
1284 {
1285 DataType::Float32,
1286 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001287 DataType::QuantisedAsymm8,
1288 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001289 };
1290
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001291 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1292 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001293}
1294
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001295void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1296{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001297 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001298
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001299 ValidateNumInputs(workloadInfo, descriptorName, 1);
1300 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1301
1302 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1303 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1304
1305 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1306 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001307
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001308 if (m_Parameters.m_BlockShape.size() != 2)
1309 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001310 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001311 }
1312
1313 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1314 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001315 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1316 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001317 }
1318
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001319 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001320
1321 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001322 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001323
Matthew Bentham8800c002018-11-19 13:19:28 +00001324 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001325
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001326 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1327 widthPad.first + widthPad.second;
1328 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1329 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001330
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001331 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1332 inputShape[dimensionIndices.GetChannelsIndex()];
1333 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001334
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001335 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001336 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001337 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001338 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001339 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001340 }
1341
1342 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001343 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001344 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1345 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001346 }
nikraj01120522a2019-05-31 11:33:07 +01001347
1348 std::vector<DataType> supportedTypes =
1349 {
1350 DataType::Float16,
1351 DataType::Float32,
1352 DataType::QuantisedAsymm8,
1353 DataType::QuantisedSymm16
1354 };
1355
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001356 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1357 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001358}
1359
Keith Davisa57eccb2019-06-14 17:33:22 +01001360void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1361{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001362 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001363
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001364 ValidateNumInputs(workloadInfo, descriptorName, 1);
1365 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001366
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001367 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1368 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1369
1370 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1371 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001372
1373 std::vector<DataType> supportedTypes =
1374 {
1375 DataType::Float32,
1376 DataType::Float16,
James Conroyd2aa85e2019-07-01 17:12:40 +01001377 DataType::QuantisedAsymm8,
1378 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001379 };
1380
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001381 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1382 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001383
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001384 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1385 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1386 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1387 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001388
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001389 const TensorShape& inputShape = inputTensorInfo.GetShape();
Keith Davisa57eccb2019-06-14 17:33:22 +01001390
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001391 const unsigned int numInputElements =
1392 inputShape[0] * inputShape[wIndex] * inputShape[hIndex] * inputShape[cIndex];
1393 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
1394
1395 if (numOutputElements != numInputElements)
Keith Davisa57eccb2019-06-14 17:33:22 +01001396 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001397 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
1398 std::to_string(numInputElements) + " but output tensor has " +
1399 std::to_string(numOutputElements) + " elements.");
Keith Davisa57eccb2019-06-14 17:33:22 +01001400 }
1401
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001402 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001403 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001404 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1405 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001406 }
1407}
1408
telsoa014fcda012018-03-09 14:13:49 +00001409void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1410{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001411 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001412
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001413 ValidateNumInputs(workloadInfo, descriptorName, 1);
1414 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1415
1416 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1417 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001418
1419 std::vector<DataType> supportedTypes =
1420 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001421 DataType::Float32,
1422 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +01001423 };
1424
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001425 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001426
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001427 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001428 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001429 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001430 }
1431}
1432
telsoa01c577f2c2018-08-31 09:22:23 +01001433void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1434{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001435 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1436
1437 const std::string descriptorName{"LstmQueueDescriptor"};
1438
1439 // check dimensions of all inputs and outputs
1440 if (workloadInfo.m_InputTensorInfos.size() != 3)
1441 {
1442 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1443 }
1444 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1445 {
1446 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1447 }
1448
1449 std::vector<DataType> supportedTypes =
1450 {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001451 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001452 DataType::Float32,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001453 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001454 };
1455
Jan Eilers38e05bd2019-06-26 13:10:09 +01001456 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001457 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1458
Jan Eilers38e05bd2019-06-26 13:10:09 +01001459 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001460 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001461 {
1462 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1463 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001464 descriptorName,
1465 "input_0",
1466 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001467 }
1468 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001469 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001470 {
1471 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1472 workloadInfo.m_OutputTensorInfos[i],
1473 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001474 "input_0",
1475 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001476 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001477
Jan Eilers38e05bd2019-06-26 13:10:09 +01001478 // TODO: check clipping parameter is valid
1479
1480 // Inferring batch size, number of outputs and number of cells from the inputs.
1481 // TODO: figure out if there is a way to make sure the specific inputs are at that index of workloadInfo
1482 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1483 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1484 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1485 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1486 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1487 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1488
Jan Eilers38e05bd2019-06-26 13:10:09 +01001489 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001490 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1491 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001492 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001493 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1494 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001495 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001496 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1497 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001498 // scratchBufferTensor
1499 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001500 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1501 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001502 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001503 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1504 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001505 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001506 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1507 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001508 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001509 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1510 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001511
1512
1513 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1514 if ( m_InputToInputWeights )
1515 {
1516 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1517 (n_cell * n_input), "InputLayerNormWeights");
1518 }
1519
1520 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1521 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1522 (n_cell * n_input), "InputToForgetWeights");
1523
1524 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1525 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1526 (n_cell * n_input), "InputToCellWeights");
1527
1528 if ( m_RecurrentToInputWeights )
1529 {
1530 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1531 (n_cell * n_output), "RecurrentToInputWeights");
1532 }
1533
1534 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1535 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1536 (n_cell * n_output), "RecurrentToForgetWeights");
1537
1538 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1539 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1540 (n_cell * n_output), "RecurrentToCellWeights");
1541
1542 // Make sure the input-gate's parameters are either both present (regular
1543 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1544 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1545 !m_Parameters.m_CifgEnabled) ||
1546 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1547 m_Parameters.m_CifgEnabled));
1548 if (!cifg_weights_all_or_none)
1549 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001550 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1551 "RecurrentToInputWeights must either both be present (regular LSTM) "
1552 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1553 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001554 }
1555
1556 if ( m_CellToInputWeights )
1557 {
1558 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1559 n_cell, "CellToInputWeights");
1560 }
1561 if ( m_CellToForgetWeights )
1562 {
1563 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1564 n_cell, "CellToForgetWeights");
1565 }
1566 if ( m_CellToOutputWeights )
1567 {
1568 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1569 n_cell, "CellToOutputWeights");
1570 }
1571
1572 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1573 bool peephole_weights_all_or_none =
1574 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1575 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1576 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1577 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1578 if (!peephole_weights_all_or_none)
1579 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001580 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001581 }
1582
1583 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1584 if (m_Parameters.m_CifgEnabled)
1585 {
1586 if (m_InputGateBias)
1587 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001588 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001589 }
1590 }
1591 else
1592 {
1593 if (!m_InputGateBias)
1594 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001595 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1596 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001597 }
1598 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1599 n_cell, "InputGateBias");
1600 }
1601
1602 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1603 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1604
1605 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1606 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1607
1608 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1609 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1610
1611 if (m_ProjectionWeights)
1612 {
1613 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
1614 (n_cell * n_output), "ProjectionWeights");
1615 }
1616 if (m_ProjectionBias)
1617 {
1618 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
1619 }
1620
1621 // Making sure the projection tensors are consistent:
1622 // 1) If projection weight is not present, then projection bias should not be
1623 // present.
1624 // 2) If projection weight is present, then projection bias is optional.
1625 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
1626 !m_Parameters.m_ProjectionEnabled)
1627 || (m_ProjectionWeights && !m_ProjectionBias &&
1628 m_Parameters.m_ProjectionEnabled)
1629 || (m_ProjectionWeights && m_ProjectionBias &&
1630 m_Parameters.m_ProjectionEnabled));
1631 if (!projecton_tensors_consistent)
1632 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001633 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001634 }
1635
1636 // The four layer normalization weights either all have values or none of them have values. Additionally, if
1637 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
1638 // either all have values or none of them have values. Layer normalization is used when the values of all the
1639 // layer normalization weights are present
1640 if (m_InputLayerNormWeights)
1641 {
1642 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
1643 }
1644 if (m_ForgetLayerNormWeights)
1645 {
1646 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1647 }
1648 if (m_CellLayerNormWeights)
1649 {
1650 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1651 }
1652 if (m_OutputLayerNormWeights)
1653 {
1654 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1655 }
1656
Jan Eilers38e05bd2019-06-26 13:10:09 +01001657 if (m_Parameters.m_LayerNormEnabled)
1658 {
1659 if (!m_Parameters.m_CifgEnabled)
1660 {
1661 if (!m_InputLayerNormWeights)
1662 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001663 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
1664 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001665 }
1666 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
1667 1, n_cell, "InputLayerNormWeights");
1668 }
1669 else if (m_InputLayerNormWeights)
1670 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001671 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
1672 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001673 }
1674
1675 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
1676 "ForgetLayerNormWeights");
1677 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1678
1679 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
1680 "OutputLayerNormWeights");
1681 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1682
1683 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
1684 "CellLayerNormWeights");
1685 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1686 }
1687 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
1688 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001689 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
1690 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001691 }
telsoa01c577f2c2018-08-31 09:22:23 +01001692}
1693
1694void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1695{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001696 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001697
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001698 ValidateNumInputs(workloadInfo, descriptorName, 1);
1699 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1700
1701 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1702 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1703
1704 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01001705 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001706 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01001707 }
1708
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001709 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001710 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001711 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001712 }
1713
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001714 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001715}
1716
1717void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1718{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001719 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001720
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001721 ValidateNumInputs(workloadInfo, descriptorName, 1);
1722 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1723
1724 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1725 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1726
1727 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001728 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001729 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001730 }
1731
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001732 if (outputTensorInfo.GetDataType() != DataType::Float32)
1733 {
1734 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
1735 }
1736
1737 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001738}
1739
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001740void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1741{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001742 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001743
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001744 ValidateNumInputs(workloadInfo, descriptorName, 2);
1745 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1746
1747 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1748 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1749 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1750
1751 std::vector<DataType> supportedTypes =
1752 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001753 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001754 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001755 DataType::QuantisedSymm16,
1756 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001757 };
1758
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001759 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1760 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1761 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001762
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001763 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1764 inputTensorInfo1,
1765 outputTensorInfo,
1766 descriptorName,
1767 "input_0",
1768 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001769}
1770
David Beckc2044fe2018-09-05 15:00:38 +01001771void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1772{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001773 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01001774
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001775 ValidateNumInputs(workloadInfo, descriptorName, 2);
1776 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1777
1778 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1779 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1780 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1781
1782 std::vector<DataType> supportedTypes =
1783 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001784 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001785 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001786 DataType::QuantisedSymm16,
1787 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001788 };
1789
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001790 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1791 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1792 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001793
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001794 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1795 inputTensorInfo1,
1796 outputTensorInfo,
1797 descriptorName,
1798 "input_0",
1799 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01001800}
1801
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001802void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1803{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001804 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001805
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001806 ValidateNumInputs(workloadInfo, descriptorName, 2);
1807 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1808
1809 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1810 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1811 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1812
1813 std::vector<DataType> supportedTypes =
1814 {
Mike Kelly1da02362019-08-01 08:43:57 +01001815 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001816 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01001817 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001818 DataType::QuantisedAsymm8,
1819 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001820 };
1821
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001822 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1823 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1824 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001825
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001826 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1827 inputTensorInfo1,
1828 outputTensorInfo,
1829 descriptorName,
1830 "input_0",
1831 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001832}
1833
narpra01a6bf9122018-09-10 09:50:09 +01001834void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1835{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001836 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01001837
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001838 ValidateNumInputs(workloadInfo, descriptorName, 1);
1839 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1840
1841 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1842 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01001843
1844 std::vector<DataType> supportedTypes =
1845 {
1846 DataType::Float32,
1847 DataType::Float16,
1848 DataType::QuantisedAsymm8,
1849 DataType::QuantisedSymm16
1850 };
narpra01eb061912018-09-10 17:35:27 +01001851
James Conroy4d1ff582019-06-10 17:06:39 +01001852 // First check if input tensor data type is supported, then
1853 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001854 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1855 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01001856
narpra0132b90462018-09-13 11:07:48 +01001857 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01001858 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001859 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01001860 }
narpra0132b90462018-09-13 11:07:48 +01001861 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01001862 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001863 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01001864 }
1865 else
1866 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001867 unsigned int outputDim =
1868 inputTensorInfo.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
1869 ValidateTensorNumDimensions(outputTensorInfo,
1870 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01001871 outputDim > 0 ? outputDim : 1,
1872 "output");
1873 }
narpra01a6bf9122018-09-10 09:50:09 +01001874}
1875
jimfly012c9322a2018-09-19 10:59:49 +01001876void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1877{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001878 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01001879
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001880 ValidateNumInputs(workloadInfo, descriptorName, 1);
1881 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1882
1883 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1884 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01001885
jimfly012c9322a2018-09-19 10:59:49 +01001886 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001887 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
1888
jimfly012c9322a2018-09-19 10:59:49 +01001889 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001890 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
1891 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
1892 "as there are dimensions in the input tensor that is " +
1893 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
1894 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01001895 }
1896}
1897
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001898void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1899{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001900 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001901
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001902 ValidateNumInputs(workloadInfo, descriptorName, 1);
1903 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001904
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001905 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1906 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1907
Sadik Armagan2208b602019-07-31 16:36:27 +01001908 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001909 {
Sadik Armagan2208b602019-07-31 16:36:27 +01001910 DataType::Float32,
1911 DataType::Float16
1912 };
1913
1914 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001915
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001916 if (outputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
1917 outputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001918 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001919 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00001920 }
1921}
1922
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00001923void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1924{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001925 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001926
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001927 ValidateNumInputs(workloadInfo, descriptorName, 1);
1928 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001929
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001930 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1931 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001932
1933 std::vector<DataType> supportedTypes =
1934 {
1935 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01001936 DataType::Float16,
Francis Murtaghd0dfe172019-06-25 10:57:10 +01001937 DataType::QuantisedAsymm8,
1938 DataType::QuantisedSymm16
1939 };
1940
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001941 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1942 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00001943}
1944
Conor Kennedy430b5d82018-11-14 15:28:28 +00001945void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1946{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001947 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00001948
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001949 ValidateNumInputs(workloadInfo, descriptorName, 1);
1950 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1951
1952 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1953 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001954
1955 std::vector<DataType> supportedTypes =
1956 {
1957 DataType::Float16,
1958 DataType::Float32,
Matteo Martincigh42666a12019-05-29 08:53:41 +01001959 DataType::QuantisedAsymm8,
1960 DataType::QuantisedSymm16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001961 };
1962
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001963 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1964 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001965
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001966 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01001967
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001968 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00001969 if (rank > 4)
1970 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001971 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00001972 }
1973
Conor Kennedy430b5d82018-11-14 15:28:28 +00001974 // Begin, End & Stride length must be of rank(input0)
1975 if (m_Parameters.m_Begin.size() != rank)
1976 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001977 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00001978 }
1979
1980 if (m_Parameters.m_End.size() != rank)
1981 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001982 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00001983 }
1984
1985 if (m_Parameters.m_Stride.size() != rank)
1986 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001987 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00001988 }
1989
1990 // Stride entries must be non-zero
1991 for (auto& stride : m_Parameters.m_Stride)
1992 {
1993 if (stride == 0)
1994 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001995 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00001996 }
1997 }
1998}
1999
kevmay0190539692018-11-29 08:40:19 +00002000void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2001{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002002 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002003
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002004 ValidateNumInputs(workloadInfo, descriptorName, 2);
2005 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2006
2007 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2008 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2009 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2010
2011 std::vector<DataType> supportedTypes =
2012 {
Mike Kelly1da02362019-08-01 08:43:57 +01002013 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002014 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002015 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01002016 DataType::QuantisedAsymm8,
2017 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002018 };
2019
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002020 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2021 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2022 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002023
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002024 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2025 inputTensorInfo1,
2026 outputTensorInfo,
2027 descriptorName,
2028 "input_0",
2029 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002030}
2031
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002032void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2033{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002034 const std::string descriptorName{"DebugQueueDescriptor"};
2035
2036 ValidateNumInputs(workloadInfo, descriptorName, 1);
2037 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002038}
2039
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002040void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2041{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002042 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002043
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002044 ValidateNumInputs(workloadInfo, descriptorName, 2);
2045 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002046
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002047 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2048 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2049 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2050
2051 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2052 inputTensorInfo1,
2053 outputTensorInfo,
2054 descriptorName,
2055 "input_0",
2056 "input_1");
2057
2058 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002059 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002060 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002061 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002062}
2063
FrancisMurtagh878f0232018-12-19 10:56:15 +00002064void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2065{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002066 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002067
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002068 ValidateNumInputs(workloadInfo, descriptorName, 2);
2069 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002070
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002071 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2072 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2073 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2074
2075 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2076 inputTensorInfo1,
2077 outputTensorInfo,
2078 descriptorName,
2079 "input_0",
2080 "input_1");
2081
2082 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002083 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002084 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002085 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002086}
2087
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002088void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2089{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002090 const std::string descriptorName{"RsqrtQueueDescriptor"};
2091
2092 ValidateNumInputs(workloadInfo, descriptorName, 1);
2093 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2094
2095 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2096 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2097
2098 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002099
2100 std::vector<DataType> supportedTypes =
2101 {
2102 DataType::Float16,
2103 DataType::Float32,
nikraj0124d73212019-06-14 14:20:40 +01002104 DataType::QuantisedAsymm8,
2105 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01002106 };
2107
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002108 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2109 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002110}
2111
narpra01b89b05f2019-01-16 09:53:09 +00002112void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2113{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002114 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002115
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002116 ValidateNumInputs(workloadInfo, descriptorName, 2);
2117 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002118
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002119 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2120 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002121 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002122 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002123 }
2124
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002125 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2126 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2127
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002128 std::vector<DataType> supportedTypes =
2129 {
2130 DataType::Float16,
2131 DataType::Float32,
2132 DataType::QuantisedAsymm8,
2133 DataType::QuantisedSymm16
2134 };
2135
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002136 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002137
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002138 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002139
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002140 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2141 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002142}
2143
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002144void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2145{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002146 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2147
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002148 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002149
2150 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2151 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002152 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002153 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2154 }
2155
2156 if (m_Anchors == nullptr)
2157 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002158 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002159 }
2160
2161 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002162 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2163 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2164
2165 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002166 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002167 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2168 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002169
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002170 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2171 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2172 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002173
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002174 const std::vector<DataType> supportedInputTypes =
2175 {
2176 DataType::Float32,
2177 DataType::QuantisedAsymm8,
2178 DataType::QuantisedSymm16
2179 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002180
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002181 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2182 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2183 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2184
2185 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2186 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2187 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2188 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2189
2190 // NOTE: Output is always Float32 regardless of input type
2191 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2192 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2193 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2194 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002195
2196 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2197 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002198 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002199 "must be positive and less than or equal to 1.");
2200 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002201
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002202 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2203 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002204 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002205 "should be equal to number of classes + 1.");
2206 }
2207}
2208
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002209void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2210{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002211 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002212
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002213 ValidateNumInputs(workloadInfo, descriptorName, 1);
2214 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2215
2216 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2217 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2218
2219 if (inputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
2220 inputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002221 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002222 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002223 }
2224
Sadik Armagan2208b602019-07-31 16:36:27 +01002225 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002226 {
Sadik Armagan2208b602019-07-31 16:36:27 +01002227 DataType::Float32,
2228 DataType::Float16
2229 };
2230
2231 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002232}
2233
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002234void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2235{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002236 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002237
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002238 ValidateNumInputs(workloadInfo, descriptorName, 2);
2239 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002240
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002241 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2242 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2243 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002244
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002245 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2246 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2247
2248 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2249 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002250}
2251
Sadik Armaganeff363d2019-04-05 15:25:46 +01002252void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2253{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002254 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002255
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002256 ValidateNumInputs(workloadInfo, descriptorName, 2);
2257 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2258
2259 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2260 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2261
2262 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2263 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2264
2265 std::vector<DataType> supportedTypes =
2266 {
Sadik Armaganeff363d2019-04-05 15:25:46 +01002267 DataType::Float32,
2268 DataType::QuantisedAsymm8,
2269 DataType::QuantisedSymm16
2270 };
2271
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002272 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2273 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002274
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002275 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2276 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002277
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002278 ValidateTensorShapesMatch(inputTensorInfo0,
2279 outputTensorInfo0,
2280 descriptorName,
2281 "input_0",
2282 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002283
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002284 ValidateTensorShapesMatch(inputTensorInfo0,
2285 outputTensorInfo1,
2286 descriptorName,
2287 "input_0",
2288 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002289}
2290
Matteo Martincigh49124022019-01-11 13:25:59 +00002291void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2292{
2293 // This is internally generated so it should not need validation.
2294}
2295
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002296void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2297{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002298 const std::string& descriptorName{"PreluQueueDescriptor"};
2299
2300 ValidateNumInputs(workloadInfo, descriptorName, 2);
2301 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2302
2303 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2304 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2305 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002306
2307 std::vector<DataType> supportedTypes
2308 {
2309 DataType::Float16,
2310 DataType::Float32,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002311 DataType::QuantisedAsymm8,
2312 DataType::QuantisedSymm16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002313 };
2314
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002315 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2316 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002317
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002318 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002319
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002320 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2321 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002322
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002323 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2324 alphaTensorInfo,
2325 outputTensorInfo,
2326 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002327 "input",
2328 "alpha");
2329}
2330
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002331void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2332{
2333 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2334
2335 ValidateNumInputs(workloadInfo, descriptorName, 1);
2336 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2337
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002338 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2339 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2340
2341 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2342 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002343
2344 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002345
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002346 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2347 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
2348 ValidateTensorDataType(weightTensorInfo, inputTensorInfo.GetDataType(), descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002349
2350 if (m_Parameters.m_BiasEnabled)
2351 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002352 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002353
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002354 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
2355 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002356
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002357 ValidateTensorDataType(biasTensorInfo,
2358 GetBiasDataType(inputTensorInfo.GetDataType()),
2359 descriptorName,
2360 "bias");
2361
2362 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002363 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002364}
2365
James Conroy9c3cae82019-08-01 16:01:48 +01002366void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2367{
2368 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
2369
2370 // Validate number of inputs/outputs
2371 ValidateNumInputs(workloadInfo, descriptorName, 3);
2372 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2373
2374 // Input/output tensor infos
2375 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2376 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
2377 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
2378
2379 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2380 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2381
2382 std::vector<DataType> inputOutputSupportedTypes =
2383 {
2384 DataType::QuantisedAsymm8
2385 };
2386
2387 std::vector<DataType> cellStateSupportedTypes =
2388 {
2389 DataType::QuantisedSymm16
2390 };
2391
2392 std::vector<DataType> weightsSupportedTypes =
2393 {
2394 DataType::QuantisedAsymm8
2395 };
2396
2397 std::vector<DataType> biasSupportedTypes =
2398 {
2399 DataType::Signed32
2400 };
2401
2402 // Validate types of input/output tensors
2403 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2404 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2405 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2406
2407 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2408 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2409
2410 // Validate matching types of input/output tensors
2411 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2412 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2413 "outputStateIn", "outputStateOut");
2414 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2415
2416 // Validate matching quantization info for input/output tensors
2417 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2418 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
2419 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2420
2421 // Infer number of batches, input size and output size from tensor dimensions
2422 const uint32_t numBatches = inputInfo.GetShape()[0];
2423 const uint32_t inputSize = inputInfo.GetShape()[1];
2424 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
2425
2426 // Validate number of dimensions and number of elements for input/output tensors
2427 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2428 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
2429 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2430 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
2431 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2432
2433 // Validate number of dimensions and number of elements for weights tensors
2434 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
2435 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
2436 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
2437
2438 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2439 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2440 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
2441
2442 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2443 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2444 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
2445
2446 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2447 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2448 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
2449
2450 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
2451 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
2452 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
2453
2454 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2455 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2456 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
2457 " RecurrentToForgetWeights");
2458
2459 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2460 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2461 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2462
2463 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2464 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2465 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2466
2467 // Validate data types for weights tensors (all should match each other)
2468 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
2469
2470 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
2471 "inputToInputWeights", "inputToForgetWeights");
2472 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
2473 "inputToInputWeights", "inputToCellWeights");
2474 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2475 "inputToInputWeights", "inputToOutputWeights");
2476
2477 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
2478 "inputToInputWeights", "recurrentToInputWeights");
2479 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2480 "inputToInputWeights", "recurrentToForgeteights");
2481 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2482 "inputToInputWeights", "recurrentToCellWeights");
2483 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2484 "inputToInputWeights", "recurrentToOutputWeights");
2485
2486 // Validate matching quantization info for weight tensors (all should match each other)
2487 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
2488 descriptorName, "inputToInputWeights", "inputToForgetWeights");
2489 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
2490 descriptorName, "inputToInputWeights", "inputToCellWeights");
2491 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
2492 descriptorName, "inputToInputWeights", "inputToOutputWeights");
2493
2494 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
2495 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
2496 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
2497 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
2498 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
2499 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
2500 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
2501 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
2502
2503 // Validate number of dimensions and number of elements in bias tensors
2504 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
2505 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
2506 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
2507
2508 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2509 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2510 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
2511
2512 ValidatePointer(m_CellBias, descriptorName, "CellBias");
2513 auto cellBiasInfo = m_CellBias->GetTensorInfo();
2514 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
2515
2516 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
2517 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
2518 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
2519
2520 // Validate data types for bias tensors (all should match each other)
2521 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
2522
2523 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
2524 "inputGateBias", "forgetGateBias");
2525 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
2526 "inputGateBias", "cellBias");
2527 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
2528 "inputGateBias", "outputGateBias");
2529
2530 // Validate bias tensor quantization info
2531 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2532 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2533 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2534 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2535}
2536
Kevin May868eb142019-09-04 17:29:31 +01002537void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2538{
2539 const std::string descriptorName{"AbsQueueDescriptor"};
2540
2541 ValidateNumInputs(workloadInfo, descriptorName, 1);
2542 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2543
2544 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2545 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2546
2547 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2548
2549 std::vector<DataType> supportedTypes =
2550 {
2551 DataType::Float16,
2552 DataType::Float32,
2553 DataType::QuantisedAsymm8,
2554 DataType::QuantisedSymm16
2555 };
2556
2557 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2558 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2559}
2560
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002561} // namespace armnn