blob: 3218ca1aaca1c66bcfd9c0eb55e3c7da2c554132 [file] [log] [blame]
Pablo Telloeb82fd22018-02-23 13:43:50 +00001/*
2 * Copyright (c) 2017-2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#pragma once
25
26template <unsigned int IntBy, typename TIn, typename TOut>
27struct TransposeInterleaveCommon
28{
29 // Override the moveblock_1xY methods to improve performance
30 static inline void moveblock_1x1(const TIn *&in0, TOut *out)
31 {
32 for(unsigned int i = 0; i < IntBy; i++)
33 {
34 *out++ = static_cast<TOut>(*in0++);
35 }
36 }
37
38 static inline void moveblock_1x2(const TIn *&in0, const TIn *&in1, TOut *out)
39 {
40 for(unsigned int i = 0; i < IntBy; i++)
41 {
42 *out++ = static_cast<TOut>(*in0++);
43 }
44 for(unsigned int i = 0; i < IntBy; i++)
45 {
46 *out++ = static_cast<TOut>(*in1++);
47 }
48 }
49
50 static inline void moveblock_1x4(const TIn *&in0, const TIn *&in1, const TIn *&in2, const TIn *&in3, TOut *out)
51 {
52 for(unsigned int i = 0; i < IntBy; i++)
53 {
54 *out++ = static_cast<TOut>(*in0++);
55 }
56 for(unsigned int i = 0; i < IntBy; i++)
57 {
58 *out++ = static_cast<TOut>(*in1++);
59 }
60 for(unsigned int i = 0; i < IntBy; i++)
61 {
62 *out++ = static_cast<TOut>(*in2++);
63 }
64 for(unsigned int i = 0; i < IntBy; i++)
65 {
66 *out++ = static_cast<TOut>(*in3++);
67 }
68 }
69
70 static inline void Transform(TOut *out, const TIn *in, const int stride, const int x0, const int xmax, const int k0, const int kmax)
71 {
72 const auto ldin = stride;
73
74 TOut *outarray = out;
75 const TIn *inarray = in;
76 TOut *outptr_base = outarray;
77 const TIn *inptr_base = inarray + x0 + (k0 * ldin);
78 int ldout = (kmax - k0) * IntBy;
79
80 int k = (kmax - k0);
81 for(; k > 3; k -= 4)
82 {
83 TOut *outptr = outptr_base;
84 const TIn *inptr = inptr_base;
85 const TIn *inptr1 = inptr + ldin;
86 const TIn *inptr2 = inptr1 + ldin;
87 const TIn *inptr3 = inptr2 + ldin;
88
89 prefetch_3x(inptr);
90 prefetch_3x(inptr1);
91 prefetch_3x(inptr2);
92 prefetch_3x(inptr3);
93
94 outptr_base += IntBy * 4;
95 inptr_base += ldin * 4;
96
97 for(int x = (xmax - x0) / IntBy; x > 0; x--)
98 {
99 moveblock_1x4(inptr, inptr1, inptr2, inptr3, outptr);
100 outptr += ldout;
101 }
102 }
103
104 if(k)
105 {
106 TOut *outptr = outptr_base;
107 const TIn *inptr = inptr_base;
108 const TIn *inptr1 = inptr + ldin;
109 const TIn *inptr2 = inptr1 + ldin;
110
111 prefetch_3x(inptr);
112 prefetch_3x(inptr1);
113 prefetch_3x(inptr2);
114
115 for(int x = (xmax - x0) / IntBy; x > 0; x--)
116 {
117 switch(k)
118 {
119 case 3:
120 moveblock_1x2(inptr, inptr1, outptr);
121 moveblock_1x1(inptr2, outptr + IntBy * 2);
122 break;
123
124 case 2:
125 moveblock_1x2(inptr, inptr1, outptr);
126 break;
127
128 case 1:
129 moveblock_1x1(inptr, outptr);
130 break;
131
132 default:
133 UNREACHABLE("Impossible.");
134 }
135
136 outptr += ldout;
137 }
138 }
139
140 // Cope with ragged X cases
141 const unsigned int overflow = (xmax - x0) % IntBy;
142 if(overflow)
143 {
144 const TIn *inptr_base = inarray + (xmax - overflow) + (k0 * ldin);
145 TOut *outptr = outarray + ((xmax - x0) / IntBy) * ldout;
146
147 for(int k = (kmax - k0); k > 0; k--)
148 {
149 const TIn *inptr = inptr_base;
150 inptr_base += ldin;
151
152 for(unsigned int x = 0; x < IntBy; x++)
153 {
154 TOut val = (x < overflow) ? static_cast<TOut>(*inptr++) : static_cast<TOut>(0);
155 *outptr++ = val;
156 }
157 }
158 }
159 }
160};