blob: 3e0c40d8909e0a9ce61cd935b64d1a3e325cfa18 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#pragma once
7
Matteo Martincigh747ef822018-12-18 09:26:39 +00008#include "CpuTensorHandle.hpp"
Kevin May665a964a2019-08-21 16:53:50 +01009#include "ITensorHandle.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010010
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000011#include <armnn/Tensor.hpp>
12
Kevin May665a964a2019-08-21 16:53:50 +010013#include <Half.hpp>
Matteo Martincigh747ef822018-12-18 09:26:39 +000014#include <Permute.hpp>
15#include <Profiling.hpp>
Matteo Martincigh747ef822018-12-18 09:26:39 +000016
telsoa01c577f2c2018-08-31 09:22:23 +010017#include <boost/cast.hpp>
18
19namespace armnn
20{
21namespace
22{
Matteo Martincigh747ef822018-12-18 09:26:39 +000023
Kevin May665a964a2019-08-21 16:53:50 +010024template <typename ArrayType, typename Arg>
telsoa01c577f2c2018-08-31 09:22:23 +010025void AssignValues(unsigned int num, unsigned int& idx, const ArrayType& array, Arg& arg)
26{
Matteo Martincigh747ef822018-12-18 09:26:39 +000027 if (idx >= num)
28 {
29 return;
30 }
telsoa01c577f2c2018-08-31 09:22:23 +010031
Matteo Martincigh747ef822018-12-18 09:26:39 +000032 arg = array[(num - 1) - idx];
33 idx++;
34}
telsoa01c577f2c2018-08-31 09:22:23 +010035
Kevin May665a964a2019-08-21 16:53:50 +010036template <typename T, typename ArrayType, typename... Args>
37void AssignValues(unsigned int num, unsigned int idx, const ArrayType& array, T& assignee, Args&... args)
telsoa01c577f2c2018-08-31 09:22:23 +010038{
Matteo Martincigh747ef822018-12-18 09:26:39 +000039 AssignValues(num, idx, array, assignee);
telsoa01c577f2c2018-08-31 09:22:23 +010040
Matteo Martincigh747ef822018-12-18 09:26:39 +000041 AssignValues(num, idx, array, args...);
telsoa01c577f2c2018-08-31 09:22:23 +010042}
Matteo Martincigh747ef822018-12-18 09:26:39 +000043
Kevin May665a964a2019-08-21 16:53:50 +010044} // anonymous namespace
telsoa01c577f2c2018-08-31 09:22:23 +010045
Kevin May665a964a2019-08-21 16:53:50 +010046template <typename CopyFunc>
telsoa01c577f2c2018-08-31 09:22:23 +010047void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* dstTensor, CopyFunc copy)
48{
Matthew Benthamefdbca62019-09-14 23:35:28 +010049 // For ease of understanding, names are assigned to the dimensions
50 // of the tensor as if NHWC, however this routine works with any 5D tensor
Matthew Jacksondba634f2019-08-15 15:14:18 +010051 static_assert(MaxNumOfTensorDimensions == 5, "Please update CopyTensorContents");
telsoa01c577f2c2018-08-31 09:22:23 +010052
Kevin May665a964a2019-08-21 16:53:50 +010053 TensorShape srcStrides = srcTensor->GetStrides();
telsoa01c577f2c2018-08-31 09:22:23 +010054 const TensorShape& srcShape = srcTensor->GetShape();
Kevin May665a964a2019-08-21 16:53:50 +010055 TensorShape dstStrides = dstTensor->GetStrides();
telsoa01c577f2c2018-08-31 09:22:23 +010056 const TensorShape& dstShape = dstTensor->GetShape();
57
Kevin May665a964a2019-08-21 16:53:50 +010058 size_t srcDepth = 1;
59 size_t srcBatches = 1;
Kevin May665a964a2019-08-21 16:53:50 +010060 size_t srcHeight = 1;
61 size_t srcWidth = 1;
Matthew Benthamefdbca62019-09-14 23:35:28 +010062 size_t srcChannels = 1;
Kevin May665a964a2019-08-21 16:53:50 +010063 AssignValues(srcShape.GetNumDimensions(),
64 0,
65 srcShape,
Matthew Benthamefdbca62019-09-14 23:35:28 +010066 srcChannels,
telsoa01c577f2c2018-08-31 09:22:23 +010067 srcWidth,
68 srcHeight,
Matthew Jacksondba634f2019-08-15 15:14:18 +010069 srcBatches,
70 srcDepth);
telsoa01c577f2c2018-08-31 09:22:23 +010071
Kevin May665a964a2019-08-21 16:53:50 +010072 size_t srcDepthStride = 0;
73 size_t srcBatchStride = 0;
Kevin May665a964a2019-08-21 16:53:50 +010074 size_t srcHeightStride = 0;
75 size_t srcWidthStride = 0;
Matthew Benthamefdbca62019-09-14 23:35:28 +010076 size_t srcChannelStride = 0;
Kevin May665a964a2019-08-21 16:53:50 +010077 AssignValues(srcStrides.GetNumDimensions(),
78 0,
79 srcStrides,
Matthew Benthamefdbca62019-09-14 23:35:28 +010080 srcChannelStride,
telsoa01c577f2c2018-08-31 09:22:23 +010081 srcWidthStride,
82 srcHeightStride,
Matthew Jacksondba634f2019-08-15 15:14:18 +010083 srcBatchStride,
84 srcDepthStride);
telsoa01c577f2c2018-08-31 09:22:23 +010085
Kevin May665a964a2019-08-21 16:53:50 +010086 size_t dstDepth = 1;
87 size_t dstBatches = 1;
Kevin May665a964a2019-08-21 16:53:50 +010088 size_t dstHeight = 1;
89 size_t dstWidth = 1;
Matthew Benthamefdbca62019-09-14 23:35:28 +010090 size_t dstChannels = 1;
Kevin May665a964a2019-08-21 16:53:50 +010091 AssignValues(dstShape.GetNumDimensions(),
92 0,
93 dstShape,
Matthew Benthamefdbca62019-09-14 23:35:28 +010094 dstChannels,
telsoa01c577f2c2018-08-31 09:22:23 +010095 dstWidth,
96 dstHeight,
Matthew Jacksondba634f2019-08-15 15:14:18 +010097 dstBatches,
98 dstDepth);
telsoa01c577f2c2018-08-31 09:22:23 +010099
Kevin May665a964a2019-08-21 16:53:50 +0100100 size_t dstDepthStride = 0;
101 size_t dstBatchStride = 0;
Kevin May665a964a2019-08-21 16:53:50 +0100102 size_t dstHeightStride = 0;
103 size_t dstWidthStride = 0;
Matthew Benthamefdbca62019-09-14 23:35:28 +0100104 size_t dstChannelStride = 0;
Kevin May665a964a2019-08-21 16:53:50 +0100105 AssignValues(dstStrides.GetNumDimensions(),
106 0,
107 dstStrides,
Matthew Benthamefdbca62019-09-14 23:35:28 +0100108 dstChannelStride,
telsoa01c577f2c2018-08-31 09:22:23 +0100109 dstWidthStride,
110 dstHeightStride,
Matthew Jacksondba634f2019-08-15 15:14:18 +0100111 dstBatchStride,
112 dstDepthStride);
telsoa01c577f2c2018-08-31 09:22:23 +0100113
Sadik Armaganbf86d512018-12-24 09:01:31 +0000114 const unsigned char* srcData;
115 unsigned char* dstData;
116 {
117 ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Synchronize buffers");
Kevin May665a964a2019-08-21 16:53:50 +0100118 srcData = static_cast<const uint8_t*>(srcTensor->Map());
119 dstData = static_cast<uint8_t*>(dstTensor->Map());
Sadik Armaganbf86d512018-12-24 09:01:31 +0000120 }
telsoa01c577f2c2018-08-31 09:22:23 +0100121
Matthew Benthamefdbca62019-09-14 23:35:28 +0100122 size_t copyLength = std::min(srcChannels*srcChannelStride, dstChannels*dstChannelStride);
123 size_t copyWidth = std::min(srcWidth, dstWidth);
124 size_t copyHeight = std::min(srcHeight, dstHeight);
125 size_t copyBatches = std::min(srcBatches, dstBatches);
126 size_t copyDepth = std::min(srcDepth, dstDepth);
telsoa01c577f2c2018-08-31 09:22:23 +0100127
Kevin May665a964a2019-08-21 16:53:50 +0100128 for (unsigned int d = 0; d < copyDepth; ++d)
telsoa01c577f2c2018-08-31 09:22:23 +0100129 {
Matthew Jacksondba634f2019-08-15 15:14:18 +0100130 auto srcPtrDepth = srcData;
131 auto dstPtrDepth = dstData;
Kevin May665a964a2019-08-21 16:53:50 +0100132 for (unsigned int b = 0; b < copyBatches; ++b)
telsoa01c577f2c2018-08-31 09:22:23 +0100133 {
Matthew Jacksondba634f2019-08-15 15:14:18 +0100134 auto srcPtrBatch = srcData;
135 auto dstPtrBatch = dstData;
Matthew Benthamefdbca62019-09-14 23:35:28 +0100136 for (unsigned int h = 0; h < copyHeight; ++h)
telsoa01c577f2c2018-08-31 09:22:23 +0100137 {
Matthew Jacksondba634f2019-08-15 15:14:18 +0100138 auto srcPtrChannel = srcData;
139 auto dstPtrChannel = dstData;
Matthew Benthamefdbca62019-09-14 23:35:28 +0100140 for (unsigned int w = 0; w < copyWidth; ++w)
Matthew Jacksondba634f2019-08-15 15:14:18 +0100141 {
142 copy(dstData, srcData, copyLength);
Matthew Benthamefdbca62019-09-14 23:35:28 +0100143 dstData += dstWidthStride;
144 srcData += srcWidthStride;
Matthew Jacksondba634f2019-08-15 15:14:18 +0100145 }
Matthew Benthamefdbca62019-09-14 23:35:28 +0100146 dstData += (static_cast<long>(dstHeightStride) - (dstData - dstPtrChannel));
147 srcData += (static_cast<long>(srcHeightStride) - (srcData - srcPtrChannel));
telsoa01c577f2c2018-08-31 09:22:23 +0100148 }
Kevin May665a964a2019-08-21 16:53:50 +0100149 dstData += (static_cast<long>(dstBatchStride) - (dstData - dstPtrBatch));
150 srcData += (static_cast<long>(srcBatchStride) - (srcData - srcPtrBatch));
telsoa01c577f2c2018-08-31 09:22:23 +0100151 }
Kevin May665a964a2019-08-21 16:53:50 +0100152 dstData += (static_cast<long>(dstDepthStride) - (dstData - dstPtrDepth));
153 srcData += (static_cast<long>(srcDepthStride) - (srcData - srcPtrDepth));
telsoa01c577f2c2018-08-31 09:22:23 +0100154 }
155
156 srcTensor->Unmap();
157 dstTensor->Unmap();
158}
159
160template <typename SrcTensorHandleType, typename DstTensorHandleType, typename DescriptorType>
161void GatherTensorHandlePairs(const DescriptorType& descriptor,
162 std::vector<std::pair<SrcTensorHandleType*, DstTensorHandleType*>>& tensorHandlePairs)
163{
164 const unsigned int numInputs = static_cast<unsigned int>(descriptor.m_Inputs.size());
165 tensorHandlePairs.reserve(numInputs);
166
167 for (unsigned int i = 0; i < numInputs; ++i)
168 {
Kevin May665a964a2019-08-21 16:53:50 +0100169 SrcTensorHandleType* const srcTensorHandle =
170 boost::polymorphic_downcast<SrcTensorHandleType*>(descriptor.m_Inputs[i]);
171 DstTensorHandleType* const dstTensorHandle =
172 boost::polymorphic_downcast<DstTensorHandleType*>(descriptor.m_Outputs[i]);
telsoa01c577f2c2018-08-31 09:22:23 +0100173
174 tensorHandlePairs.emplace_back(srcTensorHandle, dstTensorHandle);
175 }
176}
177
Matteo Martincigh747ef822018-12-18 09:26:39 +0000178armnn::ConstTensor PermuteTensor(const ConstCpuTensorHandle* tensor,
179 const PermutationVector& permutationVector,
180 void* permuteBuffer);
181
182void ReshapeWeightsForAcl(TensorInfo& weightInfo, DataLayout dataLayout);
183
184TensorInfo ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo& weightInfo, DataLayout dataLayout);
185
186armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstCpuTensorHandle* weightTensor,
187 DataLayout dataLayout,
188 void* permuteBuffer);
189
Kevin May665a964a2019-08-21 16:53:50 +0100190} //namespace armnn