blob: 0dbb75c33ac3b968320ad9d33656aa98fc54b791 [file] [log] [blame]
Nina Drozdd41b2592018-11-19 13:03:36 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "TensorUtils.hpp"
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +00007
Jim Flynnf92dfce2019-05-02 11:33:25 +01008#include <backendsCommon/ITensorHandle.hpp>
Nina Drozdd41b2592018-11-19 13:03:36 +00009
Narumol Prangnawarat02807852019-09-11 16:43:09 +010010#include <boost/assert.hpp>
11#include <boost/format.hpp>
12#include <boost/numeric/conversion/cast.hpp>
13
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000014using namespace armnn;
15
Nina Drozdd41b2592018-11-19 13:03:36 +000016namespace armnnUtils
17{
18
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000019TensorShape GetTensorShape(unsigned int numberOfBatches,
Nina Drozdd41b2592018-11-19 13:03:36 +000020 unsigned int numberOfChannels,
21 unsigned int height,
22 unsigned int width,
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000023 const DataLayout dataLayout)
Nina Drozdd41b2592018-11-19 13:03:36 +000024{
25 switch (dataLayout)
26 {
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000027 case DataLayout::NCHW:
28 return TensorShape({numberOfBatches, numberOfChannels, height, width});
29 case DataLayout::NHWC:
30 return TensorShape({numberOfBatches, height, width, numberOfChannels});
Nina Drozdd41b2592018-11-19 13:03:36 +000031 default:
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000032 throw InvalidArgumentException("Unknown data layout ["
Nina Drozdd41b2592018-11-19 13:03:36 +000033 + std::to_string(static_cast<int>(dataLayout)) +
34 "]", CHECK_LOCATION());
35 }
36}
37
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000038TensorInfo GetTensorInfo(unsigned int numberOfBatches,
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000039 unsigned int numberOfChannels,
40 unsigned int height,
41 unsigned int width,
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000042 const DataLayout dataLayout,
43 const DataType dataType)
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000044{
45 switch (dataLayout)
46 {
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000047 case DataLayout::NCHW:
48 return TensorInfo({numberOfBatches, numberOfChannels, height, width}, dataType);
49 case DataLayout::NHWC:
50 return TensorInfo({numberOfBatches, height, width, numberOfChannels}, dataType);
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000051 default:
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000052 throw InvalidArgumentException("Unknown data layout ["
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000053 + std::to_string(static_cast<int>(dataLayout)) +
54 "]", CHECK_LOCATION());
55 }
Nina Drozdd41b2592018-11-19 13:03:36 +000056}
57
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000058std::pair<float, float> FindMinMax(ITensorHandle* tensorHandle)
Jim Flynnf92dfce2019-05-02 11:33:25 +010059{
60 auto tensor_data = static_cast<const float *>(tensorHandle->Map(true));
61 auto tensor_size = tensorHandle->GetShape().GetNumElements();
62
63 // Set min/max initially to first value in tensor
64 float min = tensor_data[0];
65 float max = tensor_data[0];
66
67 // Loop over rest of tensor and update min/max if necessary
68 for (unsigned int val = 1; val < tensor_size; val++)
69 {
70 if (tensor_data[val] < min)
71 {
72 min = tensor_data[val];
73 }
74 else if (tensor_data[val] > max)
75 {
76 max = tensor_data[val];
77 }
78 }
79
80 tensorHandle->Unmap();
81
82 return std::make_pair(min, max);
83}
84
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000085TensorShape ExpandDims(const TensorShape& tensorShape, int axis)
Narumol Prangnawarat02807852019-09-11 16:43:09 +010086{
87 unsigned int outputDim = tensorShape.GetNumDimensions() + 1;
88
89 if (axis < -boost::numeric_cast<int>(outputDim) || axis > boost::numeric_cast<int>(tensorShape.GetNumDimensions()))
90 {
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +000091 throw InvalidArgumentException(
Narumol Prangnawarat02807852019-09-11 16:43:09 +010092 boost::str(boost::format("Invalid expansion axis %1% for %2%D input tensor. %3%") %
93 axis %
94 tensorShape.GetNumDimensions() %
95 CHECK_LOCATION().AsString()));
96 }
97
98 if (axis < 0)
99 {
100 axis = boost::numeric_cast<int>(outputDim) + axis;
101 }
102
103 std::vector<unsigned int> outputShape;
104 for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
105 {
106 outputShape.push_back(tensorShape[i]);
107 }
108 outputShape.insert(outputShape.begin() + axis, 1);
109
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +0000110 return TensorShape(outputDim, outputShape.data());
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100111}
112
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +0000113unsigned int GetNumElementsBetween(const TensorShape& shape,
Narumol Prangnawarat4dc64a62019-09-16 17:00:22 +0100114 const unsigned int firstAxisInclusive,
115 const unsigned int lastAxisExclusive)
116{
117 BOOST_ASSERT(0 <= firstAxisInclusive);
118 BOOST_ASSERT(firstAxisInclusive <= lastAxisExclusive);
119 BOOST_ASSERT(lastAxisExclusive <= shape.GetNumDimensions());
120 unsigned int count = 1;
121 for (unsigned int i = firstAxisInclusive; i < lastAxisExclusive; i++)
122 {
123 count *= shape[i];
124 }
125 return count;
126}
127
128unsigned int GetUnsignedAxis(const unsigned int inputDimension, const int axis)
129{
130 BOOST_ASSERT_MSG(axis < boost::numeric_cast<int>(inputDimension),
131 "Required axis index greater than number of dimensions.");
132 BOOST_ASSERT_MSG(axis >= -boost::numeric_cast<int>(inputDimension),
133 "Required axis index lower than negative of the number of dimensions");
134
135 unsigned int uAxis = axis < 0 ?
136 inputDimension - boost::numeric_cast<unsigned int>(abs(axis))
137 : boost::numeric_cast<unsigned int>(axis);
138 return uAxis;
139}
140
Matteo Martincigh9a5f9f22019-10-31 11:02:47 +0000141} // namespace armnnUtils