blob: 65c58eabd9793d291ac35af8d191b3021e4cbd8b [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#pragma once
7
8#include "armnn/Tensor.hpp"
9#include "ITensorHandle.hpp"
10
11#include <boost/cast.hpp>
12
13namespace armnn
14{
15namespace
16{
17template<typename ArrayType, typename Arg>
18void AssignValues(unsigned int num, unsigned int& idx, const ArrayType& array, Arg& arg)
19{
20 if (idx >= num)
21 {
22 return;
23 }
24
25 arg = array[(num - 1) - idx];
26 idx++;
27};
28
29template<typename T, typename ArrayType, typename ...Args>
30void AssignValues(unsigned int num, unsigned int idx, const ArrayType& array, T& assignee, Args& ... args)
31{
32 AssignValues(num, idx, array, assignee);
33
34 AssignValues(num, idx, array, args...);
35}
36} // namespace
37
38template<typename CopyFunc>
39void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* dstTensor, CopyFunc copy)
40{
41 static_assert(MaxNumOfTensorDimensions == 4, "Please update CopyTensorContents");
42
43 TensorShape srcStrides = srcTensor->GetStrides();
44 const TensorShape& srcShape = srcTensor->GetShape();
45 TensorShape dstStrides = dstTensor->GetStrides();
46 const TensorShape& dstShape = dstTensor->GetShape();
47
48 size_t srcBatches = 1;
49 size_t srcChannels = 1;
50 size_t srcHeight = 1;
51 size_t srcWidth = 1;
52 AssignValues(srcShape.GetNumDimensions(),0, srcShape,
53 srcWidth,
54 srcHeight,
55 srcChannels,
56 srcBatches);
57
58 size_t srcBatchStride = 0;
59 size_t srcChannelStride = 0;
60 size_t srcHeightStride = 0;
61 size_t srcWidthStride = 0;
62 AssignValues(srcStrides.GetNumDimensions(),0, srcStrides,
63 srcWidthStride,
64 srcHeightStride,
65 srcChannelStride,
66 srcBatchStride);
67
68 size_t dstBatches = 1;
69 size_t dstChannels = 1;
70 size_t dstHeight = 1;
71 size_t dstWidth = 1;
72 AssignValues(dstShape.GetNumDimensions(),0, dstShape,
73 dstWidth,
74 dstHeight,
75 dstChannels,
76 dstBatches);
77
78 size_t dstBatchStride = 0;
79 size_t dstChannelStride = 0;
80 size_t dstHeightStride = 0;
81 size_t dstWidthStride = 0;
82 AssignValues(dstStrides.GetNumDimensions(),0, dstStrides,
83 dstWidthStride,
84 dstHeightStride,
85 dstChannelStride,
86 dstBatchStride);
87
88 auto srcData = static_cast<const uint8_t*>(srcTensor->Map());
89 auto dstData = static_cast<uint8_t*>(dstTensor->Map());
90
91 size_t copyLength = std::min(srcWidth*srcWidthStride, dstWidth*dstWidthStride);
92 size_t copyHeight = std::min(srcHeight, dstHeight);
93 size_t copyChannels = std::min(srcChannels, dstChannels);
94 size_t copyBatches = std::min(srcBatches, dstBatches);
95
96 for(unsigned int b=0; b < copyBatches; ++b)
97 {
98 auto srcPtrBatch = srcData;
99 auto dstPtrBatch = dstData;
100 for (unsigned int c=0; c< copyChannels; ++c)
101 {
102 auto srcPtrChannel = srcData;
103 auto dstPtrChannel = dstData;
104 for (unsigned int h=0; h < copyHeight; ++h)
105 {
106 copy(dstData, srcData, copyLength);
107 dstData += dstHeightStride;
108 srcData += srcHeightStride;
109 }
110 dstData += (static_cast<long>(dstChannelStride) - (dstData - dstPtrChannel));
111 srcData += (static_cast<long>(srcChannelStride) - (srcData - srcPtrChannel));
112 }
113 dstData += (static_cast<long>(dstBatchStride)-(dstData - dstPtrBatch));
114 srcData += (static_cast<long>(srcBatchStride)-(srcData - srcPtrBatch));
115 }
116
117 srcTensor->Unmap();
118 dstTensor->Unmap();
119}
120
121template <typename SrcTensorHandleType, typename DstTensorHandleType, typename DescriptorType>
122void GatherTensorHandlePairs(const DescriptorType& descriptor,
123 std::vector<std::pair<SrcTensorHandleType*, DstTensorHandleType*>>& tensorHandlePairs)
124{
125 const unsigned int numInputs = static_cast<unsigned int>(descriptor.m_Inputs.size());
126 tensorHandlePairs.reserve(numInputs);
127
128 for (unsigned int i = 0; i < numInputs; ++i)
129 {
130 SrcTensorHandleType* const srcTensorHandle = boost::polymorphic_downcast<SrcTensorHandleType*>(
131 descriptor.m_Inputs[i]);
132 DstTensorHandleType* const dstTensorHandle = boost::polymorphic_downcast<DstTensorHandleType*>(
133 descriptor.m_Outputs[i]);
134
135 tensorHandlePairs.emplace_back(srcTensorHandle, dstTensorHandle);
136 }
137}
138
139} //namespace armnn