blob: 74a358cc5c1127b32a1d88a8d74778a68d88c5d8 [file] [log] [blame]
Samuel Yap6b478092022-07-06 15:36:03 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "BatchMatMulImpl.hpp"
7
8#include <armnn/backends/WorkloadData.hpp>
9#include <armnn/Logging.hpp>
10
11namespace armnn
12{
13
14void BatchMatMul::BatchMatMulImpl()
15{
16 inputXData = inputXDecoder.DecodeTensor(inputXInfo.GetShape());
17 inputYData = inputYDecoder.DecodeTensor(inputYInfo.GetShape());
18 // At this point, we don't touch the input decoders - just the resultant vectors
19
20 // Pre-transpose and pre-adjoint if their vectors aren't empty
21 // and also DataLayouts which may change with permutations/adjoints
22
23 // Todo: Have you updated input validation and inferred output shapes to accommodate for these pre-permutes?
24
25 auto idx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
26 RecurseBMM(idx, 0);
27}
28
29void BatchMatMul::RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim)
30{
31 // We're working off of the indexes of the output tensor (the max possible shape)
32
33 if(!(curDim < outputInfo.GetNumDimensions()))
34 {
35 // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
36
37 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(params,
38 inputXInfo.GetShape(),
39 inputYInfo.GetShape());
40 AdjustAxesToMulForUnequalRanks(axesToMul);
41
42 unsigned int inputXColDim = axesToMul.first.second;
43 unsigned int inputYRowDim = axesToMul.second.first;
44
45 unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
46
47 float sum = 0.0f;
48
49 // You could also use inputXColSize
50 for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
51 auto xIdx = curIdx;
52 xIdx[inputXColDim] = inputYRowIdx;
53
54 auto yIdx = curIdx;
55 yIdx[inputYRowDim] = inputYRowIdx;
56
57 sum += (GetValueAt(DataSlot::InputX, xIdx)
58 * GetValueAt(DataSlot::InputY, yIdx));
59 }
60
61 SetValueAt(sum, DataSlot::Output, curIdx);
62
63 return;
64 }
65
66 for (unsigned int i = 0; i < outputInfo.GetShape()[curDim]; i++)
67 {
68 curIdx[curDim] = i;
69 RecurseBMM(curIdx, curDim+1);
70 }
71}
72
73void BatchMatMul::AdjustAxesToMulForUnequalRanks(
74 std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul)
75{
76 long rankDiff = static_cast<long>(inputXInfo.GetNumDimensions()) - inputYInfo.GetNumDimensions();
77 if(rankDiff == 0)
78 {
79 return;
80 }
81 else if(rankDiff < 0)
82 {
83 // Y is the larger one
84 axesToMul.first.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
85 axesToMul.first.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
86 }
87 else if(rankDiff > 0)
88 {
89 // X is the larger one
90 axesToMul.second.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
91 axesToMul.second.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
92 }
93}
94
95float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx)
96{
97 // This gets the data from the input vector that we have, Not the decoder
98 // But for the output, it is operating on the encoder itself
99
100 AdjustToSafeIdx(type, idx);
101 unsigned int flatIdx = CalcFlatIdx(type, idx);
102 float value = 0.0f;
103
104 switch(type)
105 {
106 case DataSlot::InputX:
107 value = inputXData[flatIdx];
108 break;
109 case DataSlot::InputY:
110 value = inputYData[flatIdx];
111 break;
112 case DataSlot::Output:
113 outputEncoder[flatIdx];
114 value = outputEncoder.Get();
115 break;
116 default:
117 break;
118 }
119
120 return value;
121}
122
123void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx)
124{
125 AdjustToSafeIdx(type, idx);
126
127 unsigned int flatIdx = CalcFlatIdx(type, idx);
128
129 switch(type)
130 {
131 case DataSlot::InputX:
132 inputXData[flatIdx] = value;
133 break;
134 case DataSlot::InputY:
135 inputYData[flatIdx] = value;
136 break;
137 case DataSlot::Output:
138 outputEncoder[flatIdx];
139 outputEncoder.Set(value);
140 break;
141 default:
142 break;
143 }
144}
145
146void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
147{
148 for(unsigned int dim = 0; dim < idx.size(); dim++)
149 {
150 switch(type)
151 {
152 case DataSlot::InputX:
153 {
154 auto xRank = inputXInfo.GetNumDimensions();
155 auto xDiff = outputInfo.GetNumDimensions() - xRank;
156 if (dim < xDiff ||
157 idx[dim] > inputXInfo.GetShape()[dim-xDiff]-1)
158 {
159 idx[dim] = 0; // Broadcasting
160 }
161 break;
162 }
163 case DataSlot::InputY:
164 {
165 auto yRank = inputYInfo.GetNumDimensions();
166 auto yDiff = outputInfo.GetNumDimensions() - yRank;
167 if (dim < yDiff ||
168 idx[dim] > inputYInfo.GetShape()[dim-yDiff]-1)
169 {
170 idx[dim] = 0;
171 }
172 break;
173 }
174 case DataSlot::Output:
175 {
176 // Our indices are based off the output
177 break;
178 }
179 default:
180 break;
181 }
182 }
183}
184
185unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx)
186{
187 unsigned int result = idx[idx.size()-1];
188
189 unsigned int dimMultiplier = 1;
190
191 unsigned int offset;
192
193 // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
194 for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0; i--)
195 {
196 switch(type)
197 {
198 case DataSlot::InputX:
199 offset = outputInfo.GetNumDimensions() - inputXInfo.GetNumDimensions();
200 dimMultiplier *= inputXInfo.GetShape()[i + 1 - offset];
201 break;
202 case DataSlot::InputY:
203 offset = outputInfo.GetNumDimensions() - inputYInfo.GetNumDimensions();
204 dimMultiplier *= inputYInfo.GetShape()[i + 1 - offset];
205 break;
206 case DataSlot::Output:
207 dimMultiplier *= outputInfo.GetShape()[i+1];
208 break;
209 default:
210 break;
211 }
212 result += (idx[i] * dimMultiplier);
213 }
214 return result;
215}
216
217template <typename T>
218std::string BatchMatMul::StringifyVec(const std::vector<T>& vec)
219{
220 std::string res = "{ ";
221 for(auto x : vec)
222 {
223 res += std::to_string(x);
224 res += " ";
225 }
226 res += "}";
227 return res;
228}
229
230} // namespace armnn