blob: ba6925518324dfdd5dace6172f599e31a3939188 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#pragma once
7
Matteo Martincigh747ef822018-12-18 09:26:39 +00008#include "CpuTensorHandle.hpp"
Kevin May665a964a2019-08-21 16:53:50 +01009#include "ITensorHandle.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010010
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000011#include <armnn/Tensor.hpp>
12
Kevin May665a964a2019-08-21 16:53:50 +010013#include <Half.hpp>
Matteo Martincigh747ef822018-12-18 09:26:39 +000014#include <Permute.hpp>
15#include <Profiling.hpp>
Matteo Martincigh747ef822018-12-18 09:26:39 +000016
telsoa01c577f2c2018-08-31 09:22:23 +010017#include <boost/cast.hpp>
18
19namespace armnn
20{
21namespace
22{
Matteo Martincigh747ef822018-12-18 09:26:39 +000023
Kevin May665a964a2019-08-21 16:53:50 +010024template <typename ArrayType, typename Arg>
telsoa01c577f2c2018-08-31 09:22:23 +010025void AssignValues(unsigned int num, unsigned int& idx, const ArrayType& array, Arg& arg)
26{
Matteo Martincigh747ef822018-12-18 09:26:39 +000027 if (idx >= num)
28 {
29 return;
30 }
telsoa01c577f2c2018-08-31 09:22:23 +010031
Matteo Martincigh747ef822018-12-18 09:26:39 +000032 arg = array[(num - 1) - idx];
33 idx++;
34}
telsoa01c577f2c2018-08-31 09:22:23 +010035
Kevin May665a964a2019-08-21 16:53:50 +010036template <typename T, typename ArrayType, typename... Args>
37void AssignValues(unsigned int num, unsigned int idx, const ArrayType& array, T& assignee, Args&... args)
telsoa01c577f2c2018-08-31 09:22:23 +010038{
Matteo Martincigh747ef822018-12-18 09:26:39 +000039 AssignValues(num, idx, array, assignee);
telsoa01c577f2c2018-08-31 09:22:23 +010040
Matteo Martincigh747ef822018-12-18 09:26:39 +000041 AssignValues(num, idx, array, args...);
telsoa01c577f2c2018-08-31 09:22:23 +010042}
Matteo Martincigh747ef822018-12-18 09:26:39 +000043
Kevin May665a964a2019-08-21 16:53:50 +010044} // anonymous namespace
telsoa01c577f2c2018-08-31 09:22:23 +010045
Kevin May665a964a2019-08-21 16:53:50 +010046template <typename CopyFunc>
telsoa01c577f2c2018-08-31 09:22:23 +010047void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* dstTensor, CopyFunc copy)
48{
Matthew Jacksondba634f2019-08-15 15:14:18 +010049 static_assert(MaxNumOfTensorDimensions == 5, "Please update CopyTensorContents");
telsoa01c577f2c2018-08-31 09:22:23 +010050
Kevin May665a964a2019-08-21 16:53:50 +010051 TensorShape srcStrides = srcTensor->GetStrides();
telsoa01c577f2c2018-08-31 09:22:23 +010052 const TensorShape& srcShape = srcTensor->GetShape();
Kevin May665a964a2019-08-21 16:53:50 +010053 TensorShape dstStrides = dstTensor->GetStrides();
telsoa01c577f2c2018-08-31 09:22:23 +010054 const TensorShape& dstShape = dstTensor->GetShape();
55
Kevin May665a964a2019-08-21 16:53:50 +010056 size_t srcDepth = 1;
57 size_t srcBatches = 1;
telsoa01c577f2c2018-08-31 09:22:23 +010058 size_t srcChannels = 1;
Kevin May665a964a2019-08-21 16:53:50 +010059 size_t srcHeight = 1;
60 size_t srcWidth = 1;
61 AssignValues(srcShape.GetNumDimensions(),
62 0,
63 srcShape,
telsoa01c577f2c2018-08-31 09:22:23 +010064 srcWidth,
65 srcHeight,
66 srcChannels,
Matthew Jacksondba634f2019-08-15 15:14:18 +010067 srcBatches,
68 srcDepth);
telsoa01c577f2c2018-08-31 09:22:23 +010069
Kevin May665a964a2019-08-21 16:53:50 +010070 size_t srcDepthStride = 0;
71 size_t srcBatchStride = 0;
telsoa01c577f2c2018-08-31 09:22:23 +010072 size_t srcChannelStride = 0;
Kevin May665a964a2019-08-21 16:53:50 +010073 size_t srcHeightStride = 0;
74 size_t srcWidthStride = 0;
75 AssignValues(srcStrides.GetNumDimensions(),
76 0,
77 srcStrides,
telsoa01c577f2c2018-08-31 09:22:23 +010078 srcWidthStride,
79 srcHeightStride,
80 srcChannelStride,
Matthew Jacksondba634f2019-08-15 15:14:18 +010081 srcBatchStride,
82 srcDepthStride);
telsoa01c577f2c2018-08-31 09:22:23 +010083
Kevin May665a964a2019-08-21 16:53:50 +010084 size_t dstDepth = 1;
85 size_t dstBatches = 1;
telsoa01c577f2c2018-08-31 09:22:23 +010086 size_t dstChannels = 1;
Kevin May665a964a2019-08-21 16:53:50 +010087 size_t dstHeight = 1;
88 size_t dstWidth = 1;
89 AssignValues(dstShape.GetNumDimensions(),
90 0,
91 dstShape,
telsoa01c577f2c2018-08-31 09:22:23 +010092 dstWidth,
93 dstHeight,
94 dstChannels,
Matthew Jacksondba634f2019-08-15 15:14:18 +010095 dstBatches,
96 dstDepth);
telsoa01c577f2c2018-08-31 09:22:23 +010097
Kevin May665a964a2019-08-21 16:53:50 +010098 size_t dstDepthStride = 0;
99 size_t dstBatchStride = 0;
telsoa01c577f2c2018-08-31 09:22:23 +0100100 size_t dstChannelStride = 0;
Kevin May665a964a2019-08-21 16:53:50 +0100101 size_t dstHeightStride = 0;
102 size_t dstWidthStride = 0;
103 AssignValues(dstStrides.GetNumDimensions(),
104 0,
105 dstStrides,
telsoa01c577f2c2018-08-31 09:22:23 +0100106 dstWidthStride,
107 dstHeightStride,
108 dstChannelStride,
Matthew Jacksondba634f2019-08-15 15:14:18 +0100109 dstBatchStride,
110 dstDepthStride);
telsoa01c577f2c2018-08-31 09:22:23 +0100111
Sadik Armaganbf86d512018-12-24 09:01:31 +0000112 const unsigned char* srcData;
113 unsigned char* dstData;
114 {
115 ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Synchronize buffers");
Kevin May665a964a2019-08-21 16:53:50 +0100116 srcData = static_cast<const uint8_t*>(srcTensor->Map());
117 dstData = static_cast<uint8_t*>(dstTensor->Map());
Sadik Armaganbf86d512018-12-24 09:01:31 +0000118 }
telsoa01c577f2c2018-08-31 09:22:23 +0100119
Kevin May665a964a2019-08-21 16:53:50 +0100120 size_t copyLength = std::min(srcWidth * srcWidthStride, dstWidth * dstWidthStride);
121 size_t copyHeight = std::min(srcHeight, dstHeight);
telsoa01c577f2c2018-08-31 09:22:23 +0100122 size_t copyChannels = std::min(srcChannels, dstChannels);
Kevin May665a964a2019-08-21 16:53:50 +0100123 size_t copyBatches = std::min(srcBatches, dstBatches);
124 size_t copyDepth = std::min(srcDepth, dstDepth);
telsoa01c577f2c2018-08-31 09:22:23 +0100125
Kevin May665a964a2019-08-21 16:53:50 +0100126 for (unsigned int d = 0; d < copyDepth; ++d)
telsoa01c577f2c2018-08-31 09:22:23 +0100127 {
Matthew Jacksondba634f2019-08-15 15:14:18 +0100128 auto srcPtrDepth = srcData;
129 auto dstPtrDepth = dstData;
Kevin May665a964a2019-08-21 16:53:50 +0100130 for (unsigned int b = 0; b < copyBatches; ++b)
telsoa01c577f2c2018-08-31 09:22:23 +0100131 {
Matthew Jacksondba634f2019-08-15 15:14:18 +0100132 auto srcPtrBatch = srcData;
133 auto dstPtrBatch = dstData;
Kevin May665a964a2019-08-21 16:53:50 +0100134 for (unsigned int c = 0; c < copyChannels; ++c)
telsoa01c577f2c2018-08-31 09:22:23 +0100135 {
Matthew Jacksondba634f2019-08-15 15:14:18 +0100136 auto srcPtrChannel = srcData;
137 auto dstPtrChannel = dstData;
Kevin May665a964a2019-08-21 16:53:50 +0100138 for (unsigned int h = 0; h < copyHeight; ++h)
Matthew Jacksondba634f2019-08-15 15:14:18 +0100139 {
140 copy(dstData, srcData, copyLength);
141 dstData += dstHeightStride;
142 srcData += srcHeightStride;
143 }
144 dstData += (static_cast<long>(dstChannelStride) - (dstData - dstPtrChannel));
145 srcData += (static_cast<long>(srcChannelStride) - (srcData - srcPtrChannel));
telsoa01c577f2c2018-08-31 09:22:23 +0100146 }
Kevin May665a964a2019-08-21 16:53:50 +0100147 dstData += (static_cast<long>(dstBatchStride) - (dstData - dstPtrBatch));
148 srcData += (static_cast<long>(srcBatchStride) - (srcData - srcPtrBatch));
telsoa01c577f2c2018-08-31 09:22:23 +0100149 }
Kevin May665a964a2019-08-21 16:53:50 +0100150 dstData += (static_cast<long>(dstDepthStride) - (dstData - dstPtrDepth));
151 srcData += (static_cast<long>(srcDepthStride) - (srcData - srcPtrDepth));
telsoa01c577f2c2018-08-31 09:22:23 +0100152 }
153
154 srcTensor->Unmap();
155 dstTensor->Unmap();
156}
157
158template <typename SrcTensorHandleType, typename DstTensorHandleType, typename DescriptorType>
159void GatherTensorHandlePairs(const DescriptorType& descriptor,
160 std::vector<std::pair<SrcTensorHandleType*, DstTensorHandleType*>>& tensorHandlePairs)
161{
162 const unsigned int numInputs = static_cast<unsigned int>(descriptor.m_Inputs.size());
163 tensorHandlePairs.reserve(numInputs);
164
165 for (unsigned int i = 0; i < numInputs; ++i)
166 {
Kevin May665a964a2019-08-21 16:53:50 +0100167 SrcTensorHandleType* const srcTensorHandle =
168 boost::polymorphic_downcast<SrcTensorHandleType*>(descriptor.m_Inputs[i]);
169 DstTensorHandleType* const dstTensorHandle =
170 boost::polymorphic_downcast<DstTensorHandleType*>(descriptor.m_Outputs[i]);
telsoa01c577f2c2018-08-31 09:22:23 +0100171
172 tensorHandlePairs.emplace_back(srcTensorHandle, dstTensorHandle);
173 }
174}
175
Matteo Martincigh747ef822018-12-18 09:26:39 +0000176armnn::ConstTensor PermuteTensor(const ConstCpuTensorHandle* tensor,
177 const PermutationVector& permutationVector,
178 void* permuteBuffer);
179
180void ReshapeWeightsForAcl(TensorInfo& weightInfo, DataLayout dataLayout);
181
182TensorInfo ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo& weightInfo, DataLayout dataLayout);
183
184armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstCpuTensorHandle* weightTensor,
185 DataLayout dataLayout,
186 void* permuteBuffer);
187
Kevin May665a964a2019-08-21 16:53:50 +0100188} //namespace armnn