blob: 32af179bdca76160e95e8dc9138bfbb152830388 [file] [log] [blame]
Nina Drozdd41b2592018-11-19 13:03:36 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/TypesUtils.hpp>
9
Keith Davis5236e1d2019-11-04 08:58:33 +000010#include <boost/assert.hpp>
11
Nina Drozdd41b2592018-11-19 13:03:36 +000012namespace armnnUtils
13{
14armnn::TensorShape GetTensorShape(unsigned int numberOfBatches,
15 unsigned int numberOfChannels,
16 unsigned int height,
17 unsigned int width,
18 const armnn::DataLayout dataLayout);
19
Nina Drozdd41b2592018-11-19 13:03:36 +000020armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches,
21 unsigned int numberOfChannels,
22 unsigned int height,
23 unsigned int width,
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000024 const armnn::DataLayout dataLayout,
25 const armnn::DataType dataType);
26
Jim Flynnf92dfce2019-05-02 11:33:25 +010027std::pair<float, float> FindMinMax(armnn::ITensorHandle* tensorHandle);
28
Narumol Prangnawarat02807852019-09-11 16:43:09 +010029armnn::TensorShape ExpandDims(const armnn::TensorShape& tensorShape, int axis);
30
Narumol Prangnawarat4dc64a62019-09-16 17:00:22 +010031unsigned int GetNumElementsBetween(const armnn::TensorShape& shape,
32 unsigned int firstAxisInclusive,
33 unsigned int lastAxisExclusive);
34
35unsigned int GetUnsignedAxis(const unsigned int inputDimension, const int axis);
36
Keith Davis5236e1d2019-11-04 08:58:33 +000037inline unsigned int GetNumElementsAfter(const armnn::TensorShape& shape,
38 unsigned int axis)
39{
40 unsigned int numDim = shape.GetNumDimensions();
41 BOOST_ASSERT(0 >= axis);
42 BOOST_ASSERT(axis < numDim - 1);
43 unsigned int count = 1;
44 for (unsigned int i = axis; i < numDim; i++)
45 {
46 count *= shape[i];
47 }
48 return count;
49}
50
51inline std::pair<unsigned int, std::vector<float>> GetPerAxisParams(const armnn::TensorInfo& info)
52{
53 const std::vector<float>& scales = info.GetQuantizationScales();
54 armnn::Optional<unsigned int> quantizationDim = info.GetQuantizationDim();
55 if (scales.size() < 1 || !quantizationDim.has_value())
56 {
57 throw armnn::InvalidArgumentException(
58 "We currently support only per-axis symmetric quantization for QuantizedSymm8.");
59 }
60 unsigned int axisFactor = GetNumElementsAfter(info.GetShape(), quantizationDim.value());
61
62 return {axisFactor, scales};
63}
64
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000065} // namespace armnnUtils