blob: cfb38b48204e9dd5975827ea92572138f2a54c47 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "WorkloadData.hpp"
6
7#include "CpuTensorHandle.hpp"
telsoa014fcda012018-03-09 14:13:49 +00008
Matteo Martincigh21350152018-11-28 16:22:22 +00009#include <DataLayoutIndexed.hpp>
Matthew Bentham8800c002018-11-19 13:19:28 +000010
telsoa014fcda012018-03-09 14:13:49 +000011#include <algorithm>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000013#include <string>
14#include <sstream>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <boost/format.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010017#include <boost/numeric/conversion/cast.hpp>
James Conroyc8724c72019-10-08 15:41:34 +010018#include <TensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000019
Matteo Martincigh21350152018-11-28 16:22:22 +000020using namespace armnnUtils;
21
telsoa014fcda012018-03-09 14:13:49 +000022namespace armnn
23{
24
25//---------------------------------------------------------------
26DataType GetBiasDataType(DataType inputDataType)
27{
28 switch (inputDataType)
29 {
telsoa01c577f2c2018-08-31 09:22:23 +010030 case DataType::Float16:
31 return DataType::Float16;
telsoa014fcda012018-03-09 14:13:49 +000032 case DataType::Float32:
33 return DataType::Float32;
34 case DataType::QuantisedAsymm8:
35 return DataType::Signed32;
Ruomei Yan88d44b82019-05-23 14:29:06 +010036 case DataType::QuantisedSymm16:
37 return DataType::Signed32;
telsoa014fcda012018-03-09 14:13:49 +000038 default:
39 BOOST_ASSERT_MSG(false, "Invalid input data type");
40 return DataType::Float32;
41 }
42}
43
44namespace
45{
46
47//---------------------------------------------------------------
48//android ndk does not support std::to_string function.
49template <typename T>
50std::string to_string(T value)
51{
52 std::ostringstream os;
53 os << value;
54 return os.str();
55}
56
57//---------------------------------------------------------------
58void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
59{
60 if (!ptr)
61 {
62 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
63 paramName + " parameter must be set.");
64 }
65}
66
67//---------------------------------------------------------------
68void ValidateTensorShapesMatch(const TensorInfo& first,
69 const TensorInfo& second,
70 std::string const& descName,
71 std::string const& firstName,
72 std::string const& secondName)
73{
74 if (first.GetShape() != second.GetShape())
75 {
76 throw InvalidArgumentException(descName + ": "
77 + firstName + " & " + secondName + " must have identical shapes");
78 }
79}
80
81//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010082void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000083{
Sadik Armaganeff363d2019-04-05 15:25:46 +010084 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000085 {
86 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010087 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
telsoa014fcda012018-03-09 14:13:49 +000088 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
89 }
90}
91
92//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +010093void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000094{
Sadik Armaganeff363d2019-04-05 15:25:46 +010095 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
telsoa014fcda012018-03-09 14:13:49 +000096 {
97 throw InvalidArgumentException(descName +
Sadik Armaganeff363d2019-04-05 15:25:46 +010098 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
telsoa014fcda012018-03-09 14:13:49 +000099 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
100 }
101}
102
103//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100104void ValidateTensorNumDimensions(const TensorInfo& tensor,
telsoa014fcda012018-03-09 14:13:49 +0000105 std::string const& descName,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100106 unsigned int numDimensions,
telsoa014fcda012018-03-09 14:13:49 +0000107 std::string const& tensorName)
108{
109 if (tensor.GetNumDimensions() != numDimensions)
110 {
111 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
112 to_string(tensor.GetNumDimensions()) + " dimensions for " +
113 tensorName + " tensor.");
114 }
115}
116
117//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100118void ValidateTensorNumElements(const TensorInfo& tensor,
119 std::string const& descName,
120 unsigned int numElements,
121 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100122{
123 if (tensor.GetNumElements() != numElements)
124 {
125 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
James Conroyceda7852019-08-22 11:41:07 +0100126 to_string(tensor.GetNumElements()) + " elements for " +
Jan Eilers38e05bd2019-06-26 13:10:09 +0100127 tensorName + " tensor.");
128 }
129}
130
131//---------------------------------------------------------------
132void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100133 unsigned int numDimension,
134 unsigned int numElements,
135 std::string const& tensorName)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100136{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100137 const std::string functionName{"ValidateTensorNumDimNumElem"};
138 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
139 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100140}
141
142//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000143void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
144 const std::string& descName, std::string const& tensorName)
145{
146 if (tensor.GetDataType() != dataType)
147 {
148 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
149 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
150 }
151}
152
153//---------------------------------------------------------------
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100154void ValidateTensorQuantizationSpace(const TensorInfo& first,
155 const TensorInfo& second,
156 const std::string& descName,
157 std::string const& firstName,
158 std::string const& secondName)
159{
160 if (!first.IsQuantized() ||
161 !second.IsQuantized())
162 {
163 // Not a quantized type, ignore the validation
164 return;
165 }
166
167 DataType firstDataType = first.GetDataType();
168 DataType secondDataType = second.GetDataType();
169
170 if (firstDataType != secondDataType)
171 {
172 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
173 " must be of the same quantized type, " +
174 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
175 secondName + " is " + GetDataTypeName(secondDataType));
176 }
177
178 if (!first.IsTypeSpaceMatch(second))
179 {
180 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
181 " must have the same quantization space, " +
182 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
183 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
184 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
185 " and scale " + to_string(second.GetQuantizationScale()));
186 }
187}
188
189//---------------------------------------------------------------
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100190void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
191 const TensorInfo& inputTensorInfo,
192 const TensorInfo& weightsTensorInfo,
193 const std::string& descName)
telsoa014fcda012018-03-09 14:13:49 +0000194{
195 if (biasTensor.GetQuantizationOffset() != 0)
196 {
197 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
198 to_string(biasTensor.GetQuantizationOffset()));
199 }
200 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
kevmay016c46dd32018-12-17 15:32:45 +0000201 if (std::abs(biasTensor.GetQuantizationScale() - expectedScale) > 0.00000001f)
telsoa014fcda012018-03-09 14:13:49 +0000202 {
203 // Print the float values with extra precision to see very small differences
204 std::stringstream msg;
205 msg << std::setprecision(10) << descName << ": Expected " << expectedScale <<
206 " quantization scale for bias tensor (the product of the input and weight scales), but got " <<
207 biasTensor.GetQuantizationScale();
208 throw InvalidArgumentException(msg.str());
209 }
210}
211
212//---------------------------------------------------------------
213void ValidateTensors(const std::vector<ITensorHandle*>& vec,
214 unsigned int numExpected,
215 const std::string& descName,
216 const std::string& varName)
217{
218 if (vec.empty() && numExpected > 0)
219 {
220 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
221 }
222
223 for (unsigned int i = 0; i < numExpected; ++i)
224 {
225 if (!vec[i])
226 {
227 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
228 }
229 }
230}
231
232//---------------------------------------------------------------
233void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
234 const TensorInfo& second,
235 const TensorInfo& output,
236 std::string const& descName,
237 std::string const& firstName,
238 std::string const& secondName)
239{
240 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
241 // broadcasted.
242 if (first.GetNumDimensions() != second.GetNumDimensions())
243 {
244 throw InvalidArgumentException(descName + ": Tensors "
245 + firstName + " & " + secondName
246 + " must have the same number of dimensions in order to be broadcasted");
247 }
248 uint32_t numDims = first.GetNumDimensions();
249 std::vector<uint32_t> outputDims(numDims, 0u);
250 for (uint32_t i = 0; i < numDims; i++)
251 {
252 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
253 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
254 if (dimsNotEqual && dimsNotOne)
255 {
256 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
257 }
258 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
259 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100260 TensorShape broadcastShape = TensorShape(boost::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
telsoa014fcda012018-03-09 14:13:49 +0000261 if (broadcastShape != output.GetShape())
262 {
263 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
264 + firstName + " & " + secondName
265 + " does not match the output shape");
266 }
267}
268
269//---------------------------------------------------------------
Sadik Armaganeff363d2019-04-05 15:25:46 +0100270void ValidateDataTypes(const TensorInfo& info,
271 const std::vector<armnn::DataType>& supportedTypes,
272 std::string const& descName)
273{
274 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
275 if (iterator == supportedTypes.end())
276 {
277 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
278 }
279}
280
James Conroy4d1ff582019-06-10 17:06:39 +0100281//---------------------------------------------------------------
282void ValidateTensorDataTypesMatch(const TensorInfo& first,
283 const TensorInfo& second,
284 std::string const& descName,
285 std::string const& firstName,
286 std::string const& secondName)
287{
288 if (first.GetDataType() != second.GetDataType())
289 {
290 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
291 " must have identical data types.");
292 }
293}
294
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100295//---------------------------------------------------------------
296void ValidateTensorNumElementsMatch(const TensorInfo& first,
297 const TensorInfo& second,
298 std::string const& descName,
299 std::string const& firstName,
300 std::string const& secondName)
301{
302 if (first.GetNumElements() != second.GetNumElements())
303 {
304 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
305 " must have the same number of elements.");
306 }
307}
308
309} // anonymous namespace
telsoa014fcda012018-03-09 14:13:49 +0000310
311void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
312 unsigned int numExpectedIn, unsigned int numExpectedOut) const
313{
314 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
315 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
316}
317
318//---------------------------------------------------------------
319void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
320{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100321 const std::string descriptorName{"MemCopyQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000322
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100323 ValidateNumInputs(workloadInfo, descriptorName, 1);
324 ValidateNumOutputs(workloadInfo, descriptorName , 1);
telsoa014fcda012018-03-09 14:13:49 +0000325
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100326 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
327 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
328
329 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
330 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000331
332 if (m_Inputs.size() != m_Outputs.size())
333 {
334 throw InvalidArgumentException(boost::str(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100335 boost::format("%1%: Number of inputs (%2%) does not match the number of outputs (%3%).") %
336 descriptorName % m_Inputs.size() % m_Outputs.size()));
telsoa014fcda012018-03-09 14:13:49 +0000337 }
338
339 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
340 {
341 if (!m_Inputs[i])
342 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100343 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL input %2%.") %
344 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000345 }
346
347 if (!m_Outputs[i])
348 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100349 throw InvalidArgumentException(boost::str(boost::format("%1%: Invalid NULL output %2%") %
350 descriptorName % i));
telsoa014fcda012018-03-09 14:13:49 +0000351 }
352 }
353}
354
Derek Lambertif674aa02019-08-01 15:56:25 +0100355//---------------------------------------------------------------
356void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
357{
358 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
359 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
360
361 if (workloadInfo.m_InputTensorInfos.size() != 1)
362 {
363 throw InvalidArgumentException(boost::str(
364 boost::format("Number of input infos (%1%) is not 1.")
365 % workloadInfo.m_InputTensorInfos.size()));
366
367 }
368
369 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
370 {
371 throw InvalidArgumentException(boost::str(
372 boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
373 % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
374 }
375
376 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
377 {
378 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
379 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
380 {
381 throw InvalidArgumentException(boost::str(
382 boost::format("Number of elements for tensor input and output %1% does not match")
383 % i ));
384 }
385 }
386
387 if (m_Inputs.size() != 1)
388 {
389 throw InvalidArgumentException(boost::str(
390 boost::format("Number of inputs (%1%) is not 1.")
391 % m_Inputs.size()));
392 }
393
394 if (m_Inputs.size() != m_Outputs.size())
395 {
396 throw InvalidArgumentException(boost::str(
397 boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
398 % m_Inputs.size() % m_Outputs.size()));
399 }
400
401 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
402 {
403 if (!m_Inputs[i])
404 {
405 throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
406 }
407
408 if (!m_Outputs[i])
409 {
410 throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
411 }
412 }
413}
414
415//---------------------------------------------------------------
416void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
417{
418 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
419 ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
420
Derek Lambertif674aa02019-08-01 15:56:25 +0100421 if (m_Inputs.size() != 1)
422 {
423 throw InvalidArgumentException(boost::str(
424 boost::format("Number of inputs (%1%) is not 1.")
425 % m_Inputs.size()));
426 }
427
428 if (m_Outputs.size() != 0)
429 {
430 throw InvalidArgumentException(boost::str(
431 boost::format("Number of outputs (%1%) is not 0.")
432 % m_Inputs.size() % m_Outputs.size()));
433 }
434
435 if (!m_Inputs[0])
436 {
437 throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
438 }
439}
440
441//---------------------------------------------------------------
telsoa014fcda012018-03-09 14:13:49 +0000442void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
443{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100444 const std::string descriptorName{"ActivationQueueDescriptor"};
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100445
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100446 ValidateNumInputs(workloadInfo, descriptorName, 1);
447 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +0100448
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100449 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
450 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
nikraj01248683f2019-05-29 16:46:50 +0100451
452 std::vector<DataType> supportedTypes =
453 {
James Conroyd47a0642019-09-17 14:22:06 +0100454 DataType::Float16,
455 DataType::Float32,
456 DataType::QuantisedAsymm8,
457 DataType::QuantisedSymm16
nikraj01248683f2019-05-29 16:46:50 +0100458 };
459
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100460 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
461 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
462 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000463}
464
Nikhil Rajee391d52019-09-05 17:50:44 +0100465void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
466{
467 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
468
469 ValidateNumInputs(workloadInfo, descriptorName, 1);
470 ValidateNumOutputs(workloadInfo, descriptorName, 1);
471
472 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
473 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
474
Nikhil Raj68c2c902019-09-19 11:21:11 +0100475 if (outputTensorInfo.GetDataType() != DataType::Signed32)
476 {
477 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32.");
478 }
479
James Conroyd47a0642019-09-17 14:22:06 +0100480 std::vector<DataType> supportedInputTypes =
481 {
482 DataType::Float16,
483 DataType::Float32,
484 DataType::QuantisedAsymm8,
485 DataType::QuantisedSymm16
486 };
Nikhil Rajee391d52019-09-05 17:50:44 +0100487
James Conroyd47a0642019-09-17 14:22:06 +0100488 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
James Conroyc8724c72019-10-08 15:41:34 +0100489
490 auto inputShape = inputTensorInfo.GetShape();
491 auto outputShape = outputTensorInfo.GetShape();
492
493 auto inputNumDimensions = inputShape.GetNumDimensions();
494 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
495
496 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
497
498 // 1D input shape results in scalar output shape
499 if (inputShape.GetNumDimensions() == 1)
500 {
501 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
502 {
503 throw InvalidArgumentException(descriptorName + outputShapeError);
504 }
505 }
506 else
507 {
508 for (unsigned int i = 0; i < unsignedAxis; ++i)
509 {
510 if (outputShape[i] != inputShape[i])
511 {
512 throw InvalidArgumentException(descriptorName + outputShapeError);
513 }
514 }
515
516 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
517 {
518 if (outputShape[i - 1] != inputShape[i])
519 {
520 throw InvalidArgumentException(descriptorName + outputShapeError);
521 }
522 }
523 }
Nikhil Rajee391d52019-09-05 17:50:44 +0100524}
525
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100526void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
527{
528 const std::string descriptorName{"SoftmaxQueueDescriptor"};
529
530 ValidateNumInputs(workloadInfo, descriptorName, 1);
531 ValidateNumOutputs(workloadInfo, descriptorName, 1);
532
533 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
534 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
535
536 std::vector<DataType> supportedTypes =
537 {
James Conroyd47a0642019-09-17 14:22:06 +0100538 DataType::Float16,
539 DataType::Float32,
540 DataType::QuantisedAsymm8,
541 DataType::QuantisedSymm16
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100542 };
543
544 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
545 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
546 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
547}
548
telsoa014fcda012018-03-09 14:13:49 +0000549void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
550{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100551 const std::string descriptorName{"SplitterQueueDescriptor"};
552
553 ValidateNumInputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000554
Ruomei Yan25339c32019-05-28 16:48:20 +0100555 // Check the supported data types
556 std::vector<DataType> supportedTypes =
557 {
James Conroyd47a0642019-09-17 14:22:06 +0100558 DataType::Float32,
559 DataType::Float16,
560 DataType::Boolean,
561 DataType::Signed32,
562 DataType::QuantisedAsymm8,
563 DataType::QuantisedSymm16
Ruomei Yan25339c32019-05-28 16:48:20 +0100564 };
565
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100566 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
567 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Ruomei Yan25339c32019-05-28 16:48:20 +0100568 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100569 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
570 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
571
572 const std::string outputName = "output_" + std::to_string(i);
573 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
Ruomei Yan25339c32019-05-28 16:48:20 +0100574 }
Ruomei Yan25339c32019-05-28 16:48:20 +0100575
telsoa014fcda012018-03-09 14:13:49 +0000576 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
577 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100578 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000579 }
580
581 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
582 {
583 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100584 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000585 "has to match number of workloadInfo.m_OutputTensorInfos. "
586 "Number of windows: " +
587 to_string(m_ViewOrigins.size()) +
588 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
589 }
590
telsoa01c577f2c2018-08-31 09:22:23 +0100591 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
telsoa014fcda012018-03-09 14:13:49 +0000592 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
593 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
594 {
telsoa01c577f2c2018-08-31 09:22:23 +0100595 //Checks that the dimensionality of input is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000596 ViewOrigin const& e = m_ViewOrigins[w];
597 if (e.m_Origin.size() != inputDims)
598 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100599 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000600 "have the same dimensionality as the input tensor. "
601 "Window origin (index: " +
602 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
603 " dimensions, the input "
604 "tensor has " +
605 to_string(inputDims) + " dimensions.");
606 }
607 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
608 {
609 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
610 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
611 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100612 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000613 "be smaller or equal than the size of the input in that coord.");
614 }
615 }
616 }
617}
618
Jim Flynne242f2d2019-05-22 14:24:13 +0100619void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
telsoa014fcda012018-03-09 14:13:49 +0000620{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100621 const std::string descriptorName{"ConcatQueueDescriptor"};
622
623 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000624
625 if (m_Inputs.size() <= 0)
626 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100627 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000628 }
629 if (m_Outputs.size() <= 0)
630 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100631 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000632 }
633
634 if (workloadInfo.m_InputTensorInfos.size() <= 0)
635 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100636 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000637 }
638 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
639 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100640 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
telsoa014fcda012018-03-09 14:13:49 +0000641 }
642
Nikhil Raj8599a412018-11-19 14:51:07 +0000643 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
644 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100645 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
Nikhil Raj8599a412018-11-19 14:51:07 +0000646 }
647
648 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
649 {
650 return;
651 }
652
telsoa014fcda012018-03-09 14:13:49 +0000653 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
654 {
655 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100656 descriptorName + ": Number of split windows "
telsoa014fcda012018-03-09 14:13:49 +0000657 "has to match number of workloadInfo.m_InputTensorInfos. "
658 "Number of windows: " +
659 to_string(m_ViewOrigins.size()) +
660 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
661 }
662
telsoa01c577f2c2018-08-31 09:22:23 +0100663 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
telsoa014fcda012018-03-09 14:13:49 +0000664 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
665 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
666 {
telsoa01c577f2c2018-08-31 09:22:23 +0100667 //Checks that the dimensionality of output is same as the split windows.
telsoa014fcda012018-03-09 14:13:49 +0000668 ViewOrigin const& e = m_ViewOrigins[w];
669 if (e.m_Origin.size() != outputDims)
670 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100671 throw InvalidArgumentException(descriptorName + ": Window origin have to "
telsoa014fcda012018-03-09 14:13:49 +0000672 "have the same dimensionality as the output tensor. "
673 "Window origin (index: " +
674 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
675 " dimensions, the output "
676 "tensor has " +
677 to_string(outputDims) + " dimensions.");
678 }
telsoa01c577f2c2018-08-31 09:22:23 +0100679 //Checks that the merge windows are within the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000680 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
681 {
682 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
683 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
684 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100685 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
telsoa014fcda012018-03-09 14:13:49 +0000686 "be smaller or equal than the size of the output in that coord.");
687 }
688 }
689 }
Jim Flynncbb66aa2019-05-15 13:03:54 +0100690
691 // Check the supported data types
692 std::vector<DataType> supportedTypes =
693 {
James Conroyd47a0642019-09-17 14:22:06 +0100694 DataType::Float32,
695 DataType::Float16,
696 DataType::Boolean,
697 DataType::Signed32,
698 DataType::QuantisedAsymm8,
699 DataType::QuantisedSymm16
Jim Flynncbb66aa2019-05-15 13:03:54 +0100700 };
701
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100702 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
703 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jim Flynncbb66aa2019-05-15 13:03:54 +0100704 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100705 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
706 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
707
708 const std::string inputName = "input_" + std::to_string(i);
709 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
Jim Flynncbb66aa2019-05-15 13:03:54 +0100710 }
telsoa014fcda012018-03-09 14:13:49 +0000711}
712
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100713void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
714{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100715 const std::string descriptorName{"StackQueueDescriptor"};
716
717 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100718
719 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
720 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100721 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100722 }
723
724 // All inputs must have the same shape, which is defined in parameters
725 const TensorShape& inputShape = m_Parameters.m_InputShape;
726 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
727 {
728 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
729 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100730 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100731 }
732 }
733
Matthew Jacksondba634f2019-08-15 15:14:18 +0100734 if (inputShape.GetNumDimensions() > 4)
735 {
736 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
737 }
738
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100739 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
740 // since the output tensor has an additional dimension.
741 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
742 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100743 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100744 "than the number of input dimensions.");
745 }
746
747 // Output shape must be as inferred from the input shape
748 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
749 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
750 {
751 if (outputShape[i] != inputShape[i])
752 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100753 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100754 "match shape inferred from input tensor.");
755 }
756 }
757
758 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
759 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100760 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100761 "match shape inferred from input tensor.");
762 }
763
764 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
765 {
766 if (outputShape[i] != inputShape[i-1])
767 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100768 throw InvalidArgumentException(descriptorName + ": Output tensor must "
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100769 "match shape inferred from input tensor.");
770 }
771 }
772
Matthew Jacksondba634f2019-08-15 15:14:18 +0100773 if (outputShape.GetNumDimensions() > 5)
774 {
775 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
776 }
777
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100778 // Check the supported data types
779 std::vector<DataType> supportedTypes =
780 {
James Conroyd47a0642019-09-17 14:22:06 +0100781 DataType::Float32,
782 DataType::Float16,
783 DataType::Boolean,
784 DataType::Signed32,
785 DataType::QuantisedAsymm8,
786 DataType::QuantisedSymm16
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100787 };
788
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100789 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100790
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100791 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100792 {
793 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
794 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100795 descriptorName,
796 "input_0",
797 "input_" + std::to_string(i));
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100798 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100799
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100800 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
801 workloadInfo.m_OutputTensorInfos[0],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100802 descriptorName,
803 "input_0",
804 "output");
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100805}
806
telsoa014fcda012018-03-09 14:13:49 +0000807void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
808{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100809 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000810
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100811 ValidateNumInputs(workloadInfo, descriptorName, 1);
812 ValidateNumOutputs(workloadInfo, descriptorName, 1);
813
814 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
815 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
816
817 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
818
819 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
telsoa014fcda012018-03-09 14:13:49 +0000820 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100821 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
telsoa014fcda012018-03-09 14:13:49 +0000822 }
823
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100824 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000825
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100826 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
827 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
telsoa014fcda012018-03-09 14:13:49 +0000828
829 if (m_Parameters.m_BiasEnabled)
830 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100831 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000832
telsoa01c577f2c2018-08-31 09:22:23 +0100833 // Validates type and quantization values.
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100834 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
835 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +0000836
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100837 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
838 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
telsoa014fcda012018-03-09 14:13:49 +0000839 }
840
Francis Murtagh46c09d02019-05-28 08:15:28 +0100841 // Check the supported data types
842 std::vector<DataType> supportedTypes =
843 {
James Conroyd47a0642019-09-17 14:22:06 +0100844 DataType::Float32,
845 DataType::Float16,
846 DataType::QuantisedAsymm8,
847 DataType::QuantisedSymm16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100848 };
849
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100850 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
851 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000852}
853
telsoa014fcda012018-03-09 14:13:49 +0000854void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
855{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100856 const std::string descriptorName{"NormalizationQueueDescriptor"};
857
858 ValidateNumInputs(workloadInfo, descriptorName, 1);
859 ValidateNumOutputs(workloadInfo, descriptorName, 1);
860
861 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
862 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100863
864 // Check the supported data types
865 std::vector<DataType> supportedTypes =
866 {
867 DataType::Float16,
868 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100869 DataType::QuantisedAsymm8,
870 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100871 };
872
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100873 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100874
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100875 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100876
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100877 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +0000878}
879
880void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
881{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100882 const std::string descriptorName{"AdditionQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000883
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100884 ValidateNumInputs(workloadInfo, descriptorName, 2);
885 ValidateNumOutputs(workloadInfo, descriptorName, 1);
886
887 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
888 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
889 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
890
891 std::vector<DataType> supportedTypes =
892 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100893 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100894 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +0100895 DataType::QuantisedSymm16,
896 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100897 };
898
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100899 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
900 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
901 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100902
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100903 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
904 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100905
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100906 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
907 inputTensorInfo1,
908 outputTensorInfo,
909 descriptorName,
910 "input_0",
911 "input_1");
telsoa014fcda012018-03-09 14:13:49 +0000912}
913
telsoa014fcda012018-03-09 14:13:49 +0000914void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
915{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100916 const std::string descriptorName{"MultiplicationQueueDescriptor"};
surmeh01bceff2f2018-03-29 16:29:27 +0100917
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100918 ValidateNumInputs(workloadInfo, descriptorName, 2);
919 ValidateNumOutputs(workloadInfo, descriptorName, 1);
920
921 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
922 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
923 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
924
925 std::vector<DataType> supportedTypes =
926 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100927 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100928 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +0100929 DataType::QuantisedSymm16,
930 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100931 };
932
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100933 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
934 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
935 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100936
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100937 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
938 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100939
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100940 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
941 inputTensorInfo1,
942 outputTensorInfo,
943 descriptorName,
944 "input_0",
945 "input_1");
telsoa014fcda012018-03-09 14:13:49 +0000946}
947
948void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
949{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100950 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100951
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100952 ValidateNumInputs(workloadInfo, descriptorName, 1);
953 ValidateNumOutputs(workloadInfo, descriptorName, 1);
954
955 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
956 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100957
958 std::vector<DataType> supportedTypes =
959 {
960 DataType::Float16,
961 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100962 DataType::QuantisedAsymm8,
963 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100964 };
965
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100966 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
967 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100968
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100969 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
970 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
971 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100972
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100973 ValidatePointer(m_Mean, descriptorName, "mean");
974 ValidatePointer(m_Variance, descriptorName, "variance");
975 ValidatePointer(m_Beta, descriptorName, "beta");
976 ValidatePointer(m_Gamma, descriptorName, "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000977
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100978 const TensorInfo& mean = m_Mean->GetTensorInfo();
979 const TensorInfo& variance = m_Variance->GetTensorInfo();
980 const TensorInfo& beta = m_Beta->GetTensorInfo();
981 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
telsoa014fcda012018-03-09 14:13:49 +0000982
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100983 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
984 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
985 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
986 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000987
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100988 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
989 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
990 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
telsoa014fcda012018-03-09 14:13:49 +0000991}
992
993void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
994{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100995 const std::string descriptorName{"Convolution2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +0000996
Aron Virginas-Tar84062b72019-07-19 11:37:10 +0100997 ValidateNumInputs(workloadInfo, descriptorName, 1);
998 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +0000999
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001000 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1001 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001002
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001003 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1004 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001005
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001006 ValidatePointer(m_Weight, descriptorName, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001007
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001008 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1009 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
telsoa014fcda012018-03-09 14:13:49 +00001010
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001011 ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight");
telsoa014fcda012018-03-09 14:13:49 +00001012
1013 if (m_Parameters.m_BiasEnabled)
1014 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001015 ValidatePointer(m_Bias, descriptorName, "bias");
telsoa014fcda012018-03-09 14:13:49 +00001016
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001017 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
1018 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
1019
1020 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1021 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001022 }
1023
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001024 std::vector<DataType> supportedTypes =
1025 {
Ruomei Yan88d44b82019-05-23 14:29:06 +01001026 DataType::Float32,
1027 DataType::QuantisedAsymm8,
1028 DataType::QuantisedSymm16,
1029 DataType::Float16
1030 };
1031
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001032 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1033 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1034}
Ruomei Yan88d44b82019-05-23 14:29:06 +01001035
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001036void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1037{
1038 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1039
1040 ValidateNumInputs(workloadInfo, descriptorName, 1);
1041 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1042
1043 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1044 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1045
1046 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1047 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1048
1049 ValidatePointer(m_Weight, descriptorName, "weight");
1050
1051 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
1052 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1053
1054 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1055 {
1056 throw InvalidArgumentException(
1057 boost::str(boost::format("%1%: dilationX (provided %2%) and dilationY (provided %3%) "
1058 "cannot be smaller than 1.") % descriptorName %
1059 m_Parameters.m_DilationX % m_Parameters.m_DilationX));
1060 }
1061
1062 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1063
1064 // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
1065 // inputChannels * channelMultiplier should be equal to outputChannels.
1066 const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
1067 const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
1068 const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1069 if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
1070 {
1071 throw InvalidArgumentException(
1072 boost::str(boost::format("%1%: output_channels (provided %2%) should be "
1073 "equal to input_channels (provided %3%) multiplied by channel_multiplier "
1074 "(provided %4%).") % descriptorName % numWeightOutputChannels %
1075 numWeightInputChannels % numWeightChannelMultiplier));
1076 }
1077
1078 ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight");
1079
1080 if (m_Parameters.m_BiasEnabled)
1081 {
1082 ValidatePointer(m_Bias, descriptorName, "bias");
1083
1084 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
1085 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
1086
1087 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1088 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1089 }
1090
1091 std::vector<DataType> supportedTypes =
1092 {
1093 DataType::Float32,
1094 DataType::QuantisedAsymm8,
1095 DataType::QuantisedSymm16,
1096 DataType::Float16
1097 };
1098
1099 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1100 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001101}
1102
1103void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1104{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001105 const std::string descriptorName{"PermuteQueueDescriptor"};
1106
1107 ValidateNumInputs(workloadInfo, descriptorName, 1);
1108 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001109
1110 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1111
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001112 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1113 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
telsoa014fcda012018-03-09 14:13:49 +00001114
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001115 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1116 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
telsoa014fcda012018-03-09 14:13:49 +00001117
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001118 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
telsoa014fcda012018-03-09 14:13:49 +00001119 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001120 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
telsoa014fcda012018-03-09 14:13:49 +00001121 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001122 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1123 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1124 "must match dst dimension " + to_string(mapping[i]) +
1125 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
telsoa014fcda012018-03-09 14:13:49 +00001126 }
1127 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001128
1129 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001130}
1131
1132void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1133{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001134 const std::string descriptorName{"Pooling2dQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001135
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001136 ValidateNumInputs(workloadInfo, descriptorName, 1);
1137 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1138
1139 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1140 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1141
1142 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1143 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlina3b20472019-06-06 11:12:32 +01001144
1145 std::vector<DataType> supportedTypes =
1146 {
1147 DataType::Float32,
1148 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001149 DataType::QuantisedAsymm8,
1150 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001151 };
1152
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001153 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1154 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001155}
1156
1157void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1158{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001159 const std::string descriptorName{"ResizeBilinearQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001160
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001161 ValidateNumInputs(workloadInfo, descriptorName, 1);
1162 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1163
1164 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1165 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1166
1167 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1168 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
telsoa014fcda012018-03-09 14:13:49 +00001169
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001170 std::vector<DataType> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001171 {
1172 DataType::Float16,
1173 DataType::Float32,
1174 DataType::QuantisedAsymm8,
1175 DataType::QuantisedSymm16
1176 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001177
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001178 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1179 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001180
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001181 // ResizeBilinear only changes width and height: batch and channel count must match.
1182 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1183 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001184 if (inputBatchSize != outputBatchSize)
telsoa014fcda012018-03-09 14:13:49 +00001185 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001186 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001187 boost::str(boost::format("%1%: Input batch size (%2%) "
1188 "does not match output batch size (%3%)") %
1189 descriptorName % inputBatchSize % outputBatchSize));
telsoa014fcda012018-03-09 14:13:49 +00001190 }
1191
Teresa Charlin970f43b2019-07-01 13:51:07 +01001192 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001193 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1194 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001195 if (inputChannelCount != outputChannelCount)
telsoa014fcda012018-03-09 14:13:49 +00001196 {
Teresa Charlin970f43b2019-07-01 13:51:07 +01001197 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001198 boost::str(boost::format("%1%: Input channel count (%2%) "
1199 "does not match output channel count (%3%)") %
1200 descriptorName % inputChannelCount % outputChannelCount));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001201 }
1202}
1203
1204void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1205{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001206 const std::string descriptorName{"ResizeQueueDescriptor"};
Teresa Charlin970f43b2019-07-01 13:51:07 +01001207
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001208 ValidateNumInputs(workloadInfo, descriptorName, 1);
1209 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1210
1211 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1212 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1213
1214 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1215 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001216
1217 std::vector<DataType> supportedTypes =
1218 {
1219 DataType::Float16,
1220 DataType::Float32,
1221 DataType::QuantisedAsymm8,
1222 DataType::QuantisedSymm16
1223 };
1224
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001225 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1226 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Teresa Charlin970f43b2019-07-01 13:51:07 +01001227
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001228 // Resize only changes width and height: batch and channel count must match.
1229 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1230 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001231 if (inputBatchSize != outputBatchSize)
1232 {
1233 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001234 boost::str(boost::format("%1%: Input batch size (%2%) "
1235 "does not match output batch size (%3%)") %
1236 descriptorName % inputBatchSize % outputBatchSize));
Teresa Charlin970f43b2019-07-01 13:51:07 +01001237 }
1238
1239 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001240 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1241 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
Teresa Charlin970f43b2019-07-01 13:51:07 +01001242 if (inputChannelCount != outputChannelCount)
1243 {
1244 throw InvalidArgumentException(
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001245 boost::str(boost::format("%1%: Input channel count (%2%) "
1246 "does not match output channel count (%3%)") %
1247 descriptorName % inputChannelCount % outputChannelCount));
telsoa014fcda012018-03-09 14:13:49 +00001248 }
1249}
1250
1251void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1252{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001253 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001254
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001255 ValidateNumInputs(workloadInfo, descriptorName, 1);
1256 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1257
1258 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1259 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1260
1261 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1262 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1263
1264 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1265
telsoa014fcda012018-03-09 14:13:49 +00001266 if (m_Parameters.m_Min > m_Parameters.m_Max)
1267 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001268 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
telsoa014fcda012018-03-09 14:13:49 +00001269 }
telsoa014fcda012018-03-09 14:13:49 +00001270}
1271
Kevin Mayce5045a2019-10-02 14:07:47 +01001272void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1273{
1274 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1275
1276 ValidateNumInputs(workloadInfo, descriptorName, 1);
1277 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1278
1279 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1280 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1281
1282 if (inputTensorInfo.GetNumDimensions() > 4)
1283 {
1284 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1285 }
1286
1287 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1288
1289 // Check the supported data types
1290 std::vector<DataType> supportedTypes =
1291 {
1292 DataType::Float32,
1293 DataType::Float16
1294 };
1295
1296 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Kevin Mayce5045a2019-10-02 14:07:47 +01001297 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Kevin Mayce5045a2019-10-02 14:07:47 +01001298}
1299
telsoa014fcda012018-03-09 14:13:49 +00001300void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1301{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001302 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001303
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001304 ValidateNumInputs(workloadInfo, descriptorName, 1);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001305 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1306
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001307 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1308 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1309
Matthew Jackson82b15ed2019-07-25 16:14:30 +01001310 if (inputTensorInfo.GetNumDimensions() > 4)
1311 {
1312 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1313 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001314
1315 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001316
1317 // Check the supported data types
1318 std::vector<DataType> supportedTypes =
1319 {
1320 DataType::Float32,
1321 DataType::Float16,
1322 DataType::QuantisedAsymm8,
1323 DataType::QuantisedSymm16
1324 };
1325
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001326 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001327 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1328}
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001329
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +01001330void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1331{
1332 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1333
1334 ValidateNumInputs(workloadInfo, descriptorName, 1);
1335 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1336
1337 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1338 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1339
1340 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1341
1342 std::vector<DataType> supportedTypes =
1343 {
1344 DataType::Float32,
1345 DataType::Float16,
1346 };
1347
1348 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001349 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001350}
1351
1352void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1353{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001354 const std::string descriptorName{"ConstantQueueDescriptor"};
1355
1356 ValidateNumInputs(workloadInfo, descriptorName, 0);
1357 ValidateNumOutputs(workloadInfo, descriptorName, 1);
telsoa014fcda012018-03-09 14:13:49 +00001358
1359 if (!m_LayerOutput)
1360 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001361 throw InvalidArgumentException(descriptorName + ": No const input specified.");
telsoa014fcda012018-03-09 14:13:49 +00001362 }
1363
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001364 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1365 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
Nina Drozd58ef2c62019-05-16 12:09:18 +01001366
1367 // Check the supported data types
1368 std::vector<DataType> supportedTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001369 {
1370 DataType::Float32,
1371 DataType::Float16,
1372 DataType::Signed32,
1373 DataType::QuantisedAsymm8,
1374 DataType::QuantisedSymm16
1375 };
Nina Drozd58ef2c62019-05-16 12:09:18 +01001376
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001377 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001378}
1379
1380void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1381{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001382 const std::string descriptorName{"ReshapeQueueDescriptor"};
telsoa014fcda012018-03-09 14:13:49 +00001383
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001384 ValidateNumInputs(workloadInfo, descriptorName, 1);
1385 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1386
1387 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1388 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1389
1390 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nina Drozd2f2778f2019-05-27 10:37:05 +01001391
1392 // Check the supported data types
1393 std::vector<DataType> supportedTypes =
1394 {
1395 DataType::Float32,
1396 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001397 DataType::Signed32,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001398 DataType::QuantisedAsymm8,
1399 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001400 };
1401
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001402 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1403 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa014fcda012018-03-09 14:13:49 +00001404}
1405
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001406void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1407{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001408 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001409
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001410 ValidateNumInputs(workloadInfo, descriptorName, 1);
1411 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1412
1413 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1414 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1415
1416 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1417 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001418
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001419 if (m_Parameters.m_BlockShape.size() != 2)
1420 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001421 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001422 }
1423
1424 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1425 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001426 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1427 "dimensions as Block Shape.");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001428 }
1429
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001430 const TensorShape& inputShape = inputTensorInfo.GetShape();
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001431
1432 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001433 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001434
Matthew Bentham8800c002018-11-19 13:19:28 +00001435 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001436
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001437 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1438 widthPad.first + widthPad.second;
1439 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1440 heightPad.first + heightPad.second;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001441
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001442 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1443 inputShape[dimensionIndices.GetChannelsIndex()];
1444 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001445
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001446 if (numOutputElements != numInputElements)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001447 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001448 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001449 to_string(numInputElements) + " after padding but output tensor has " +
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001450 to_string(numOutputElements) + " elements.");
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001451 }
1452
1453 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001454 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001455 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1456 "divisible by Block Shape in all spatial dimensions");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001457 }
nikraj01120522a2019-05-31 11:33:07 +01001458
1459 std::vector<DataType> supportedTypes =
1460 {
1461 DataType::Float16,
1462 DataType::Float32,
1463 DataType::QuantisedAsymm8,
1464 DataType::QuantisedSymm16
1465 };
1466
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001467 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1468 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +00001469}
1470
Keith Davisa57eccb2019-06-14 17:33:22 +01001471void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1472{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001473 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
Keith Davisa57eccb2019-06-14 17:33:22 +01001474
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001475 ValidateNumInputs(workloadInfo, descriptorName, 1);
1476 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Keith Davisa57eccb2019-06-14 17:33:22 +01001477
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001478 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1479 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1480
1481 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1482 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Keith Davisa57eccb2019-06-14 17:33:22 +01001483
1484 std::vector<DataType> supportedTypes =
1485 {
1486 DataType::Float32,
1487 DataType::Float16,
James Conroyd2aa85e2019-07-01 17:12:40 +01001488 DataType::QuantisedAsymm8,
1489 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001490 };
1491
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001492 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1493 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Keith Davisa57eccb2019-06-14 17:33:22 +01001494
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001495 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1496
1497 if (m_Parameters.m_BlockSize == 0)
1498 {
1499 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1500 }
1501
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001502 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1503 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1504 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1505 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
Keith Davisa57eccb2019-06-14 17:33:22 +01001506
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001507 const TensorShape& inputShape = inputTensorInfo.GetShape();
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001508 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
Keith Davisa57eccb2019-06-14 17:33:22 +01001509 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001510 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1511 "by block size in all spatial dimensions");
Keith Davisa57eccb2019-06-14 17:33:22 +01001512 }
Aron Virginas-Tar8a1b2182019-09-19 14:39:37 +01001513
1514 const TensorShape& outputShape = outputTensorInfo.GetShape();
1515 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1516 {
1517 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1518 "must be divisible by the square of block size." );
1519 }
Keith Davisa57eccb2019-06-14 17:33:22 +01001520}
1521
telsoa014fcda012018-03-09 14:13:49 +00001522void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1523{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001524 const std::string descriptorName{"FloorQueueDescriptor"};
James Conroy83735b12019-05-30 16:36:59 +01001525
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001526 ValidateNumInputs(workloadInfo, descriptorName, 1);
1527 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1528
1529 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1530 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy83735b12019-05-30 16:36:59 +01001531
1532 std::vector<DataType> supportedTypes =
1533 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001534 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001535 DataType::Float16,
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001536 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +01001537 };
1538
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001539 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
telsoa014fcda012018-03-09 14:13:49 +00001540
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001541 if (inputTensorInfo != outputTensorInfo)
telsoa014fcda012018-03-09 14:13:49 +00001542 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001543 throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
telsoa014fcda012018-03-09 14:13:49 +00001544 }
1545}
1546
telsoa01c577f2c2018-08-31 09:22:23 +01001547void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1548{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001549 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1550
1551 const std::string descriptorName{"LstmQueueDescriptor"};
1552
1553 // check dimensions of all inputs and outputs
1554 if (workloadInfo.m_InputTensorInfos.size() != 3)
1555 {
1556 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1557 }
1558 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1559 {
1560 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1561 }
1562
1563 std::vector<DataType> supportedTypes =
1564 {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001565 DataType::Float16,
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001566 DataType::Float32,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001567 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001568 };
1569
Jan Eilers38e05bd2019-06-26 13:10:09 +01001570 // check for supported type of one input and match them with all the other input and output
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001571 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1572
Jan Eilers38e05bd2019-06-26 13:10:09 +01001573 // type matches all other inputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001574 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001575 {
1576 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1577 workloadInfo.m_InputTensorInfos[i],
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001578 descriptorName,
1579 "input_0",
1580 "input_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001581 }
1582 // type matches all other outputs
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001583 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
Jan Eilers38e05bd2019-06-26 13:10:09 +01001584 {
1585 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1586 workloadInfo.m_OutputTensorInfos[i],
1587 "LstmQueueDescriptor",
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001588 "input_0",
1589 "output_" + std::to_string(i));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001590 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001591
Jan Eilers38e05bd2019-06-26 13:10:09 +01001592 // TODO: check clipping parameter is valid
1593
1594 // Inferring batch size, number of outputs and number of cells from the inputs.
1595 // TODO: figure out if there is a way to make sure the specific inputs are at that index of workloadInfo
1596 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
1597 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
1598 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
1599 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
1600 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
1601 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
1602
Jan Eilers38e05bd2019-06-26 13:10:09 +01001603 // input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001604 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
1605 descriptorName + " input_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001606 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001607 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
1608 descriptorName + " input_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001609 // outputStateInTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001610 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
1611 descriptorName + " input_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001612 // scratchBufferTensor
1613 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001614 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
1615 descriptorName + " output_0");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001616 // outputStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001617 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
1618 descriptorName + " output_1");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001619 // cellStateOutTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001620 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
1621 descriptorName + " output_2");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001622 // outputTensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001623 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
1624 descriptorName + " output_3");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001625
1626
1627 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
1628 if ( m_InputToInputWeights )
1629 {
1630 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
1631 (n_cell * n_input), "InputLayerNormWeights");
1632 }
1633
1634 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
1635 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
1636 (n_cell * n_input), "InputToForgetWeights");
1637
1638 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
1639 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
1640 (n_cell * n_input), "InputToCellWeights");
1641
1642 if ( m_RecurrentToInputWeights )
1643 {
1644 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
1645 (n_cell * n_output), "RecurrentToInputWeights");
1646 }
1647
1648 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
1649 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
1650 (n_cell * n_output), "RecurrentToForgetWeights");
1651
1652 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
1653 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
1654 (n_cell * n_output), "RecurrentToCellWeights");
1655
1656 // Make sure the input-gate's parameters are either both present (regular
1657 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
1658 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
1659 !m_Parameters.m_CifgEnabled) ||
1660 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
1661 m_Parameters.m_CifgEnabled));
1662 if (!cifg_weights_all_or_none)
1663 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001664 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
1665 "RecurrentToInputWeights must either both be present (regular LSTM) "
1666 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
1667 "accordingly.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001668 }
1669
1670 if ( m_CellToInputWeights )
1671 {
1672 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
1673 n_cell, "CellToInputWeights");
1674 }
1675 if ( m_CellToForgetWeights )
1676 {
1677 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
1678 n_cell, "CellToForgetWeights");
1679 }
1680 if ( m_CellToOutputWeights )
1681 {
1682 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
1683 n_cell, "CellToOutputWeights");
1684 }
1685
1686 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
1687 bool peephole_weights_all_or_none =
1688 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
1689 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
1690 || ( !m_CellToInputWeights && !m_CellToForgetWeights
1691 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
1692 if (!peephole_weights_all_or_none)
1693 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001694 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001695 }
1696
1697 // Make sure the input gate bias is present only when not a CIFG-LSTM.
1698 if (m_Parameters.m_CifgEnabled)
1699 {
1700 if (m_InputGateBias)
1701 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001702 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001703 }
1704 }
1705 else
1706 {
1707 if (!m_InputGateBias)
1708 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001709 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
1710 "must be present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001711 }
1712 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
1713 n_cell, "InputGateBias");
1714 }
1715
1716 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
1717 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
1718
1719 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
1720 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
1721
1722 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
1723 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
1724
1725 if (m_ProjectionWeights)
1726 {
1727 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
1728 (n_cell * n_output), "ProjectionWeights");
1729 }
1730 if (m_ProjectionBias)
1731 {
1732 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
1733 }
1734
1735 // Making sure the projection tensors are consistent:
1736 // 1) If projection weight is not present, then projection bias should not be
1737 // present.
1738 // 2) If projection weight is present, then projection bias is optional.
1739 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
1740 !m_Parameters.m_ProjectionEnabled)
1741 || (m_ProjectionWeights && !m_ProjectionBias &&
1742 m_Parameters.m_ProjectionEnabled)
1743 || (m_ProjectionWeights && m_ProjectionBias &&
1744 m_Parameters.m_ProjectionEnabled));
1745 if (!projecton_tensors_consistent)
1746 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001747 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001748 }
1749
1750 // The four layer normalization weights either all have values or none of them have values. Additionally, if
1751 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
1752 // either all have values or none of them have values. Layer normalization is used when the values of all the
1753 // layer normalization weights are present
1754 if (m_InputLayerNormWeights)
1755 {
1756 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
1757 }
1758 if (m_ForgetLayerNormWeights)
1759 {
1760 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1761 }
1762 if (m_CellLayerNormWeights)
1763 {
1764 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1765 }
1766 if (m_OutputLayerNormWeights)
1767 {
1768 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1769 }
1770
Jan Eilers38e05bd2019-06-26 13:10:09 +01001771 if (m_Parameters.m_LayerNormEnabled)
1772 {
1773 if (!m_Parameters.m_CifgEnabled)
1774 {
1775 if (!m_InputLayerNormWeights)
1776 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001777 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
1778 "disabled but InputLayerNormWeights are not present");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001779 }
1780 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
1781 1, n_cell, "InputLayerNormWeights");
1782 }
1783 else if (m_InputLayerNormWeights)
1784 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001785 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
1786 "enabled");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001787 }
1788
1789 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
1790 "ForgetLayerNormWeights");
1791 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
1792
1793 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
1794 "OutputLayerNormWeights");
1795 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
1796
1797 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
1798 "CellLayerNormWeights");
1799 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
1800 }
1801 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
1802 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001803 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
1804 "normalisation weights are present.");
Jan Eilers38e05bd2019-06-26 13:10:09 +01001805 }
telsoa01c577f2c2018-08-31 09:22:23 +01001806}
1807
1808void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1809{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001810 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001811
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001812 ValidateNumInputs(workloadInfo, descriptorName, 1);
1813 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1814
1815 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1816 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1817
1818 if (inputTensorInfo.GetDataType() != DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01001819 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001820 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
telsoa01c577f2c2018-08-31 09:22:23 +01001821 }
1822
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001823 if (outputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001824 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001825 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001826 }
1827
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001828 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001829}
1830
1831void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1832{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001833 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
telsoa01c577f2c2018-08-31 09:22:23 +01001834
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001835 ValidateNumInputs(workloadInfo, descriptorName, 1);
1836 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1837
1838 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1839 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1840
1841 if (inputTensorInfo.GetDataType() != DataType::Float16)
telsoa01c577f2c2018-08-31 09:22:23 +01001842 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001843 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
telsoa01c577f2c2018-08-31 09:22:23 +01001844 }
1845
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001846 if (outputTensorInfo.GetDataType() != DataType::Float32)
1847 {
1848 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
1849 }
1850
1851 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
telsoa01c577f2c2018-08-31 09:22:23 +01001852}
1853
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001854void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1855{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001856 const std::string descriptorName{"DivisionQueueDescriptor"};
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001857
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001858 ValidateNumInputs(workloadInfo, descriptorName, 2);
1859 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1860
1861 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1862 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1863 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1864
1865 std::vector<DataType> supportedTypes =
1866 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001867 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001868 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001869 DataType::QuantisedSymm16,
1870 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001871 };
1872
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001873 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1874 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1875 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001876
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001877 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1878 inputTensorInfo1,
1879 outputTensorInfo,
1880 descriptorName,
1881 "input_0",
1882 "input_1");
Francis Murtaghe7a86a42018-08-29 12:42:10 +01001883}
1884
David Beckc2044fe2018-09-05 15:00:38 +01001885void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1886{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001887 const std::string descriptorName{"SubtractionQueueDescriptor"};
David Beckc2044fe2018-09-05 15:00:38 +01001888
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001889 ValidateNumInputs(workloadInfo, descriptorName, 2);
1890 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1891
1892 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1893 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1894 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1895
1896 std::vector<DataType> supportedTypes =
1897 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001898 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001899 DataType::QuantisedAsymm8,
Jim Flynn82fbe7c2019-04-02 15:19:08 +01001900 DataType::QuantisedSymm16,
1901 DataType::Float16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001902 };
1903
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001904 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1905 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1906 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001907
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001908 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1909 inputTensorInfo1,
1910 outputTensorInfo,
1911 descriptorName,
1912 "input_0",
1913 "input_1");
David Beckc2044fe2018-09-05 15:00:38 +01001914}
1915
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001916void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1917{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001918 const std::string descriptorName{"MaximumQueueDescriptor"};
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001919
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001920 ValidateNumInputs(workloadInfo, descriptorName, 2);
1921 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1922
1923 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1924 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1925 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1926
1927 std::vector<DataType> supportedTypes =
1928 {
Mike Kelly1da02362019-08-01 08:43:57 +01001929 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001930 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01001931 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001932 DataType::QuantisedAsymm8,
1933 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001934 };
1935
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001936 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1937 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1938 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01001939
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001940 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1941 inputTensorInfo1,
1942 outputTensorInfo,
1943 descriptorName,
1944 "input_0",
1945 "input_1");
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +00001946}
1947
narpra01a6bf9122018-09-10 09:50:09 +01001948void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1949{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001950 const std::string descriptorName{"MeanQueueDescriptor"};
James Conroy4d1ff582019-06-10 17:06:39 +01001951
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001952 ValidateNumInputs(workloadInfo, descriptorName, 1);
1953 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1954
1955 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1956 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
James Conroy4d1ff582019-06-10 17:06:39 +01001957
1958 std::vector<DataType> supportedTypes =
1959 {
1960 DataType::Float32,
1961 DataType::Float16,
1962 DataType::QuantisedAsymm8,
1963 DataType::QuantisedSymm16
1964 };
narpra01eb061912018-09-10 17:35:27 +01001965
James Conroy4d1ff582019-06-10 17:06:39 +01001966 // First check if input tensor data type is supported, then
1967 // check if this data type matches the output tensor data type
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001968 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1969 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
James Conroy4d1ff582019-06-10 17:06:39 +01001970
narpra0132b90462018-09-13 11:07:48 +01001971 if (m_Parameters.m_KeepDims)
narpra01eb061912018-09-10 17:35:27 +01001972 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001973 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
narpra01eb061912018-09-10 17:35:27 +01001974 }
narpra0132b90462018-09-13 11:07:48 +01001975 else if (m_Parameters.m_Axis.empty())
narpra01eb061912018-09-10 17:35:27 +01001976 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001977 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
narpra01eb061912018-09-10 17:35:27 +01001978 }
1979 else
1980 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001981 unsigned int outputDim =
1982 inputTensorInfo.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
1983 ValidateTensorNumDimensions(outputTensorInfo,
1984 descriptorName,
narpra01eb061912018-09-10 17:35:27 +01001985 outputDim > 0 ? outputDim : 1,
1986 "output");
1987 }
narpra01a6bf9122018-09-10 09:50:09 +01001988}
1989
jimfly012c9322a2018-09-19 10:59:49 +01001990void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1991{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001992 const std::string descriptorName{"PadQueueDescriptor"};
jimfly012c9322a2018-09-19 10:59:49 +01001993
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01001994 ValidateNumInputs(workloadInfo, descriptorName, 1);
1995 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1996
1997 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1998 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nina Drozd661dfa72018-10-02 11:14:17 +01001999
jimfly012c9322a2018-09-19 10:59:49 +01002000 // input and output should have the same number of dimensions
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002001 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2002
jimfly012c9322a2018-09-19 10:59:49 +01002003 // there should be entry in the pad list for each dimension in the input tensor
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002004 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2005 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2006 "as there are dimensions in the input tensor that is " +
2007 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2008 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
jimfly012c9322a2018-09-19 10:59:49 +01002009 }
2010}
2011
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002012void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2013{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002014 const std::string descriptorName{"QuantizeQueueDescriptor"};
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002015
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002016 ValidateNumInputs(workloadInfo, descriptorName, 1);
2017 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002018
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002019 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2020 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2021
Sadik Armagan2208b602019-07-31 16:36:27 +01002022 std::vector<DataType> supportedTypes =
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002023 {
James Conroyd47a0642019-09-17 14:22:06 +01002024 DataType::Float32,
2025 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002026 };
2027
2028 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002029
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002030 if (outputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
2031 outputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002032 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002033 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
Derek Lambertia9cca6a2019-03-25 15:41:58 +00002034 }
2035}
2036
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002037void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2038{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002039 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002040
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002041 ValidateNumInputs(workloadInfo, descriptorName, 1);
2042 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002043
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002044 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2045 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002046
2047 std::vector<DataType> supportedTypes =
2048 {
James Conroyd47a0642019-09-17 14:22:06 +01002049 DataType::Float32,
2050 DataType::Float16,
2051 DataType::QuantisedAsymm8,
2052 DataType::QuantisedSymm16
Francis Murtaghd0dfe172019-06-25 10:57:10 +01002053 };
2054
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002055 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2056 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +00002057}
2058
Conor Kennedy430b5d82018-11-14 15:28:28 +00002059void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2060{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002061 const std::string descriptorName{"StridedSliceQueueDescriptor"};
Conor Kennedy430b5d82018-11-14 15:28:28 +00002062
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002063 ValidateNumInputs(workloadInfo, descriptorName, 1);
2064 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2065
2066 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2067 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002068
2069 std::vector<DataType> supportedTypes =
2070 {
2071 DataType::Float16,
2072 DataType::Float32,
Matteo Martincigh42666a12019-05-29 08:53:41 +01002073 DataType::QuantisedAsymm8,
2074 DataType::QuantisedSymm16
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002075 };
2076
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002077 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2078 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002079
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002080 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Matteo Martincighe851b3d2019-05-28 14:31:20 +01002081
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002082 const uint32_t rank = inputTensorInfo.GetNumDimensions();
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002083 if (rank > 4)
2084 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002085 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
Nattapat Chaimanowonga0d28442018-11-21 16:48:17 +00002086 }
2087
Conor Kennedy430b5d82018-11-14 15:28:28 +00002088 // Begin, End & Stride length must be of rank(input0)
2089 if (m_Parameters.m_Begin.size() != rank)
2090 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002091 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002092 }
2093
2094 if (m_Parameters.m_End.size() != rank)
2095 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002096 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002097 }
2098
2099 if (m_Parameters.m_Stride.size() != rank)
2100 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002101 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
Conor Kennedy430b5d82018-11-14 15:28:28 +00002102 }
2103
2104 // Stride entries must be non-zero
2105 for (auto& stride : m_Parameters.m_Stride)
2106 {
2107 if (stride == 0)
2108 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002109 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
Conor Kennedy430b5d82018-11-14 15:28:28 +00002110 }
2111 }
2112}
2113
kevmay0190539692018-11-29 08:40:19 +00002114void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2115{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002116 const std::string descriptorName{"MinimumQueueDescriptor"};
kevmay0190539692018-11-29 08:40:19 +00002117
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002118 ValidateNumInputs(workloadInfo, descriptorName, 2);
2119 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2120
2121 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2122 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2123 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2124
2125 std::vector<DataType> supportedTypes =
2126 {
Mike Kelly1da02362019-08-01 08:43:57 +01002127 DataType::Float16,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002128 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01002129 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01002130 DataType::QuantisedAsymm8,
2131 DataType::QuantisedSymm16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002132 };
2133
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002134 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2135 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2136 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01002137
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002138 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2139 inputTensorInfo1,
2140 outputTensorInfo,
2141 descriptorName,
2142 "input_0",
2143 "input_1");
kevmay0190539692018-11-29 08:40:19 +00002144}
2145
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002146void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2147{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002148 const std::string descriptorName{"DebugQueueDescriptor"};
2149
2150 ValidateNumInputs(workloadInfo, descriptorName, 1);
2151 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowonga9a1cf12018-12-03 16:06:49 +00002152}
2153
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002154void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2155{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002156 const std::string descriptorName{"EqualQueueDescriptor"};
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002157
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002158 ValidateNumInputs(workloadInfo, descriptorName, 2);
2159 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002160
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002161 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2162 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2163 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2164
2165 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2166 inputTensorInfo1,
2167 outputTensorInfo,
2168 descriptorName,
2169 "input_0",
2170 "input_1");
2171
2172 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002173 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002174 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002175 }
FrancisMurtagh30cdfca2018-12-18 12:57:35 +00002176}
2177
FrancisMurtagh878f0232018-12-19 10:56:15 +00002178void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2179{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002180 const std::string descriptorName{"GreaterQueueDescriptor"};
FrancisMurtagh878f0232018-12-19 10:56:15 +00002181
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002182 ValidateNumInputs(workloadInfo, descriptorName, 2);
2183 ValidateNumOutputs(workloadInfo, descriptorName, 1);
kevmay012b4d88e2019-01-24 14:05:09 +00002184
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002185 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2186 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2187 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2188
2189 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2190 inputTensorInfo1,
2191 outputTensorInfo,
2192 descriptorName,
2193 "input_0",
2194 "input_1");
2195
2196 if (outputTensorInfo.GetDataType() != DataType::Boolean)
kevmay012b4d88e2019-01-24 14:05:09 +00002197 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002198 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
kevmay012b4d88e2019-01-24 14:05:09 +00002199 }
FrancisMurtagh878f0232018-12-19 10:56:15 +00002200}
2201
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002202void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2203{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002204 const std::string descriptorName{"RsqrtQueueDescriptor"};
2205
2206 ValidateNumInputs(workloadInfo, descriptorName, 1);
2207 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2208
2209 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2210 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2211
2212 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
nikraj010421e7f2019-06-14 09:40:34 +01002213
2214 std::vector<DataType> supportedTypes =
2215 {
James Conroyd47a0642019-09-17 14:22:06 +01002216 DataType::Float16,
2217 DataType::Float32,
2218 DataType::QuantisedAsymm8,
2219 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01002220 };
2221
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002222 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2223 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00002224}
2225
narpra01b89b05f2019-01-16 09:53:09 +00002226void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2227{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002228 const std::string descriptorName{"GatherQueueDescriptor"};
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002229
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002230 ValidateNumInputs(workloadInfo, descriptorName, 2);
2231 ValidateNumOutputs(workloadInfo, descriptorName, 1);
narpra014951d842019-01-18 16:53:53 +00002232
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002233 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2234 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
narpra014951d842019-01-18 16:53:53 +00002235 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002236 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
narpra014951d842019-01-18 16:53:53 +00002237 }
2238
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002239 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2240 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2241
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002242 std::vector<DataType> supportedTypes =
2243 {
James Conroyd47a0642019-09-17 14:22:06 +01002244 DataType::Float16,
2245 DataType::Float32,
2246 DataType::QuantisedAsymm8,
2247 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002248 };
2249
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002250 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002251
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002252 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01002253
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002254 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2255 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
narpra01b89b05f2019-01-16 09:53:09 +00002256}
2257
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002258void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2259{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002260 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2261
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002262 ValidateNumInputs(workloadInfo, descriptorName, 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002263
2264 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2265 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002266 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002267 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2268 }
2269
2270 if (m_Anchors == nullptr)
2271 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002272 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002273 }
2274
2275 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002276 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2277 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2278
2279 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +00002280 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002281 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2282 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002283
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002284 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2285 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2286 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002287
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002288 const std::vector<DataType> supportedInputTypes =
2289 {
2290 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002291 DataType::Float16,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002292 DataType::QuantisedAsymm8,
2293 DataType::QuantisedSymm16
2294 };
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002295
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002296 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2297 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2298 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2299
2300 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2301 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2302 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2303 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2304
2305 // NOTE: Output is always Float32 regardless of input type
2306 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2307 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2308 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2309 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002310
2311 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2312 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002313 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002314 "must be positive and less than or equal to 1.");
2315 }
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002316
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002317 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2318 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01002319 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00002320 "should be equal to number of classes + 1.");
2321 }
2322}
2323
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002324void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2325{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002326 const std::string& descriptorName{"DequantizeQueueDescriptor"};
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002327
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002328 ValidateNumInputs(workloadInfo, descriptorName, 1);
2329 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2330
2331 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2332 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2333
2334 if (inputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 &&
2335 inputTensorInfo.GetDataType() != DataType::QuantisedSymm16)
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002336 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002337 throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type.");
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002338 }
2339
Sadik Armagan2208b602019-07-31 16:36:27 +01002340 std::vector<DataType> supportedTypes =
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002341 {
James Conroyd47a0642019-09-17 14:22:06 +01002342 DataType::Float32,
2343 DataType::Float16
Sadik Armagan2208b602019-07-31 16:36:27 +01002344 };
2345
2346 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +00002347}
2348
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002349void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2350{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002351 const std::string& descriptorName{"MergeQueueDescriptor"};
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002352
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002353 ValidateNumInputs(workloadInfo, descriptorName, 2);
2354 ValidateNumOutputs(workloadInfo, descriptorName, 1);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002355
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002356 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2357 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2358 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002359
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002360 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2361 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2362
2363 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2364 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +01002365}
2366
Sadik Armaganeff363d2019-04-05 15:25:46 +01002367void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2368{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002369 const std::string& descriptorName{"SwitchQueueDescriptor"};
Sadik Armaganeff363d2019-04-05 15:25:46 +01002370
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002371 ValidateNumInputs(workloadInfo, descriptorName, 2);
2372 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2373
2374 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2375 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2376
2377 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2378 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2379
2380 std::vector<DataType> supportedTypes =
2381 {
Sadik Armaganeff363d2019-04-05 15:25:46 +01002382 DataType::Float32,
2383 DataType::QuantisedAsymm8,
2384 DataType::QuantisedSymm16
2385 };
2386
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002387 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2388 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002389
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002390 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2391 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
Sadik Armaganeff363d2019-04-05 15:25:46 +01002392
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002393 ValidateTensorShapesMatch(inputTensorInfo0,
2394 outputTensorInfo0,
2395 descriptorName,
2396 "input_0",
2397 "output_0");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002398
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002399 ValidateTensorShapesMatch(inputTensorInfo0,
2400 outputTensorInfo1,
2401 descriptorName,
2402 "input_0",
2403 "output_1");
Sadik Armaganeff363d2019-04-05 15:25:46 +01002404}
2405
Matteo Martincigh49124022019-01-11 13:25:59 +00002406void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2407{
2408 // This is internally generated so it should not need validation.
2409}
2410
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002411void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2412{
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002413 const std::string& descriptorName{"PreluQueueDescriptor"};
2414
2415 ValidateNumInputs(workloadInfo, descriptorName, 2);
2416 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2417
2418 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2419 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2420 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002421
2422 std::vector<DataType> supportedTypes
2423 {
2424 DataType::Float16,
2425 DataType::Float32,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002426 DataType::QuantisedAsymm8,
2427 DataType::QuantisedSymm16
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002428 };
2429
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002430 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2431 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002432
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002433 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002434
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002435 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2436 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002437
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002438 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2439 alphaTensorInfo,
2440 outputTensorInfo,
2441 descriptorName,
Matteo Martincigh0e406ee2019-06-12 15:42:18 +01002442 "input",
2443 "alpha");
2444}
2445
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002446void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2447{
2448 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2449
2450 ValidateNumInputs(workloadInfo, descriptorName, 1);
2451 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2452
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002453 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2454 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2455
2456 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2457 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002458
2459 ValidatePointer(m_Weight, descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002460
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002461 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2462 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
2463 ValidateTensorDataType(weightTensorInfo, inputTensorInfo.GetDataType(), descriptorName, "weight");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002464
2465 if (m_Parameters.m_BiasEnabled)
2466 {
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002467 ValidatePointer(m_Bias, descriptorName, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002468
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002469 const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
2470 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002471
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002472 ValidateTensorDataType(biasTensorInfo,
2473 GetBiasDataType(inputTensorInfo.GetDataType()),
2474 descriptorName,
2475 "bias");
2476
2477 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002478 }
Aron Virginas-Tar639fb042019-06-20 14:28:19 +01002479}
2480
James Conroy9c3cae82019-08-01 16:01:48 +01002481void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2482{
2483 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
2484
2485 // Validate number of inputs/outputs
2486 ValidateNumInputs(workloadInfo, descriptorName, 3);
2487 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2488
2489 // Input/output tensor infos
2490 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2491 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
2492 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
2493
2494 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2495 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2496
2497 std::vector<DataType> inputOutputSupportedTypes =
2498 {
2499 DataType::QuantisedAsymm8
2500 };
2501
2502 std::vector<DataType> cellStateSupportedTypes =
2503 {
2504 DataType::QuantisedSymm16
2505 };
2506
2507 std::vector<DataType> weightsSupportedTypes =
2508 {
2509 DataType::QuantisedAsymm8
2510 };
2511
2512 std::vector<DataType> biasSupportedTypes =
2513 {
2514 DataType::Signed32
2515 };
2516
2517 // Validate types of input/output tensors
2518 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2519 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2520 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2521
2522 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2523 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2524
2525 // Validate matching types of input/output tensors
2526 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2527 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2528 "outputStateIn", "outputStateOut");
2529 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2530
2531 // Validate matching quantization info for input/output tensors
2532 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2533 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
2534 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002535
James Conroy9c3cae82019-08-01 16:01:48 +01002536 // Infer number of batches, input size and output size from tensor dimensions
2537 const uint32_t numBatches = inputInfo.GetShape()[0];
2538 const uint32_t inputSize = inputInfo.GetShape()[1];
2539 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
2540
2541 // Validate number of dimensions and number of elements for input/output tensors
2542 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2543 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
2544 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2545 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
2546 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2547
2548 // Validate number of dimensions and number of elements for weights tensors
2549 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
2550 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
2551 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
2552
2553 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2554 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2555 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
2556
2557 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2558 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2559 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
2560
2561 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2562 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2563 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
2564
2565 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
2566 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
2567 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
2568
2569 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2570 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2571 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
2572 " RecurrentToForgetWeights");
2573
2574 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2575 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2576 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2577
2578 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2579 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2580 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2581
2582 // Validate data types for weights tensors (all should match each other)
2583 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
2584
2585 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
2586 "inputToInputWeights", "inputToForgetWeights");
2587 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
2588 "inputToInputWeights", "inputToCellWeights");
2589 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2590 "inputToInputWeights", "inputToOutputWeights");
2591
2592 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
2593 "inputToInputWeights", "recurrentToInputWeights");
2594 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2595 "inputToInputWeights", "recurrentToForgeteights");
2596 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2597 "inputToInputWeights", "recurrentToCellWeights");
2598 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2599 "inputToInputWeights", "recurrentToOutputWeights");
2600
2601 // Validate matching quantization info for weight tensors (all should match each other)
2602 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
2603 descriptorName, "inputToInputWeights", "inputToForgetWeights");
2604 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
2605 descriptorName, "inputToInputWeights", "inputToCellWeights");
2606 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
2607 descriptorName, "inputToInputWeights", "inputToOutputWeights");
2608
2609 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
2610 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
2611 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
2612 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
2613 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
2614 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
2615 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
2616 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
2617
2618 // Validate number of dimensions and number of elements in bias tensors
2619 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
2620 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
2621 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
2622
2623 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2624 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2625 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
2626
2627 ValidatePointer(m_CellBias, descriptorName, "CellBias");
2628 auto cellBiasInfo = m_CellBias->GetTensorInfo();
2629 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
2630
2631 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
2632 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
2633 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
2634
2635 // Validate data types for bias tensors (all should match each other)
2636 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
2637
2638 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
2639 "inputGateBias", "forgetGateBias");
2640 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
2641 "inputGateBias", "cellBias");
2642 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
2643 "inputGateBias", "outputGateBias");
2644
2645 // Validate bias tensor quantization info
2646 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2647 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2648 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2649 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2650}
2651
Kevin May868eb142019-09-04 17:29:31 +01002652void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2653{
2654 const std::string descriptorName{"AbsQueueDescriptor"};
2655
2656 ValidateNumInputs(workloadInfo, descriptorName, 1);
2657 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2658
2659 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2660 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2661
2662 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2663
2664 std::vector<DataType> supportedTypes =
James Conroyd47a0642019-09-17 14:22:06 +01002665 {
2666 DataType::Float16,
2667 DataType::Float32,
2668 DataType::QuantisedAsymm8,
2669 DataType::QuantisedSymm16
2670 };
Kevin May868eb142019-09-04 17:29:31 +01002671
2672 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2673 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2674}
2675
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002676void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2677{
2678 const std::string descriptorName{"SliceQueueDescriptor"};
2679
2680 ValidateNumInputs(workloadInfo, descriptorName, 1);
2681 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2682
2683 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2684 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2685
2686 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2687
2688 const unsigned int rank = inputTensorInfo.GetNumDimensions();
2689 if (rank > 4)
2690 {
2691 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
2692 }
2693
2694 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
2695
2696 // Check if m_Begin and m_Size have the expected length
2697 if (m_Parameters.m_Begin.size() != rank)
2698 {
2699 throw InvalidArgumentException(descriptorName +
2700 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
2701 }
2702 if (m_Parameters.m_Size.size() != rank)
2703 {
2704 throw InvalidArgumentException(descriptorName +
2705 ": Length of size descriptor must equal rank " + std::to_string(rank));
2706 }
2707
2708 // Check if the shape of the output tensor matches m_Size
2709 const TensorShape& outputShape = outputTensorInfo.GetShape();
2710 for (unsigned int i = 0u; i < rank; ++i)
2711 {
2712 if (m_Parameters.m_Size[i] != outputShape[i])
2713 {
2714 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
2715 }
2716 }
2717
2718 // Check if the sum of begin offset and size in a given dimension
2719 // does not exceed the size of corresponding input
2720 const TensorShape& inputShape = inputTensorInfo.GetShape();
2721 for(unsigned int i = 0u; i < rank; ++i)
2722 {
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002723 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
Aron Virginas-Tar636ab402019-09-16 14:27:45 +01002724 {
2725 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
2726 std::to_string(i) + " exceeds input size.");
2727 }
2728 }
2729}
2730
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +01002731void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2732{
2733 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
2734
2735 ValidateNumInputs(workloadInfo, descriptorName, 1);
2736 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2737
2738 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
2739 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
2740
2741 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
2742 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
2743
2744 std::vector<DataType> supportedTypes =
2745 {
2746 DataType::Float32,
2747 DataType::Float16,
2748 DataType::QuantisedAsymm8,
2749 DataType::QuantisedSymm16
2750 };
2751
2752 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
2753 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
2754
2755 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
2756
2757 if (m_Parameters.m_BlockSize == 0)
2758 {
2759 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
2760 }
2761
2762 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
2763 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
2764 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
2765 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
2766
2767 const TensorShape& outputShape = outputInfo.GetShape();
2768 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
2769 {
2770 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
2771 "must be divisible by block size.");
2772 }
2773
2774 const TensorShape& inputShape = inputInfo.GetShape();
2775 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
2776 {
2777 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
2778 "must be divisible by the square of block size." );
2779 }
2780}
2781
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01002782void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2783{
2784 const std::string descriptorName{"ComparisonQueueDescriptor"};
2785
2786 ValidateNumInputs(workloadInfo, descriptorName, 2);
2787 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2788
2789 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2790 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2791 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2792
2793 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2794 inputTensorInfo1,
2795 outputTensorInfo,
2796 descriptorName,
2797 "input_0",
2798 "input_1");
2799
2800 if (outputTensorInfo.GetDataType() != DataType::Boolean)
2801 {
2802 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
2803 }
2804}
2805
Aron Virginas-Tar84062b72019-07-19 11:37:10 +01002806} // namespace armnn