blob: fa387a7a0b3a8d3e159dbc02c6c361101ca3a996 [file] [log] [blame]
Matteo Martincigh747ef822018-12-18 09:26:39 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "WorkloadUtils.hpp"
7
8namespace armnn
9{
10
11armnn::ConstTensor PermuteTensor(const ConstCpuTensorHandle* tensor,
12 const PermutationVector& permutationVector,
13 void* permuteBuffer)
14{
15 BOOST_ASSERT_MSG(tensor, "Invalid input tensor");
16 BOOST_ASSERT_MSG(permuteBuffer, "Invalid permute buffer");
17
18 TensorInfo tensorInfo = tensor->GetTensorInfo();
19
20 if (permutationVector.GetSize() > 0)
21 {
22 tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector);
23 armnnUtils::Permute(tensorInfo.GetShape(), permutationVector,
24 tensor->GetConstTensor<void>(), permuteBuffer,
25 GetDataTypeSize(tensorInfo.GetDataType()));
26 }
27 else
28 {
29 ::memcpy(permuteBuffer, tensor->GetConstTensor<void>(), tensorInfo.GetNumBytes());
30 }
31
32 return ConstTensor(tensorInfo, permuteBuffer);
33}
34
35void ReshapeWeightsForAcl(TensorInfo& weightInfo, DataLayout dataLayout)
36{
37 // Reshape the weights in-place
38 const TensorShape& weightShape = weightInfo.GetShape();
39 switch (dataLayout)
40 {
41 case DataLayout::NHWC:
42 // The data layout is NHWC, reshape from [ H, W, I, M ] to [ 1, H, W, I * M ]
43 weightInfo.SetShape({ 1,
44 weightShape[0],
45 weightShape[1],
46 weightShape[2] * weightShape[3] });
47 break;
48 case DataLayout::NCHW:
49 default:
50 // The data layout is NCHW, reshape from [ M, I, H, W ] to [ 1, I * M, H, W, ]
51 weightInfo.SetShape({ 1,
52 weightShape[0] * weightShape[1],
53 weightShape[2],
54 weightShape[3] });
55 break;
56 }
57}
58
59TensorInfo ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo& weightInfo, DataLayout dataLayout)
60{
61 // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either
62 // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library
63
64 // 1. Permute the weights if necessary
65 // If the data layout is NCHW no permutation is necessary, as a reshape to [ 1, I * M, H, W ] can be better done
66 // starting from the current shape of [ M, I, H, W ]
67 TensorInfo weightPermutedInfo(weightInfo);
68 if (dataLayout == DataLayout::NHWC)
69 {
70 // The data layout is NHWC, then permute the weights from [ M, I, H, W ] to [ H, W, I, M ]
71 PermutationVector permutationVector{ 3, 2, 0, 1 };
72 weightPermutedInfo = armnnUtils::Permuted(weightInfo, permutationVector);
73 }
74
75 // 2. Reshape the weights
76 ReshapeWeightsForAcl(weightPermutedInfo, dataLayout);
77
78 // 3. Return the permuted weight info
79 return weightPermutedInfo;
80}
81
82armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstCpuTensorHandle* weightTensor,
83 DataLayout dataLayout,
84 void* permuteBuffer)
85{
86 BOOST_ASSERT_MSG(weightTensor, "Invalid input tensor");
87 BOOST_ASSERT_MSG(permuteBuffer, "Invalid permute buffer");
88
89 // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either
90 // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library
91
92 // 1. Permute the weights if necessary
93 // If the data layout is NCHW no permutation is necessary, as a reshape to [ 1, I * M, H, W ] can be better done
94 // starting from the current shape of [ M, I, H, W ]
95 // If no permutation is necessary, leave the permutation vector empty
96 PermutationVector permutationVector{};
97 if (dataLayout == DataLayout::NHWC)
98 {
99 // The data layout is NHWC, then permute the weights from [ M, I, H, W ] to [ H, W, I, M ]
100 permutationVector = { 3, 2, 0, 1 };
101 }
102 ConstTensor weightPermuted = PermuteTensor(weightTensor, permutationVector, permuteBuffer);
103
104 // 2. Reshape the weights
105 ReshapeWeightsForAcl(weightPermuted.GetInfo(), dataLayout);
106
107 // 3. Return both the tensor and the allocated storage to ensure that the data stays alive
108 return weightPermuted;
109}
110
111} // namespace armnn