blob: e1a369af7c76790ac94613bfce49f8e93ab2c471 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "WorkloadData.hpp"
6
7#include "CpuTensorHandle.hpp"
telsoa014fcda012018-03-09 14:13:49 +00008
Matteo Martincigh21350152018-11-28 16:22:22 +00009#include <DataLayoutIndexed.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000013#include <string>
14#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/format.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010017#include <boost/numeric/conversion/cast.hpp>
James Conroyc8724c72019-10-08 15:41:34 +010018#include <TensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000019
Matteo Martincigh21350152018-11-28 16:22:22 +000020using namespace armnnUtils;
21
telsoa014fcda012018-03-09 14:13:49 +000022namespace armnn
23{
24
25//---------------------------------------------------------------
26DataType GetBiasDataType(DataType inputDataType)
27{
28 switch (inputDataType)
29 {
telsoa01c577f2c2018-08-31 09:22:23 +010030 case DataType::Float16:
31 return DataType::Float16;
telsoa014fcda012018-03-09 14:13:49 +000032 case DataType::Float32:
33 return DataType::Float32;
34 case DataType::QuantisedAsymm8:
35 return DataType::Signed32;
Ruomei Yan88d44b82019-05-23 14:29:06 +010036 case DataType::QuantisedSymm16:
37 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000038 default:
39 BOOST_ASSERT_MSG(false, "Invalid input data type");
40 return DataType::Float32;
41 }
42}
43
44namespace
45{
46
47//---------------------------------------------------------------
48//android ndk does not support std::to_string function.
49template <typename T>
50std::string to_string(T value)
51{
52 std::ostringstream os;
53 os << value;
54 return os.str();
55}
56
57//---------------------------------------------------------------
58void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
59{
60 if (!ptr)
61 {
62 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
63 paramName + " parameter must be set.");
64 }
65}
66
67//---------------------------------------------------------------
68void ValidateTensorShapesMatch(const TensorInfo& first,
69 const TensorInfo& second,
70 std::string const& descName,
71 std::string const& firstName,
72 std::string const& secondName)
73{
74 if (first.GetShape() != second.GetShape())
75 {
76 throw InvalidArgumentException(descName + ": "
77 + firstName + " & " + secondName + " must have identical shapes");
78 }
79}
80
81//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010082void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000083{
Sadik Armaganeff363d2019-04-05 15:25:46 +010084 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000085 {
86 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010087 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000088 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
89 }
90}
91
92//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010093void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000094{
Sadik Armaganeff363d2019-04-05 15:25:46 +010095 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000096 {
97 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010098 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +000099 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
100 }
101}
102
103//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100104void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000105 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100106 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000107 std::string const& tensorName)
108{
109 if (tensor.GetNumDimensions() != numDimensions)
110 {
111 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
112 to_string(tensor.GetNumDimensions()) + " dimensions for " +
113 tensorName + " tensor.");
114 }
115}
116
117//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100118void ValidateTensorNumElements(const TensorInfo& tensor,
119 std::string const& descName,
120 unsigned int numElements,
121 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100122{
123 if (tensor.GetNumElements() != numElements)
124 {
125 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100126 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100127 tensorName + " tensor.");
128 }
129}
130
131//---------------------------------------------------------------
132void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100133 unsigned int numDimension,
134 unsigned int numElements,
135 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100136{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100137 const std::string functionName{"ValidateTensorNumDimNumElem"};
138 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
139 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100140}
141
142//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000143void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
144 const std::string& descName, std::string const& tensorName)
145{
146 if (tensor.GetDataType() != dataType)
147 {
148 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
149 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
150 }
151}
152
153//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100154void ValidateTensorQuantizationSpace(const TensorInfo& first,
155 const TensorInfo& second,
156 const std::string& descName,
157 std::string const& firstName,
158 std::string const& secondName)
159{
160 if (!first.IsQuantized() ||
161 !second.IsQuantized())
162 {
163 // Not a quantized type, ignore the validation
164 return;
165 }
166
167 DataType firstDataType = first.GetDataType();
168 DataType secondDataType = second.GetDataType();
169
170 if (firstDataType != secondDataType)
171 {
172 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
173 " must be of the same quantized type, " +
174 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
175 secondName + " is " + GetDataTypeName(secondDataType));
176 }
177
178 if (!first.IsTypeSpaceMatch(second))
179 {
180 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
181 " must have the same quantization space, " +
182 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
183 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
184 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
185 " and scale " + to_string(second.GetQuantizationScale()));
186 }
187}
188
189//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100190void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
191 const TensorInfo& inputTensorInfo,
192 const TensorInfo& weightsTensorInfo,
193 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000194{
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000195 // Helper lambda function to validate a single bias quantization scale value
196 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
197 {
ricbur013f4d7102019-10-31 16:22:18 +0000198 constexpr float tolerance = 0.000001f;
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000199 if (std::abs(biasScale - expectedScale) > tolerance)
200 {
201 // Print the float values with extra precision to see very small differences
202 std::stringstream msg;
203 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
204 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
205 biasScale;
206 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
207 }
208 };
209
telsoa014fcda012018-03-09 14:13:49 +0000210 if (biasTensor.GetQuantizationOffset() != 0)
211 {
212 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
213 to_string(biasTensor.GetQuantizationOffset()));
214 }
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000215
216 if (biasTensor.HasMultipleQuantizationScales())
telsoa014fcda012018-03-09 14:13:49 +0000217 {
Aron Virginas-Tard9053072019-10-30 16:03:19 +0000218 // Validate per-axis quantization scales
219 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
220 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
221
222 if (weightScales.size() != biasScales.size())
223 {
224 std::stringstream msg;
225 msg << descName << ": Expected matchhing number of per-axis quantization scales, but got different "
226 << "values: weights=" << weightScales.size() << ", biases=" << biasScales.size();
227 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
228 }
229
230 for (size_t i = 0ul; i < biasScales.size(); ++i)
231 {
232 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
233 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
234 }
235 }
236 else
237 {
238 // Validate per-tensor quantization scale
239 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
240 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
telsoa014fcda012018-03-09 14:13:49 +0000241 }
242}
243
244//---------------------------------------------------------------
245void ValidateTensors(const std::vector<ITensorHandle*>& vec,
246 unsigned int numExpected,
247 const std::string& descName,
248 const std::string& varName)
249{
250 if (vec.empty() && numExpected > 0)
251 {
252 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
253 }
254
255 for (unsigned int i = 0; i < numExpected; ++i)
256 {
257 if (!vec[i])
258 {
259 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
260 }
261 }
262}
263
264//---------------------------------------------------------------
265void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
266 const TensorInfo& second,
267 const TensorInfo& output,
268 std::string const& descName,
269 std::string const& firstName,
270 std::string const& secondName)
271{
272 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
273 // broadcasted.
274 if (first.GetNumDimensions() != second.GetNumDimensions())
275 {
276 throw InvalidArgumentException(descName + ": Tensors "
277 + firstName + " & " + secondName
278 + " must have the same number of dimensions in order to be broadcasted");
279 }
280 uint32_t numDims = first.GetNumDimensions();
281 std::vector<uint32_t> outputDims(numDims, 0u);
282 for (uint32_t i = 0; i < numDims; i++)
283 {
284 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
285 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
286 if (dimsNotEqual && dimsNotOne)
287 {
288 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
289 }
290 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
291 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100292 TensorShape broadcastShape = TensorShape(boost::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000293 if (broadcastShape != output.GetShape())
294 {
295 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
296 + firstName + " & " + secondName
297 + " does not match the output shape");
298 }
299}
300
301//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100302void ValidateDataTypes(const TensorInfo& info,
303 const std::vector<armnn::DataType>& supportedTypes,
304 std::string const& descName)
305{
306 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
307 if (iterator == supportedTypes.end())
308 {
309 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
310 }
311}
312
James Conroy4d1ff582019-06-10 17:06:39 +0100313//---------------------------------------------------------------
314void ValidateTensorDataTypesMatch(const TensorInfo& first,
315 const TensorInfo& second,
316 std::string const& descName,
317 std::string const& firstName,
318 std::string const& secondName)
319{
320 if (first.GetDataType() != second.GetDataType())
321 {
322 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
323 " must have identical data types.");
324 }
325}
326
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100327//---------------------------------------------------------------
328void ValidateTensorNumElementsMatch(const TensorInfo& first,
329 const TensorInfo& second,
330 std::string const& descName,
331 std::string const& firstName,
332 std::string const& secondName)
333{
334 if (first.GetNumElements() != second.GetNumElements())
335 {
336 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
337 " must have the same number of elements.");
338 }
339}
340
341} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000342
343void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
344 unsigned int numExpectedIn, unsigned int numExpectedOut) const
345{
346 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
347 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
348}
349
350//---------------------------------------------------------------
351void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
352{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100353 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000354
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100355 ValidateNumInputs(workloadInfo, descriptorName, 1);
356 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000357
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100358 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
359 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
360
361 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
362 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000363
364 if (m_Inputs.size() != m_Outputs.size())
365 {
366 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100367 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
368 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000369 }
370
371 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
372 {
373 if (!m_Inputs[i])
374 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100375 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
376 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000377 }
378
379 if (!m_Outputs[i])
380 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100381 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
382 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000383 }
384 }
385}
386
Derek Lambertif674aa02019-08-01 15:56:25 +0100387//---------------------------------------------------------------
388void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
389{
390 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
391 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
392
393 if (workloadInfo.m_InputTensorInfos.size() != 1)
394 {
395 throw InvalidArgumentException(boost::str(
396 boost::format("Number of input infos (%1%) is not 1.")
397 % workloadInfo.m_InputTensorInfos.size()));
398
399 }
400
401 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
402 {
403 throw InvalidArgumentException(boost::str(
404 boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
405 % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
406 }
407
408 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
409 {
410 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
411 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
412 {
413 throw InvalidArgumentException(boost::str(
414 boost::format("Number of elements for tensor input and output %1% does not match")
415 % i ));
416 }
417 }
418
419 if (m_Inputs.size() != 1)
420 {
421 throw InvalidArgumentException(boost::str(
422 boost::format("Number of inputs (%1%) is not 1.")
423 % m_Inputs.size()));
424 }
425
426 if (m_Inputs.size() != m_Outputs.size())
427 {
428 throw InvalidArgumentException(boost::str(
429 boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
430 % m_Inputs.size() % m_Outputs.size()));
431 }
432
433 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
434 {
435 if (!m_Inputs[i])
436 {
437 throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
438 }
439
440 if (!m_Outputs[i])
441 {
442 throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
443 }
444 }
445}
446
447//---------------------------------------------------------------
448void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
449{
450 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
451 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
452
Derek Lambertif674aa02019-08-01 15:56:25 +0100453 if (m_Inputs.size() != 1)
454 {
455 throw InvalidArgumentException(boost::str(
456 boost::format("Number of inputs (%1%) is not 1.")
457 % m_Inputs.size()));
458 }
459
460 if (m_Outputs.size() != 0)
461 {
462 throw InvalidArgumentException(boost::str(
463 boost::format("Number of outputs (%1%) is not 0.")
464 % m_Inputs.size() % m_Outputs.size()));
465 }
466
467 if (!m_Inputs[0])
468 {
469 throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
470 }
471}
472
473//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000474void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
475{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100476 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100477
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100478 ValidateNumInputs(workloadInfo, descriptorName, 1);
479 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100480
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100481 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
482 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100483
484 std::vector<DataType> supportedTypes =
485 {
James Conroyd47a0642019-09-17 14:22:06 +0100486 DataType::Float16,
487 DataType::Float32,
488 DataType::QuantisedAsymm8,
489 DataType::QuantisedSymm16
nikraj01248683f2019-05-29 16:46:50 +0100490 };
491
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100492 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
493 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
494 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000495}
496
Nikhil Rajee391d52019-09-05 17:50:44 +0100497void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
498{
499 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
500
501 ValidateNumInputs(workloadInfo, descriptorName, 1);
502 ValidateNumOutputs(workloadInfo, descriptorName, 1);
503
504 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
505 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
506
Nikhil Raj68c2c902019-09-19 11:21:11 +0100507 if (outputTensorInfo.GetDataType() != DataType::Signed32)
508 {
509 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32.");
510 }
511
James Conroyd47a0642019-09-17 14:22:06 +0100512 std::vector<DataType> supportedInputTypes =
513 {
514 DataType::Float16,
515 DataType::Float32,
516 DataType::QuantisedAsymm8,
517 DataType::QuantisedSymm16
518 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100519
James Conroyd47a0642019-09-17 14:22:06 +0100520 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100521
522 auto inputShape = inputTensorInfo.GetShape();
523 auto outputShape = outputTensorInfo.GetShape();
524
525 auto inputNumDimensions = inputShape.GetNumDimensions();
526 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
527
528 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
529
530 // 1D input shape results in scalar output shape
531 if (inputShape.GetNumDimensions() == 1)
532 {
533 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
534 {
535 throw InvalidArgumentException(descriptorName + outputShapeError);
536 }
537 }
538 else
539 {
540 for (unsigned int i = 0; i < unsignedAxis; ++i)
541 {
542 if (outputShape[i] != inputShape[i])
543 {
544 throw InvalidArgumentException(descriptorName + outputShapeError);
545 }
546 }
547
548 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
549 {
550 if (outputShape[i - 1] != inputShape[i])
551 {
552 throw InvalidArgumentException(descriptorName + outputShapeError);
553 }
554 }
555 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100556}
557
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100558void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
559{
560 const std::string descriptorName{"SoftmaxQueueDescriptor"};
561
562 ValidateNumInputs(workloadInfo, descriptorName, 1);
563 ValidateNumOutputs(workloadInfo, descriptorName, 1);
564
565 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
566 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
567
568 std::vector<DataType> supportedTypes =
569 {
James Conroyd47a0642019-09-17 14:22:06 +0100570 DataType::Float16,
571 DataType::Float32,
572 DataType::QuantisedAsymm8,
573 DataType::QuantisedSymm16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100574 };
575
576 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
577 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
578 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
579}
580
telsoa014fcda012018-03-09 14:13:49 +0000581void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
582{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100583 const std::string descriptorName{"SplitterQueueDescriptor"};
584
585 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000586
Ruomei Yan25339c32019-05-28 16:48:20 +0100587 // Check the supported data types
588 std::vector<DataType> supportedTypes =
589 {
James Conroyd47a0642019-09-17 14:22:06 +0100590 DataType::Float32,
591 DataType::Float16,
592 DataType::Boolean,
593 DataType::Signed32,
594 DataType::QuantisedAsymm8,
595 DataType::QuantisedSymm16
Ruomei Yan25339c32019-05-28 16:48:20 +0100596 };
597
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100598 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
599 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100600 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100601 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
602 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
603
604 const std::string outputName = "output_" + std::to_string(i);
605 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100606 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100607
telsoa014fcda012018-03-09 14:13:49 +0000608 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
609 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100610 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000611 }
612
613 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
614 {
615 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100616 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000617 "has to match number of workloadInfo.m_OutputTensorInfos. "
618 "Number of windows: " +
619 to_string(m_ViewOrigins.size()) +
620 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
621 }
622
telsoa01c577f2c2018-08-31 09:22:23 +0100623 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000624 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
625 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
626 {
telsoa01c577f2c2018-08-31 09:22:23 +0100627 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000628 ViewOrigin const& e = m_ViewOrigins[w];
629 if (e.m_Origin.size() != inputDims)
630 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100631 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000632 "have the same dimensionality as the input tensor. "
633 "Window origin (index: " +
634 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
635 " dimensions, the input "
636 "tensor has " +
637 to_string(inputDims) + " dimensions.");
638 }
639 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
640 {
641 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
642 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
643 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100644 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000645 "be smaller or equal than the size of the input in that coord.");
646 }
647 }
648 }
649}
650
Jim Flynne242f2d2019-05-22 14:24:13 +0100651void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000652{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100653 const std::string descriptorName{"ConcatQueueDescriptor"};
654
655 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000656
657 if (m_Inputs.size() <= 0)
658 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100659 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000660 }
661 if (m_Outputs.size() <= 0)
662 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100663 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000664 }
665
666 if (workloadInfo.m_InputTensorInfos.size() <= 0)
667 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100668 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000669 }
670 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
671 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100672 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000673 }
674
Nikhil Raj8599a412018-11-19 14:51:07 +0000675 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
676 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100677 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000678 }
679
680 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
681 {
682 return;
683 }
684
telsoa014fcda012018-03-09 14:13:49 +0000685 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
686 {
687 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100688 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000689 "has to match number of workloadInfo.m_InputTensorInfos. "
690 "Number of windows: " +
691 to_string(m_ViewOrigins.size()) +
692 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
693 }
694
telsoa01c577f2c2018-08-31 09:22:23 +0100695 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000696 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
697 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
698 {
telsoa01c577f2c2018-08-31 09:22:23 +0100699 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000700 ViewOrigin const& e = m_ViewOrigins[w];
701 if (e.m_Origin.size() != outputDims)
702 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100703 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000704 "have the same dimensionality as the output tensor. "
705 "Window origin (index: " +
706 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
707 " dimensions, the output "
708 "tensor has " +
709 to_string(outputDims) + " dimensions.");
710 }
telsoa01c577f2c2018-08-31 09:22:23 +0100711 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000712 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
713 {
714 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
715 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
716 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100717 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000718 "be smaller or equal than the size of the output in that coord.");
719 }
720 }
721 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100722
723 // Check the supported data types
724 std::vector<DataType> supportedTypes =
725 {
James Conroyd47a0642019-09-17 14:22:06 +0100726 DataType::Float32,
727 DataType::Float16,
728 DataType::Boolean,
729 DataType::Signed32,
730 DataType::QuantisedAsymm8,
731 DataType::QuantisedSymm16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100732 };
733
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100734 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
735 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100736 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100737 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
738 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
739
740 const std::string inputName = "input_" + std::to_string(i);
741 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100742 }
telsoa014fcda012018-03-09 14:13:49 +0000743}
744
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100745void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
746{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100747 const std::string descriptorName{"StackQueueDescriptor"};
748
749 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100750
751 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
752 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100753 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100754 }
755
756 // All inputs must have the same shape, which is defined in parameters
757 const TensorShape& inputShape = m_Parameters.m_InputShape;
758 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
759 {
760 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
761 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100762 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100763 }
764 }
765
Matthew Jacksondba634f2019-08-15 15:14:18 +0100766 if (inputShape.GetNumDimensions() > 4)
767 {
768 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
769 }
770
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100771 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
772 // since the output tensor has an additional dimension.
773 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
774 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100775 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100776 "than the number of input dimensions.");
777 }
778
779 // Output shape must be as inferred from the input shape
780 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
781 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
782 {
783 if (outputShape[i] != inputShape[i])
784 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100785 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100786 "match shape inferred from input tensor.");
787 }
788 }
789
790 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
791 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100792 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100793 "match shape inferred from input tensor.");
794 }
795
796 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
797 {
798 if (outputShape[i] != inputShape[i-1])
799 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100800 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100801 "match shape inferred from input tensor.");
802 }
803 }
804
Matthew Jacksondba634f2019-08-15 15:14:18 +0100805 if (outputShape.GetNumDimensions() > 5)
806 {
807 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
808 }
809
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100810 // Check the supported data types
811 std::vector<DataType> supportedTypes =
812 {
James Conroyd47a0642019-09-17 14:22:06 +0100813 DataType::Float32,
814 DataType::Float16,
815 DataType::Boolean,
816 DataType::Signed32,
817 DataType::QuantisedAsymm8,
818 DataType::QuantisedSymm16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100819 };
820
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100821 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100822
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100823 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100824 {
825 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
826 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100827 descriptorName,
828 "input_0",
829 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100830 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100831
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100832 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
833 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100834 descriptorName,
835 "input_0",
836 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100837}
838
telsoa014fcda012018-03-09 14:13:49 +0000839void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
840{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100841 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000842
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100843 ValidateNumInputs(workloadInfo, descriptorName, 1);
844 ValidateNumOutputs(workloadInfo, descriptorName, 1);
845
846 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
847 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
848
849 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
850
851 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +0000852 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100853 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +0000854 }
855
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100856 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000857
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100858 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
859 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000860
861 if (m_Parameters.m_BiasEnabled)
862 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100863 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000864
telsoa01c577f2c2018-08-31 09:22:23 +0100865 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100866 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
867 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000868
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100869 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
870 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000871 }
872
Francis Murtagh46c09d02019-05-28 08:15:28 +0100873 // Check the supported data types
874 std::vector<DataType> supportedTypes =
875 {
James Conroyd47a0642019-09-17 14:22:06 +0100876 DataType::Float32,
877 DataType::Float16,
878 DataType::QuantisedAsymm8,
879 DataType::QuantisedSymm16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100880 };
881
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100882 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
883 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000884}
885
telsoa014fcda012018-03-09 14:13:49 +0000886void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
887{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100888 const std::string descriptorName{"NormalizationQueueDescriptor"};
889
890 ValidateNumInputs(workloadInfo, descriptorName, 1);
891 ValidateNumOutputs(workloadInfo, descriptorName, 1);
892
893 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
894 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100895
896 // Check the supported data types
897 std::vector<DataType> supportedTypes =
898 {
899 DataType::Float16,
900 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100901 DataType::QuantisedAsymm8,
902 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100903 };
904
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100905 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100906
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100907 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100908
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100909 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000910}
911
912void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
913{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100914 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000915
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100916 ValidateNumInputs(workloadInfo, descriptorName, 2);
917 ValidateNumOutputs(workloadInfo, descriptorName, 1);
918
919 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
920 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
921 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
922
923 std::vector<DataType> supportedTypes =
924 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100925 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100926 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +0100927 DataType::QuantisedSymm16,
928 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100929 };
930
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100931 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
932 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
933 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100934
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100935 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
936 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100937
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100938 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
939 inputTensorInfo1,
940 outputTensorInfo,
941 descriptorName,
942 "input_0",
943 "input_1");
telsoa014fcda012018-03-09 14:13:49 +0000944}
945
telsoa014fcda012018-03-09 14:13:49 +0000946void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
947{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100948 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +0100949
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100950 ValidateNumInputs(workloadInfo, descriptorName, 2);
951 ValidateNumOutputs(workloadInfo, descriptorName, 1);
952
953 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
954 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
955 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
956
957 std::vector<DataType> supportedTypes =
958 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100959 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100960 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +0100961 DataType::QuantisedSymm16,
962 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100963 };
964
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100965 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
966 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
967 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100968
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100969 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
970 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100971
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100972 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
973 inputTensorInfo1,
974 outputTensorInfo,
975 descriptorName,
976 "input_0",
977 "input_1");
telsoa014fcda012018-03-09 14:13:49 +0000978}
979
980void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
981{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100982 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100983
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100984 ValidateNumInputs(workloadInfo, descriptorName, 1);
985 ValidateNumOutputs(workloadInfo, descriptorName, 1);
986
987 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
988 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100989
990 std::vector<DataType> supportedTypes =
991 {
992 DataType::Float16,
993 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100994 DataType::QuantisedAsymm8,
995 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100996 };
997
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100998 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
999 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001000
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001001 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1002 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1003 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001004
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001005 ValidatePointer(m_Mean, descriptorName, "mean");
1006 ValidatePointer(m_Variance, descriptorName, "variance");
1007 ValidatePointer(m_Beta, descriptorName, "beta");
1008 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001009
Matteo Martincigh3122bd52019-06-03 16:54:25 +01001010 const TensorInfo& mean = m_Mean->GetTensorInfo();
1011 const TensorInfo& variance = m_Variance->GetTensorInfo();
1012 const TensorInfo& beta = m_Beta->GetTensorInfo();
1013 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +00001014
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001015 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1016 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1017 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1018 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001019
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001020 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1021 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1022 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +00001023}
1024
1025void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1026{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001027 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001028
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001029 ValidateNumInputs(workloadInfo, descriptorName, 1);
1030 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001031
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001032 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1033 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001034
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001035 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1036 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001037
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001038 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001039
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001040 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1041 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001042
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001043 ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight");
telsoa014fcda012018-03-09 14:13:49 +00001044
1045 if (m_Parameters.m_BiasEnabled)
1046 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001047 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001048
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001049 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
1050 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
1051
1052 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1053 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001054 }
1055
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001056 std::vector<DataType> supportedTypes =
1057 {
Ruomei Yan88d44b82019-05-23 14:29:06 +01001058 DataType::Float32,
1059 DataType::QuantisedAsymm8,
1060 DataType::QuantisedSymm16,
1061 DataType::Float16
1062 };
1063
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001064 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1065 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1066}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001067
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001068void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1069{
1070 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1071
1072 ValidateNumInputs(workloadInfo, descriptorName, 1);
1073 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1074
1075 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1076 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1077
1078 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1079 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1080
1081 ValidatePointer(m_Weight, descriptorName, "weight");
1082
1083 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1084 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1085
1086 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1087 {
1088 throw InvalidArgumentException(
1089 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
1090 "cannot be smaller than 1.") % descriptorName %
1091 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
1092 }
1093
1094 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1095
1096 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1097 // inputChannels * channelMultiplier should be equal to outputChannels.
1098 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1099 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1100 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1101 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1102 {
1103 throw InvalidArgumentException(
1104 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
1105 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
1106 "(provided %4%).") % descriptorName % numWeightOutputChannels %
1107 numWeightInputChannels % numWeightChannelMultiplier));
1108 }
1109
1110 ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight");
1111
1112 if (m_Parameters.m_BiasEnabled)
1113 {
1114 ValidatePointer(m_Bias, descriptorName, "bias");
1115
1116 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
1117 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
1118
1119 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1120 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1121 }
1122
1123 std::vector<DataType> supportedTypes =
1124 {
1125 DataType::Float32,
1126 DataType::QuantisedAsymm8,
1127 DataType::QuantisedSymm16,
1128 DataType::Float16
1129 };
1130
1131 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1132 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001133}
1134
1135void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1136{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001137 const std::string descriptorName{"PermuteQueueDescriptor"};
1138
1139 ValidateNumInputs(workloadInfo, descriptorName, 1);
1140 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001141
1142 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1143
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001144 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1145 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001146
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001147 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1148 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001149
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001150 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001151 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001152 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001153 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001154 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1155 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1156 "must match dst dimension " + to_string(mapping[i]) +
1157 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001158 }
1159 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001160
1161 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001162}
1163
1164void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1165{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001166 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001167
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001168 ValidateNumInputs(workloadInfo, descriptorName, 1);
1169 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1170
1171 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1172 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1173
1174 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1175 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001176
1177 std::vector<DataType> supportedTypes =
1178 {
1179 DataType::Float32,
1180 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001181 DataType::QuantisedAsymm8,
1182 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001183 };
1184
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001185 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1186 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001187}
1188
1189void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1190{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001191 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001192
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001193 ValidateNumInputs(workloadInfo, descriptorName, 1);
1194 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1195
1196 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1197 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1198
1199 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1200 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001201
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001202 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001203 {
1204 DataType::Float16,
1205 DataType::Float32,
1206 DataType::QuantisedAsymm8,
1207 DataType::QuantisedSymm16
1208 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001209
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001210 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1211 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001212
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001213 // ResizeBilinear only changes width and height: batch and channel count must match.
1214 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1215 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001216 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001217 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001218 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001219 boost::str(boost::format("%1%: Input batch size (%2%) "
1220 "does not match output batch size (%3%)") %
1221 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001222 }
1223
Teresa Charlin970f43b2019-07-01 13:51:07 +01001224 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001225 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1226 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001227 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001228 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001229 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001230 boost::str(boost::format("%1%: Input channel count (%2%) "
1231 "does not match output channel count (%3%)") %
1232 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001233 }
1234}
1235
1236void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1237{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001238 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001239
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001240 ValidateNumInputs(workloadInfo, descriptorName, 1);
1241 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1242
1243 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1244 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1245
1246 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1247 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001248
1249 std::vector<DataType> supportedTypes =
1250 {
1251 DataType::Float16,
1252 DataType::Float32,
1253 DataType::QuantisedAsymm8,
1254 DataType::QuantisedSymm16
1255 };
1256
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001257 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1258 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001259
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001260 // Resize only changes width and height: batch and channel count must match.
1261 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1262 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001263 if (inputBatchSize != outputBatchSize)
1264 {
1265 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001266 boost::str(boost::format("%1%: Input batch size (%2%) "
1267 "does not match output batch size (%3%)") %
1268 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001269 }
1270
1271 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001272 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1273 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001274 if (inputChannelCount != outputChannelCount)
1275 {
1276 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001277 boost::str(boost::format("%1%: Input channel count (%2%) "
1278 "does not match output channel count (%3%)") %
1279 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001280 }
1281}
1282
1283void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1284{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001285 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001286
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001287 ValidateNumInputs(workloadInfo, descriptorName, 1);
1288 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1289
1290 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1291 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1292
1293 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1294 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1295
1296 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1297
telsoa014fcda012018-03-09 14:13:49 +00001298 if (m_Parameters.m_Min > m_Parameters.m_Max)
1299 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001300 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001301 }
telsoa014fcda012018-03-09 14:13:49 +00001302}
1303
Kevin Mayce5045a2019-10-02 14:07:47 +01001304void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1305{
1306 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1307
1308 ValidateNumInputs(workloadInfo, descriptorName, 1);
1309 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1310
1311 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1312 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1313
1314 if (inputTensorInfo.GetNumDimensions() > 4)
1315 {
1316 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1317 }
1318
1319 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1320
1321 // Check the supported data types
1322 std::vector<DataType> supportedTypes =
1323 {
1324 DataType::Float32,
1325 DataType::Float16
1326 };
1327
1328 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001329 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001330}
1331
telsoa014fcda012018-03-09 14:13:49 +00001332void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1333{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001334 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001335
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001336 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001337 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1338
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001339 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1340 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1341
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001342 if (inputTensorInfo.GetNumDimensions() > 4)
1343 {
1344 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1345 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001346
1347 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001348
1349 // Check the supported data types
1350 std::vector<DataType> supportedTypes =
1351 {
1352 DataType::Float32,
1353 DataType::Float16,
1354 DataType::QuantisedAsymm8,
1355 DataType::QuantisedSymm16
1356 };
1357
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001358 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001359 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1360}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001361
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001362void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1363{
1364 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1365
1366 ValidateNumInputs(workloadInfo, descriptorName, 1);
1367 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1368
1369 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1370 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1371
1372 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1373
1374 std::vector<DataType> supportedTypes =
1375 {
1376 DataType::Float32,
1377 DataType::Float16,
1378 };
1379
1380 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001381 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001382}
1383
1384void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1385{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001386 const std::string descriptorName{"ConstantQueueDescriptor"};
1387
1388 ValidateNumInputs(workloadInfo, descriptorName, 0);
1389 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001390
1391 if (!m_LayerOutput)
1392 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001393 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001394 }
1395
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001396 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1397 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001398
1399 // Check the supported data types
1400 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001401 {
1402 DataType::Float32,
1403 DataType::Float16,
1404 DataType::Signed32,
1405 DataType::QuantisedAsymm8,
1406 DataType::QuantisedSymm16
1407 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001408
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001409 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001410}
1411
1412void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1413{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001414 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001415
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001416 ValidateNumInputs(workloadInfo, descriptorName, 1);
1417 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1418
1419 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1420 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1421
1422 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001423
1424 // Check the supported data types
1425 std::vector<DataType> supportedTypes =
1426 {
1427 DataType::Float32,
1428 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001429 DataType::Signed32,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001430 DataType::QuantisedAsymm8,
1431 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001432 };
1433
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001434 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1435 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001436}
1437
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001438void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1439{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001440 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001441
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001442 ValidateNumInputs(workloadInfo, descriptorName, 1);
1443 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1444
1445 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1446 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1447
1448 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1449 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001450
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001451 if (m_Parameters.m_BlockShape.size() != 2)
1452 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001453 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001454 }
1455
1456 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1457 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001458 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1459 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001460 }
1461
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001462 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001463
1464 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001465 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001466
Matthew Bentham8800c002018-11-19 13:19:28 +00001467 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001468
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001469 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1470 widthPad.first + widthPad.second;
1471 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1472 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001473
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001474 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1475 inputShape[dimensionIndices.GetChannelsIndex()];
1476 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001477
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001478 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001479 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001480 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001481 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001482 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001483 }
1484
1485 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001486 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001487 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1488 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001489 }
nikraj01120522a2019-05-31 11:33:07 +01001490
1491 std::vector<DataType> supportedTypes =
1492 {
1493 DataType::Float16,
1494 DataType::Float32,
1495 DataType::QuantisedAsymm8,
1496 DataType::QuantisedSymm16
1497 };
1498
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001499 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1500 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001501}
1502
Keith Davisa57eccb2019-06-14 17:33:22 +01001503void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1504{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001505 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001506
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001507 ValidateNumInputs(workloadInfo, descriptorName, 1);
1508 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001509
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001510 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1511 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1512
1513 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1514 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001515
1516 std::vector<DataType> supportedTypes =
1517 {
1518 DataType::Float32,
1519 DataType::Float16,
James Conroyd2aa85e2019-07-01 17:12:40 +01001520 DataType::QuantisedAsymm8,
1521 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001522 };
1523
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001524 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1525 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001526
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001527 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1528
1529 if (m_Parameters.m_BlockSize == 0)
1530 {
1531 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1532 }
1533
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001534 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1535 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1536 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1537 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001538
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001539 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001540 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001541 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001542 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1543 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001544 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001545
1546 const TensorShape& outputShape = outputTensorInfo.GetShape();
1547 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1548 {
1549 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1550 "must be divisible by the square of block size." );
1551 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001552}
1553
telsoa014fcda012018-03-09 14:13:49 +00001554void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1555{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001556 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001557
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001558 ValidateNumInputs(workloadInfo, descriptorName, 1);
1559 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1560
1561 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1562 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001563
1564 std::vector<DataType> supportedTypes =
1565 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001566 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001567 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001568 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +01001569 };
1570
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001571 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001572
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001573 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001574 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001575 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001576 }
1577}
1578
telsoa01c577f2c2018-08-31 09:22:23 +01001579void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1580{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001581 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1582
1583 const std::string descriptorName{"LstmQueueDescriptor"};
1584
1585 // check dimensions of all inputs and outputs
1586 if (workloadInfo.m_InputTensorInfos.size() != 3)
1587 {
1588 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1589 }
1590 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1591 {
1592 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1593 }
1594
1595 std::vector<DataType> supportedTypes =
1596 {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001597 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001598 DataType::Float32,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001599 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001600 };
1601
Jan Eilers38e05bd2019-06-26 13:10:09 +01001602 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001603 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1604
Jan Eilers38e05bd2019-06-26 13:10:09 +01001605 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001606 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001607 {
1608 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1609 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001610 descriptorName,
1611 "input_0",
1612 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001613 }
1614 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001615 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001616 {
1617 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1618 workloadInfo.m_OutputTensorInfos[i],
1619 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001620 "input_0",
1621 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001622 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001623
Jan Eilers38e05bd2019-06-26 13:10:09 +01001624 // TODO: check clipping parameter is valid
1625
1626 // Inferring batch size, number of outputs and number of cells from the inputs.
1627 // TODO: figure out if there is a way to make sure the specific inputs are at that index of workloadInfo
1628 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1629 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1630 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1631 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1632 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1633 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1634
Jan Eilers38e05bd2019-06-26 13:10:09 +01001635 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001636 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1637 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001638 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001639 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1640 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001641 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001642 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1643 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001644 // scratchBufferTensor
1645 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001646 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1647 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001648 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001649 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1650 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001651 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001652 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1653 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001654 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001655 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1656 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001657
1658
1659 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1660 if ( m_InputToInputWeights )
1661 {
1662 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1663 (n_cell * n_input), "InputLayerNormWeights");
1664 }
1665
1666 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1667 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1668 (n_cell * n_input), "InputToForgetWeights");
1669
1670 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1671 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1672 (n_cell * n_input), "InputToCellWeights");
1673
1674 if ( m_RecurrentToInputWeights )
1675 {
1676 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1677 (n_cell * n_output), "RecurrentToInputWeights");
1678 }
1679
1680 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1681 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1682 (n_cell * n_output), "RecurrentToForgetWeights");
1683
1684 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1685 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1686 (n_cell * n_output), "RecurrentToCellWeights");
1687
1688 // Make sure the input-gate's parameters are either both present (regular
1689 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1690 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1691 !m_Parameters.m_CifgEnabled) ||
1692 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1693 m_Parameters.m_CifgEnabled));
1694 if (!cifg_weights_all_or_none)
1695 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001696 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1697 "RecurrentToInputWeights must either both be present (regular LSTM) "
1698 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1699 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001700 }
1701
1702 if ( m_CellToInputWeights )
1703 {
1704 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1705 n_cell, "CellToInputWeights");
1706 }
1707 if ( m_CellToForgetWeights )
1708 {
1709 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1710 n_cell, "CellToForgetWeights");
1711 }
1712 if ( m_CellToOutputWeights )
1713 {
1714 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1715 n_cell, "CellToOutputWeights");
1716 }
1717
1718 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1719 bool peephole_weights_all_or_none =
1720 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1721 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1722 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1723 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1724 if (!peephole_weights_all_or_none)
1725 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001726 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001727 }
1728
1729 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1730 if (m_Parameters.m_CifgEnabled)
1731 {
1732 if (m_InputGateBias)
1733 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001734 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001735 }
1736 }
1737 else
1738 {
1739 if (!m_InputGateBias)
1740 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001741 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1742 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001743 }
1744 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1745 n_cell, "InputGateBias");
1746 }
1747
1748 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1749 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1750
1751 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1752 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1753
1754 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1755 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1756
1757 if (m_ProjectionWeights)
1758 {
1759 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
1760 (n_cell * n_output), "ProjectionWeights");
1761 }
1762 if (m_ProjectionBias)
1763 {
1764 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
1765 }
1766
1767 // Making sure the projection tensors are consistent:
1768 // 1) If projection weight is not present, then projection bias should not be
1769 // present.
1770 // 2) If projection weight is present, then projection bias is optional.
1771 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
1772 !m_Parameters.m_ProjectionEnabled)
1773 || (m_ProjectionWeights && !m_ProjectionBias &&
1774 m_Parameters.m_ProjectionEnabled)
1775 || (m_ProjectionWeights && m_ProjectionBias &&
1776 m_Parameters.m_ProjectionEnabled));
1777 if (!projecton_tensors_consistent)
1778 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001779 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001780 }
1781
1782 // The four layer normalization weights either all have values or none of them have values. Additionally, if
1783 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
1784 // either all have values or none of them have values. Layer normalization is used when the values of all the
1785 // layer normalization weights are present
1786 if (m_InputLayerNormWeights)
1787 {
1788 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
1789 }
1790 if (m_ForgetLayerNormWeights)
1791 {
1792 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1793 }
1794 if (m_CellLayerNormWeights)
1795 {
1796 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1797 }
1798 if (m_OutputLayerNormWeights)
1799 {
1800 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1801 }
1802
Jan Eilers38e05bd2019-06-26 13:10:09 +01001803 if (m_Parameters.m_LayerNormEnabled)
1804 {
1805 if (!m_Parameters.m_CifgEnabled)
1806 {
1807 if (!m_InputLayerNormWeights)
1808 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001809 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
1810 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001811 }
1812 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
1813 1, n_cell, "InputLayerNormWeights");
1814 }
1815 else if (m_InputLayerNormWeights)
1816 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001817 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
1818 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001819 }
1820
1821 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
1822 "ForgetLayerNormWeights");
1823 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1824
1825 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
1826 "OutputLayerNormWeights");
1827 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1828
1829 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
1830 "CellLayerNormWeights");
1831 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1832 }
1833 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
1834 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001835 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
1836 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001837 }
telsoa01c577f2c2018-08-31 09:22:23 +01001838}
1839
1840void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1841{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001842 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001843
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001844 ValidateNumInputs(workloadInfo, descriptorName, 1);
1845 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1846
1847 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1848 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1849
1850 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01001851 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001852 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01001853 }
1854
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001855 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001856 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001857 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001858 }
1859
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001860 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001861}
1862
1863void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1864{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001865 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001866
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001867 ValidateNumInputs(workloadInfo, descriptorName, 1);
1868 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1869
1870 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1871 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1872
1873 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001874 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001875 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001876 }
1877
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001878 if (outputTensorInfo.GetDataType() != DataType::Float32)
1879 {
1880 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
1881 }
1882
1883 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001884}
1885
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001886void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1887{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001888 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001889
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001890 ValidateNumInputs(workloadInfo, descriptorName, 2);
1891 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1892
1893 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1894 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1895 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1896
1897 std::vector<DataType> supportedTypes =
1898 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001899 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001900 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001901 DataType::QuantisedSymm16,
1902 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001903 };
1904
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001905 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1906 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1907 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001908
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001909 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1910 inputTensorInfo1,
1911 outputTensorInfo,
1912 descriptorName,
1913 "input_0",
1914 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001915}
1916
David Beckc2044fe2018-09-05 15:00:38 +01001917void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1918{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001919 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01001920
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001921 ValidateNumInputs(workloadInfo, descriptorName, 2);
1922 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1923
1924 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1925 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1926 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1927
1928 std::vector<DataType> supportedTypes =
1929 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001930 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001931 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001932 DataType::QuantisedSymm16,
1933 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001934 };
1935
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001936 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1937 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1938 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001939
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001940 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1941 inputTensorInfo1,
1942 outputTensorInfo,
1943 descriptorName,
1944 "input_0",
1945 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01001946}
1947
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001948void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1949{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001950 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001951
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001952 ValidateNumInputs(workloadInfo, descriptorName, 2);
1953 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1954
1955 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1956 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1957 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1958
1959 std::vector<DataType> supportedTypes =
1960 {
Mike Kelly1da02362019-08-01 08:43:57 +01001961 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001962 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01001963 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001964 DataType::QuantisedAsymm8,
1965 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001966 };
1967
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001968 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1969 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1970 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001971
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001972 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1973 inputTensorInfo1,
1974 outputTensorInfo,
1975 descriptorName,
1976 "input_0",
1977 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001978}
1979
narpra01a6bf9122018-09-10 09:50:09 +01001980void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1981{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001982 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01001983
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001984 ValidateNumInputs(workloadInfo, descriptorName, 1);
1985 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1986
1987 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1988 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01001989
1990 std::vector<DataType> supportedTypes =
1991 {
1992 DataType::Float32,
1993 DataType::Float16,
1994 DataType::QuantisedAsymm8,
1995 DataType::QuantisedSymm16
1996 };
narpra01eb061912018-09-10 17:35:27 +01001997
James Conroy4d1ff582019-06-10 17:06:39 +01001998 // First check if input tensor data type is supported, then
1999 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002000 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2001 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01002002
narpra0132b90462018-09-13 11:07:48 +01002003 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01002004 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002005 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01002006 }
narpra0132b90462018-09-13 11:07:48 +01002007 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01002008 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002009 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01002010 }
2011 else
2012 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002013 unsigned int outputDim =
2014 inputTensorInfo.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
2015 ValidateTensorNumDimensions(outputTensorInfo,
2016 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01002017 outputDim > 0 ? outputDim : 1,
2018 "output");
2019 }
narpra01a6bf9122018-09-10 09:50:09 +01002020}
2021
jimfly012c9322a2018-09-19 10:59:49 +01002022void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2023{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002024 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01002025
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002026 ValidateNumInputs(workloadInfo, descriptorName, 1);
2027 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2028
2029 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2030 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01002031
jimfly012c9322a2018-09-19 10:59:49 +01002032 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002033 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2034
jimfly012c9322a2018-09-19 10:59:49 +01002035 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002036 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2037 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2038 "as there are dimensions in the input tensor that is " +
2039 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2040 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002041 }
2042}
2043
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002044void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2045{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002046 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002047
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002048 ValidateNumInputs(workloadInfo, descriptorName, 1);
2049 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002050
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002051 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2052 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2053
Sadik Armagan2208b602019-07-31 16:36:27 +01002054 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002055 {
James Conroyd47a0642019-09-17 14:22:06 +01002056 DataType::Float32,
2057 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002058 };
2059
2060 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002061
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002062 if (outputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
2063 outputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002064 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002065 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002066 }
2067}
2068
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002069void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2070{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002071 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002072
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002073 ValidateNumInputs(workloadInfo, descriptorName, 1);
2074 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002075
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002076 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2077 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002078
2079 std::vector<DataType> supportedTypes =
2080 {
James Conroyd47a0642019-09-17 14:22:06 +01002081 DataType::Float32,
2082 DataType::Float16,
2083 DataType::QuantisedAsymm8,
2084 DataType::QuantisedSymm16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002085 };
2086
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002087 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2088 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002089}
2090
Conor Kennedy430b5d82018-11-14 15:28:28 +00002091void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2092{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002093 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002094
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002095 ValidateNumInputs(workloadInfo, descriptorName, 1);
2096 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2097
2098 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2099 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002100
2101 std::vector<DataType> supportedTypes =
2102 {
2103 DataType::Float16,
2104 DataType::Float32,
Matteo Martincigh42666a12019-05-29 08:53:41 +01002105 DataType::QuantisedAsymm8,
2106 DataType::QuantisedSymm16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002107 };
2108
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002109 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2110 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002111
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002112 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002113
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002114 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002115 if (rank > 4)
2116 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002117 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002118 }
2119
Conor Kennedy430b5d82018-11-14 15:28:28 +00002120 // Begin, End & Stride length must be of rank(input0)
2121 if (m_Parameters.m_Begin.size() != rank)
2122 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002123 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002124 }
2125
2126 if (m_Parameters.m_End.size() != rank)
2127 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002128 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002129 }
2130
2131 if (m_Parameters.m_Stride.size() != rank)
2132 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002133 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002134 }
2135
2136 // Stride entries must be non-zero
2137 for (auto& stride : m_Parameters.m_Stride)
2138 {
2139 if (stride == 0)
2140 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002141 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002142 }
2143 }
2144}
2145
kevmay0190539692018-11-29 08:40:19 +00002146void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2147{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002148 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002149
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002150 ValidateNumInputs(workloadInfo, descriptorName, 2);
2151 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2152
2153 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2154 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2155 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2156
2157 std::vector<DataType> supportedTypes =
2158 {
Mike Kelly1da02362019-08-01 08:43:57 +01002159 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002160 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002161 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01002162 DataType::QuantisedAsymm8,
2163 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002164 };
2165
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002166 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2167 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2168 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002169
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002170 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2171 inputTensorInfo1,
2172 outputTensorInfo,
2173 descriptorName,
2174 "input_0",
2175 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002176}
2177
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002178void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2179{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002180 const std::string descriptorName{"DebugQueueDescriptor"};
2181
2182 ValidateNumInputs(workloadInfo, descriptorName, 1);
2183 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002184}
2185
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002186void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2187{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002188 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002189
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002190 ValidateNumInputs(workloadInfo, descriptorName, 2);
2191 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002192
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002193 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2194 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2195 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2196
2197 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2198 inputTensorInfo1,
2199 outputTensorInfo,
2200 descriptorName,
2201 "input_0",
2202 "input_1");
2203
2204 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002205 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002206 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002207 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002208}
2209
FrancisMurtagh878f0232018-12-19 10:56:15 +00002210void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2211{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002212 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002213
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002214 ValidateNumInputs(workloadInfo, descriptorName, 2);
2215 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002216
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002217 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2218 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2219 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2220
2221 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2222 inputTensorInfo1,
2223 outputTensorInfo,
2224 descriptorName,
2225 "input_0",
2226 "input_1");
2227
2228 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002229 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002230 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002231 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002232}
2233
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002234void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2235{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002236 const std::string descriptorName{"RsqrtQueueDescriptor"};
2237
2238 ValidateNumInputs(workloadInfo, descriptorName, 1);
2239 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2240
2241 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2242 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2243
2244 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002245
2246 std::vector<DataType> supportedTypes =
2247 {
James Conroyd47a0642019-09-17 14:22:06 +01002248 DataType::Float16,
2249 DataType::Float32,
2250 DataType::QuantisedAsymm8,
2251 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01002252 };
2253
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002254 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2255 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002256}
2257
narpra01b89b05f2019-01-16 09:53:09 +00002258void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2259{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002260 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002261
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002262 ValidateNumInputs(workloadInfo, descriptorName, 2);
2263 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002264
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002265 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2266 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002267 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002268 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002269 }
2270
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002271 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2272 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2273
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002274 std::vector<DataType> supportedTypes =
2275 {
James Conroyd47a0642019-09-17 14:22:06 +01002276 DataType::Float16,
2277 DataType::Float32,
2278 DataType::QuantisedAsymm8,
2279 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002280 };
2281
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002282 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002283
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002284 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002285
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002286 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2287 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002288}
2289
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002290void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2291{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002292 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2293
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002294 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002295
2296 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2297 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002298 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002299 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2300 }
2301
2302 if (m_Anchors == nullptr)
2303 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002304 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002305 }
2306
2307 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002308 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2309 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2310
2311 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002312 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002313 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2314 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002315
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002316 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2317 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2318 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002319
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002320 const std::vector<DataType> supportedInputTypes =
2321 {
2322 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002323 DataType::Float16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002324 DataType::QuantisedAsymm8,
2325 DataType::QuantisedSymm16
2326 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002327
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002328 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2329 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2330 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2331
2332 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2333 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2334 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2335 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2336
2337 // NOTE: Output is always Float32 regardless of input type
2338 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2339 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2340 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2341 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002342
2343 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2344 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002345 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002346 "must be positive and less than or equal to 1.");
2347 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002348
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002349 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2350 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002351 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002352 "should be equal to number of classes + 1.");
2353 }
2354}
2355
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002356void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2357{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002358 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002359
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002360 ValidateNumInputs(workloadInfo, descriptorName, 1);
2361 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2362
2363 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2364 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2365
2366 if (inputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
2367 inputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002368 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002369 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002370 }
2371
Sadik Armagan2208b602019-07-31 16:36:27 +01002372 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002373 {
James Conroyd47a0642019-09-17 14:22:06 +01002374 DataType::Float32,
2375 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002376 };
2377
2378 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002379}
2380
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002381void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2382{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002383 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002384
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002385 ValidateNumInputs(workloadInfo, descriptorName, 2);
2386 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002387
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002388 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2389 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2390 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002391
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002392 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2393 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2394
2395 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2396 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002397}
2398
Sadik Armaganeff363d2019-04-05 15:25:46 +01002399void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2400{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002401 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002402
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002403 ValidateNumInputs(workloadInfo, descriptorName, 2);
2404 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2405
2406 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2407 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2408
2409 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2410 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2411
2412 std::vector<DataType> supportedTypes =
2413 {
Sadik Armaganeff363d2019-04-05 15:25:46 +01002414 DataType::Float32,
2415 DataType::QuantisedAsymm8,
2416 DataType::QuantisedSymm16
2417 };
2418
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002419 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2420 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002421
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002422 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2423 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002424
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002425 ValidateTensorShapesMatch(inputTensorInfo0,
2426 outputTensorInfo0,
2427 descriptorName,
2428 "input_0",
2429 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002430
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002431 ValidateTensorShapesMatch(inputTensorInfo0,
2432 outputTensorInfo1,
2433 descriptorName,
2434 "input_0",
2435 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002436}
2437
Matteo Martincigh49124022019-01-11 13:25:59 +00002438void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2439{
2440 // This is internally generated so it should not need validation.
2441}
2442
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002443void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2444{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002445 const std::string& descriptorName{"PreluQueueDescriptor"};
2446
2447 ValidateNumInputs(workloadInfo, descriptorName, 2);
2448 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2449
2450 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2451 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2452 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002453
2454 std::vector<DataType> supportedTypes
2455 {
2456 DataType::Float16,
2457 DataType::Float32,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002458 DataType::QuantisedAsymm8,
2459 DataType::QuantisedSymm16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002460 };
2461
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002462 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2463 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002464
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002465 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002466
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002467 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2468 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002469
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002470 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2471 alphaTensorInfo,
2472 outputTensorInfo,
2473 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002474 "input",
2475 "alpha");
2476}
2477
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002478void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2479{
2480 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2481
2482 ValidateNumInputs(workloadInfo, descriptorName, 1);
2483 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2484
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002485 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2486 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2487
2488 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2489 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002490
2491 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002492
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002493 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2494 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
2495 ValidateTensorDataType(weightTensorInfo, inputTensorInfo.GetDataType(), descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002496
2497 if (m_Parameters.m_BiasEnabled)
2498 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002499 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002500
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002501 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
2502 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002503
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002504 ValidateTensorDataType(biasTensorInfo,
2505 GetBiasDataType(inputTensorInfo.GetDataType()),
2506 descriptorName,
2507 "bias");
2508
2509 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002510 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002511}
2512
James Conroy9c3cae82019-08-01 16:01:48 +01002513void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2514{
2515 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
2516
2517 // Validate number of inputs/outputs
2518 ValidateNumInputs(workloadInfo, descriptorName, 3);
2519 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2520
2521 // Input/output tensor infos
2522 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2523 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
2524 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
2525
2526 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2527 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2528
2529 std::vector<DataType> inputOutputSupportedTypes =
2530 {
2531 DataType::QuantisedAsymm8
2532 };
2533
2534 std::vector<DataType> cellStateSupportedTypes =
2535 {
2536 DataType::QuantisedSymm16
2537 };
2538
2539 std::vector<DataType> weightsSupportedTypes =
2540 {
2541 DataType::QuantisedAsymm8
2542 };
2543
2544 std::vector<DataType> biasSupportedTypes =
2545 {
2546 DataType::Signed32
2547 };
2548
2549 // Validate types of input/output tensors
2550 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2551 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2552 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2553
2554 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2555 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2556
2557 // Validate matching types of input/output tensors
2558 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2559 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2560 "outputStateIn", "outputStateOut");
2561 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2562
2563 // Validate matching quantization info for input/output tensors
2564 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2565 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
2566 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002567
James Conroy9c3cae82019-08-01 16:01:48 +01002568 // Infer number of batches, input size and output size from tensor dimensions
2569 const uint32_t numBatches = inputInfo.GetShape()[0];
2570 const uint32_t inputSize = inputInfo.GetShape()[1];
2571 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
2572
2573 // Validate number of dimensions and number of elements for input/output tensors
2574 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2575 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
2576 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2577 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
2578 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2579
2580 // Validate number of dimensions and number of elements for weights tensors
2581 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
2582 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
2583 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
2584
2585 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2586 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2587 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
2588
2589 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2590 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2591 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
2592
2593 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2594 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2595 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
2596
2597 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
2598 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
2599 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
2600
2601 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2602 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2603 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
2604 " RecurrentToForgetWeights");
2605
2606 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2607 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2608 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2609
2610 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2611 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2612 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2613
2614 // Validate data types for weights tensors (all should match each other)
2615 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
2616
2617 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
2618 "inputToInputWeights", "inputToForgetWeights");
2619 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
2620 "inputToInputWeights", "inputToCellWeights");
2621 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2622 "inputToInputWeights", "inputToOutputWeights");
2623
2624 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
2625 "inputToInputWeights", "recurrentToInputWeights");
2626 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2627 "inputToInputWeights", "recurrentToForgeteights");
2628 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2629 "inputToInputWeights", "recurrentToCellWeights");
2630 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2631 "inputToInputWeights", "recurrentToOutputWeights");
2632
2633 // Validate matching quantization info for weight tensors (all should match each other)
2634 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
2635 descriptorName, "inputToInputWeights", "inputToForgetWeights");
2636 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
2637 descriptorName, "inputToInputWeights", "inputToCellWeights");
2638 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
2639 descriptorName, "inputToInputWeights", "inputToOutputWeights");
2640
2641 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
2642 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
2643 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
2644 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
2645 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
2646 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
2647 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
2648 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
2649
2650 // Validate number of dimensions and number of elements in bias tensors
2651 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
2652 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
2653 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
2654
2655 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2656 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2657 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
2658
2659 ValidatePointer(m_CellBias, descriptorName, "CellBias");
2660 auto cellBiasInfo = m_CellBias->GetTensorInfo();
2661 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
2662
2663 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
2664 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
2665 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
2666
2667 // Validate data types for bias tensors (all should match each other)
2668 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
2669
2670 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
2671 "inputGateBias", "forgetGateBias");
2672 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
2673 "inputGateBias", "cellBias");
2674 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
2675 "inputGateBias", "outputGateBias");
2676
2677 // Validate bias tensor quantization info
2678 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2679 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2680 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2681 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2682}
2683
Kevin May868eb142019-09-04 17:29:31 +01002684void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2685{
2686 const std::string descriptorName{"AbsQueueDescriptor"};
2687
2688 ValidateNumInputs(workloadInfo, descriptorName, 1);
2689 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2690
2691 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2692 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2693
2694 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2695
2696 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01002697 {
2698 DataType::Float16,
2699 DataType::Float32,
2700 DataType::QuantisedAsymm8,
2701 DataType::QuantisedSymm16
2702 };
Kevin May868eb142019-09-04 17:29:31 +01002703
2704 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2705 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2706}
2707
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002708void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2709{
2710 const std::string descriptorName{"SliceQueueDescriptor"};
2711
2712 ValidateNumInputs(workloadInfo, descriptorName, 1);
2713 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2714
2715 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2716 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2717
2718 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2719
2720 const unsigned int rank = inputTensorInfo.GetNumDimensions();
2721 if (rank > 4)
2722 {
2723 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
2724 }
2725
2726 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
2727
2728 // Check if m_Begin and m_Size have the expected length
2729 if (m_Parameters.m_Begin.size() != rank)
2730 {
2731 throw InvalidArgumentException(descriptorName +
2732 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
2733 }
2734 if (m_Parameters.m_Size.size() != rank)
2735 {
2736 throw InvalidArgumentException(descriptorName +
2737 ": Length of size descriptor must equal rank " + std::to_string(rank));
2738 }
2739
2740 // Check if the shape of the output tensor matches m_Size
2741 const TensorShape& outputShape = outputTensorInfo.GetShape();
2742 for (unsigned int i = 0u; i < rank; ++i)
2743 {
2744 if (m_Parameters.m_Size[i] != outputShape[i])
2745 {
2746 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
2747 }
2748 }
2749
2750 // Check if the sum of begin offset and size in a given dimension
2751 // does not exceed the size of corresponding input
2752 const TensorShape& inputShape = inputTensorInfo.GetShape();
2753 for(unsigned int i = 0u; i < rank; ++i)
2754 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002755 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002756 {
2757 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
2758 std::to_string(i) + " exceeds input size.");
2759 }
2760 }
2761}
2762
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01002763void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2764{
2765 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
2766
2767 ValidateNumInputs(workloadInfo, descriptorName, 1);
2768 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2769
2770 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
2771 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
2772
2773 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
2774 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
2775
2776 std::vector<DataType> supportedTypes =
2777 {
2778 DataType::Float32,
2779 DataType::Float16,
2780 DataType::QuantisedAsymm8,
2781 DataType::QuantisedSymm16
2782 };
2783
2784 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
2785 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
2786
2787 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
2788
2789 if (m_Parameters.m_BlockSize == 0)
2790 {
2791 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
2792 }
2793
2794 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
2795 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
2796 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
2797 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
2798
2799 const TensorShape& outputShape = outputInfo.GetShape();
2800 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
2801 {
2802 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
2803 "must be divisible by block size.");
2804 }
2805
2806 const TensorShape& inputShape = inputInfo.GetShape();
2807 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
2808 {
2809 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
2810 "must be divisible by the square of block size." );
2811 }
2812}
2813
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01002814void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2815{
2816 const std::string descriptorName{"ComparisonQueueDescriptor"};
2817
2818 ValidateNumInputs(workloadInfo, descriptorName, 2);
2819 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2820
2821 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2822 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2823 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2824
2825 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2826 inputTensorInfo1,
2827 outputTensorInfo,
2828 descriptorName,
2829 "input_0",
2830 "input_1");
2831
2832 if (outputTensorInfo.GetDataType() != DataType::Boolean)
2833 {
2834 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
2835 }
2836}
2837
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002838} // namespace armnn