blob: c592b3b76c1792444781b15c754b86228543353e [file] [log] [blame]
Samuel Yap6b478092022-07-06 15:36:03 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "BatchMatMulImpl.hpp"
7
8#include <armnn/backends/WorkloadData.hpp>
9#include <armnn/Logging.hpp>
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010010#include <armnnUtils/Permute.hpp>
Samuel Yap6b478092022-07-06 15:36:03 +010011
12namespace armnn
13{
14
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010015BatchMatMul::BatchMatMul(const BatchMatMulDescriptor& params,
16 const TensorInfo& inputXInfo,
17 const TensorInfo& inputYInfo,
18 const TensorInfo& outputInfo,
19 Decoder<float>& inputXDecoder,
20 Decoder<float>& inputYDecoder,
21 Encoder<float>& outputEncoder)
22 : params(params),
23 inputXInfo(inputXInfo),
24 inputYInfo(inputYInfo),
25 outputInfo(outputInfo),
26 inputXDecoder(inputXDecoder),
27 inputYDecoder(inputYDecoder),
28 outputEncoder(outputEncoder)
Samuel Yap6b478092022-07-06 15:36:03 +010029{
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010030 inputXData = this->inputXDecoder.DecodeTensor(inputXInfo.GetShape());
31 inputYData = this->inputYDecoder.DecodeTensor(inputYInfo.GetShape());
Samuel Yap6b478092022-07-06 15:36:03 +010032 // At this point, we don't touch the input decoders - just the resultant vectors
33
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010034 ApplyParams();
Samuel Yap6b478092022-07-06 15:36:03 +010035
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010036 ApplyBatchMatMul();
Samuel Yap6b478092022-07-06 15:36:03 +010037}
38
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010039void BatchMatMul::ApplyBatchMatMul()
Samuel Yap6b478092022-07-06 15:36:03 +010040{
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010041 auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutX,
42 inputXInfo.GetShape());
43 auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutY,
44 inputYInfo.GetShape());
45 AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul);
Samuel Yap6b478092022-07-06 15:36:03 +010046
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010047 unsigned int inputXColDim = axesXToMul.second;
48 unsigned int inputYRowDim = axesYToMul.first;
49
50 unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
51
52 auto batchMatMulOperation = [&](const std::vector<unsigned int>& curIdx)
Samuel Yap6b478092022-07-06 15:36:03 +010053 {
Samuel Yap6b478092022-07-06 15:36:03 +010054 float sum = 0.0f;
55
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010056 // InputYRowSize is synonymous with inputXColSize
Samuel Yap6b478092022-07-06 15:36:03 +010057 for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
58 auto xIdx = curIdx;
59 xIdx[inputXColDim] = inputYRowIdx;
60
61 auto yIdx = curIdx;
62 yIdx[inputYRowDim] = inputYRowIdx;
63
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010064 sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx));
Samuel Yap6b478092022-07-06 15:36:03 +010065 }
66
67 SetValueAt(sum, DataSlot::Output, curIdx);
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010068 };
Samuel Yap6b478092022-07-06 15:36:03 +010069
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010070 auto startIdx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
71 RecurseTensor(outputInfo,
72 batchMatMulOperation,
73 startIdx,
74 0);
75}
Samuel Yap6b478092022-07-06 15:36:03 +010076
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010077void BatchMatMul::ApplyParams()
78{
79 if(params.m_TransposeX)
Samuel Yap6b478092022-07-06 15:36:03 +010080 {
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010081 Transpose(DataSlot::InputX);
82 }
83 else if(params.m_AdjointX)
84 {
85 Adjoint(DataSlot::InputX);
86 }
87 if(params.m_TransposeY)
88 {
89 Transpose(DataSlot::InputY);
90 }
91 else if(params.m_AdjointY)
92 {
93 Adjoint(DataSlot::InputY);
Samuel Yap6b478092022-07-06 15:36:03 +010094 }
95}
96
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010097void BatchMatMul::Transpose(DataSlot type)
98{
99 // AKA the permute of the tensor
100 // This modifies the tensor's info.
101
102 switch(type)
103 {
104 case DataSlot::InputX:
105 {
106 auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutX,
107 inputXInfo.GetShape());
108 inputXInfo = armnnUtils::Permuted(inputXInfo, permuteVec);
109 std::vector<float> temp(inputXData.size());
110 armnnUtils::Permute(inputXInfo.GetShape(),
111 permuteVec,
112 inputXData.data(),
113 temp.data(),
114 sizeof(float));
115 inputXData = temp;
116 break;
117 }
118 case DataSlot::InputY:
119 {
120 auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutY,
121 inputYInfo.GetShape());
122 inputYInfo = armnnUtils::Permuted(inputYInfo, permuteVec);
123 std::vector<float> temp(inputYData.size());
124 armnnUtils::Permute(inputYInfo.GetShape(),
125 permuteVec,
126 inputYData.data(),
127 temp.data(),
128 sizeof(float));
129 inputYData = temp;
130 break;
131 }
132 case DataSlot::Output: // We needn't transpose the output tensor
133 default:
134 break;
135 }
136}
137
138void BatchMatMul::Adjoint(DataSlot type)
139{
140 // Finding the adjoint of a square matrix:
141 // Calculate the cofactor of each element (using Gauss elimination here)
142 // Apply a transpose to it (this also modifies the tensor's info)
143
144 TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo;
145 const auto& dataLayout = (type == DataSlot::InputX) ? params.m_DataLayoutX : params.m_DataLayoutY;
146 const auto axesToAdjoint = BatchMatMulDescriptor::GetAxesToMul(dataLayout,inputInfo.GetShape());
147
148 ARMNN_ASSERT(inputInfo.GetShape()[axesToAdjoint.first] == inputInfo.GetShape()[axesToAdjoint.second]);
149 // We grab a copy of the tensor data to prevent overwriting
150 std::vector<float> inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData;
151
152 // The sub-matrix is the resultant matrix when the row and column of the current index is removed
153 unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1;
154 std::vector<std::vector<float>> subMat(subMatAxisSize,
155 std::vector<float>(subMatAxisSize));
156
157 // Lambdas for each sub-step of the cofactor operation
158 auto almostEquals = [&](const float& a, const float& b, float unitsInLastPlace = 2.0f)
159 {
160 float diff = std::fabs(a-b);
161 float bound = diff * std::numeric_limits<float>::epsilon() * unitsInLastPlace;
162 return (diff <= bound) || (diff < std::numeric_limits<float>::min());
163 };
164
165 float swapMultiplier = std::numeric_limits<float>::max();
166 auto swapRows = [&](unsigned int rowIdxA, unsigned int rowIdxB)
167 {
168 // Every row swap flips this around by the negative (set to 1 at the beginning of each cofactor op run)
169 for(unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++)
170 {
171 float tmp = subMat[rowIdxA][colIdx];
172 subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx];
173 subMat[rowIdxB][colIdx] = tmp;
174 }
175 swapMultiplier *= -1.0f;
176 };
177
178 auto findNextValidPivotRowIdx = [&](unsigned int colIdx)
179 {
180 unsigned int result = std::numeric_limits<unsigned int>::max();
181
182 // The original diagonal has been checked and is invalid
183 for(unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++)
184 {
185 if(!almostEquals(subMat[rowIdx][colIdx], 0.0f))
186 {
187 result = rowIdx;
188 break;
189 }
190 }
191 return result;
192 };
193
194 auto eliminate = [&](const float& pivot, unsigned int pivotPos)
195 {
196 for(unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++)
197 {
198 float multiplierNumerator = subMat[rowIdx][pivotPos];
199 if(almostEquals(multiplierNumerator, 0.0f))
200 {
201 continue;
202 }
203 float multiplier = multiplierNumerator / pivot; // Susceptible to floating point inaccuracies
204 // Hence the almostEquals usage to counteract this
205 for(unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++)
206 {
207 // We start at col=pivotPos as we have assumed that all elements
208 // to our left have been eliminated to zero already
209
210 // We subtract based on the element directly above us in our pivot row
211 subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx];
212 }
213 }
214 };
215
216 auto cofactorOperation = [&](const std::vector<unsigned int>& curIdx)
217 {
218 auto row = curIdx[axesToAdjoint.first];
219 auto col = curIdx[axesToAdjoint.second];
220
221 float minorMultiplier = static_cast<float>(std::pow(-1, (row + 1 + col + 1)));
222
223 for(unsigned int subRow = 0; subRow < subMatAxisSize; subRow++)
224 {
225 for(unsigned int subCol = 0; subCol < subMatAxisSize; subCol++)
226 {
227 unsigned int outerRow = (subRow >= row)?subRow + 1:subRow;
228 unsigned int outerCol = (subCol >= col)?subCol + 1:subCol;
229 auto cloneIdx = curIdx;
230 cloneIdx[axesToAdjoint.first] = outerRow;
231 cloneIdx[axesToAdjoint.second] = outerCol;
232 subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone);
233 }
234 }
235
236 float determinant = 1.0f;
237
238 // Cover the edge cases and simple base cases before resorting to Gauss elimination for larger matrices
239 switch(subMatAxisSize)
240 {
241 case 0:
242 {
243 determinant = GetValueAt(type, curIdx, inputDataClone);
244 break;
245 }
246 case 1:
247 {
248 // If the resultant sub-matrix is just one element - that's the determinant
249 determinant = subMat[0][0];
250 break;
251 }
252 case 2:
253 {
254 // For a 2x2 sub-matrix, the determinant is just a*d-b*c
255 determinant = subMat[0][0] * subMat[1][1] -
256 subMat[0][1] * subMat[1][0];
257 break;
258 }
259 default:
260 {
261 // Gaussian elimination to find the determinant of this sub-matrix
262 swapMultiplier = 1.0f;
263 // March diagonally down the pivots and if it's invalid (a zero), swap the row with the
264 // nearest non-zero down within the column
265 for(unsigned int pivotRow = 0, pivotCol = 0;
266 pivotRow < subMatAxisSize;
267 pivotRow++, pivotCol++)
268 {
269 float& pivot = subMat[pivotRow][pivotCol];
270
271 if(almostEquals(pivot, 0.0f))
272 {
273 unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol);
274 if(nextValidPivotRowIdx == std::numeric_limits<unsigned int>::max())
275 {
276 // No valid pivot down this column, which means that this pivot remains a zero.
277 // This results in the determinant for this entire sub-matrix to just be zero.
278 determinant = 0.0f;
279 break;
280 }
281 swapRows(pivotRow, nextValidPivotRowIdx);
282 }
283 determinant *= pivot;
284 // The actual elimination bit (which will update/propagate to the pivots down the line)
285 eliminate(pivot, pivotRow); // Synonymous with pivotCol
286 }
287
288 determinant *= swapMultiplier;
289 break;
290 }
291 }
292 float cofactor = minorMultiplier * determinant;
293 SetValueAt(cofactor, type, curIdx);
294 };
295
296 auto startIdx = std::vector<unsigned int>(inputInfo.GetNumDimensions(), 0);
297 RecurseTensor(inputInfo,
298 cofactorOperation,
299 startIdx,
300 0);
301
302 Transpose(type);
303}
304
305void BatchMatMul::RecurseTensor(const TensorInfo& tensorInfo,
306 const std::function<void(const std::vector<unsigned int>&)>& operation,
307 std::vector<unsigned int>& curIdx,
308 unsigned int curDim)
309{
310 if(!(curDim < tensorInfo.GetNumDimensions()))
311 {
312 // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
313 operation(curIdx);
314 return;
315 }
316
317 for(unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++)
318 {
319 curIdx[curDim] = i;
320 RecurseTensor(tensorInfo,
321 operation,
322 curIdx,
323 curDim + 1);
324 }
325}
326
327void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
328 std::pair<unsigned int, unsigned int>& axesYToMul)
Samuel Yap6b478092022-07-06 15:36:03 +0100329{
Samuel Yapca7bbd62022-08-02 09:12:02 +0100330 int rankDiff = static_cast<int>(inputXInfo.GetNumDimensions()) -
331 static_cast<int>(inputYInfo.GetNumDimensions());
Samuel Yap6b478092022-07-06 15:36:03 +0100332 if(rankDiff == 0)
333 {
334 return;
335 }
336 else if(rankDiff < 0)
337 {
338 // Y is the larger one
Samuel Yapdc8ed9d2022-08-08 14:07:42 +0100339 axesXToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
340 axesXToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
Samuel Yap6b478092022-07-06 15:36:03 +0100341 }
342 else if(rankDiff > 0)
343 {
344 // X is the larger one
Samuel Yapdc8ed9d2022-08-08 14:07:42 +0100345 axesYToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
346 axesYToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
Samuel Yap6b478092022-07-06 15:36:03 +0100347 }
348}
349
Samuel Yapdc8ed9d2022-08-08 14:07:42 +0100350float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData)
Samuel Yap6b478092022-07-06 15:36:03 +0100351{
352 // This gets the data from the input vector that we have, Not the decoder
353 // But for the output, it is operating on the encoder itself
354
355 AdjustToSafeIdx(type, idx);
356 unsigned int flatIdx = CalcFlatIdx(type, idx);
357 float value = 0.0f;
Samuel Yap6b478092022-07-06 15:36:03 +0100358 switch(type)
359 {
360 case DataSlot::InputX:
Samuel Yapdc8ed9d2022-08-08 14:07:42 +0100361 value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx];
Samuel Yap6b478092022-07-06 15:36:03 +0100362 break;
363 case DataSlot::InputY:
Samuel Yapdc8ed9d2022-08-08 14:07:42 +0100364 value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx];
Samuel Yap6b478092022-07-06 15:36:03 +0100365 break;
366 case DataSlot::Output:
367 outputEncoder[flatIdx];
368 value = outputEncoder.Get();
369 break;
370 default:
371 break;
372 }
373
374 return value;
375}
376
377void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx)
378{
379 AdjustToSafeIdx(type, idx);
Samuel Yap6b478092022-07-06 15:36:03 +0100380 unsigned int flatIdx = CalcFlatIdx(type, idx);
Samuel Yap6b478092022-07-06 15:36:03 +0100381 switch(type)
382 {
383 case DataSlot::InputX:
384 inputXData[flatIdx] = value;
385 break;
386 case DataSlot::InputY:
387 inputYData[flatIdx] = value;
388 break;
389 case DataSlot::Output:
390 outputEncoder[flatIdx];
391 outputEncoder.Set(value);
392 break;
393 default:
394 break;
395 }
396}
397
398void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
399{
400 for(unsigned int dim = 0; dim < idx.size(); dim++)
401 {
402 switch(type)
403 {
404 case DataSlot::InputX:
405 {
406 auto xRank = inputXInfo.GetNumDimensions();
407 auto xDiff = outputInfo.GetNumDimensions() - xRank;
408 if (dim < xDiff ||
409 idx[dim] > inputXInfo.GetShape()[dim-xDiff]-1)
410 {
411 idx[dim] = 0; // Broadcasting
412 }
413 break;
414 }
415 case DataSlot::InputY:
416 {
417 auto yRank = inputYInfo.GetNumDimensions();
418 auto yDiff = outputInfo.GetNumDimensions() - yRank;
419 if (dim < yDiff ||
420 idx[dim] > inputYInfo.GetShape()[dim-yDiff]-1)
421 {
422 idx[dim] = 0;
423 }
424 break;
425 }
426 case DataSlot::Output:
427 {
428 // Our indices are based off the output
429 break;
430 }
431 default:
432 break;
433 }
434 }
435}
436
437unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx)
438{
439 unsigned int result = idx[idx.size()-1];
Samuel Yap6b478092022-07-06 15:36:03 +0100440 unsigned int dimMultiplier = 1;
Samuel Yap6b478092022-07-06 15:36:03 +0100441 unsigned int offset;
442
443 // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
444 for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0; i--)
445 {
446 switch(type)
447 {
448 case DataSlot::InputX:
449 offset = outputInfo.GetNumDimensions() - inputXInfo.GetNumDimensions();
450 dimMultiplier *= inputXInfo.GetShape()[i + 1 - offset];
451 break;
452 case DataSlot::InputY:
453 offset = outputInfo.GetNumDimensions() - inputYInfo.GetNumDimensions();
454 dimMultiplier *= inputYInfo.GetShape()[i + 1 - offset];
455 break;
456 case DataSlot::Output:
457 dimMultiplier *= outputInfo.GetShape()[i+1];
458 break;
459 default:
460 break;
461 }
462 result += (idx[i] * dimMultiplier);
463 }
464 return result;
465}
466
Samuel Yap6b478092022-07-06 15:36:03 +0100467} // namespace armnn