Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "BatchMatMulImpl.hpp" |
| 7 | |
| 8 | #include <armnn/backends/WorkloadData.hpp> |
| 9 | #include <armnn/Logging.hpp> |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 10 | #include <armnnUtils/Permute.hpp> |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 11 | |
| 12 | namespace armnn |
| 13 | { |
| 14 | |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 15 | BatchMatMul::BatchMatMul(const BatchMatMulDescriptor& params, |
| 16 | const TensorInfo& inputXInfo, |
| 17 | const TensorInfo& inputYInfo, |
| 18 | const TensorInfo& outputInfo, |
| 19 | Decoder<float>& inputXDecoder, |
| 20 | Decoder<float>& inputYDecoder, |
| 21 | Encoder<float>& outputEncoder) |
| 22 | : params(params), |
| 23 | inputXInfo(inputXInfo), |
| 24 | inputYInfo(inputYInfo), |
| 25 | outputInfo(outputInfo), |
| 26 | inputXDecoder(inputXDecoder), |
| 27 | inputYDecoder(inputYDecoder), |
| 28 | outputEncoder(outputEncoder) |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 29 | { |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 30 | inputXData = this->inputXDecoder.DecodeTensor(inputXInfo.GetShape()); |
| 31 | inputYData = this->inputYDecoder.DecodeTensor(inputYInfo.GetShape()); |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 32 | // At this point, we don't touch the input decoders - just the resultant vectors |
| 33 | |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 34 | ApplyParams(); |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 35 | |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 36 | ApplyBatchMatMul(); |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 37 | } |
| 38 | |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 39 | void BatchMatMul::ApplyBatchMatMul() |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 40 | { |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 41 | auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutX, |
| 42 | inputXInfo.GetShape()); |
| 43 | auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutY, |
| 44 | inputYInfo.GetShape()); |
| 45 | AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul); |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 46 | |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 47 | unsigned int inputXColDim = axesXToMul.second; |
| 48 | unsigned int inputYRowDim = axesYToMul.first; |
| 49 | |
| 50 | unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim]; |
| 51 | |
| 52 | auto batchMatMulOperation = [&](const std::vector<unsigned int>& curIdx) |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 53 | { |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 54 | float sum = 0.0f; |
| 55 | |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 56 | // InputYRowSize is synonymous with inputXColSize |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 57 | for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) { |
| 58 | auto xIdx = curIdx; |
| 59 | xIdx[inputXColDim] = inputYRowIdx; |
| 60 | |
| 61 | auto yIdx = curIdx; |
| 62 | yIdx[inputYRowDim] = inputYRowIdx; |
| 63 | |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 64 | sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx)); |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 65 | } |
| 66 | |
| 67 | SetValueAt(sum, DataSlot::Output, curIdx); |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 68 | }; |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 69 | |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 70 | auto startIdx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0); |
| 71 | RecurseTensor(outputInfo, |
| 72 | batchMatMulOperation, |
| 73 | startIdx, |
| 74 | 0); |
| 75 | } |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 76 | |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 77 | void BatchMatMul::ApplyParams() |
| 78 | { |
| 79 | if(params.m_TransposeX) |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 80 | { |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 81 | Transpose(DataSlot::InputX); |
| 82 | } |
| 83 | else if(params.m_AdjointX) |
| 84 | { |
| 85 | Adjoint(DataSlot::InputX); |
| 86 | } |
| 87 | if(params.m_TransposeY) |
| 88 | { |
| 89 | Transpose(DataSlot::InputY); |
| 90 | } |
| 91 | else if(params.m_AdjointY) |
| 92 | { |
| 93 | Adjoint(DataSlot::InputY); |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 94 | } |
| 95 | } |
| 96 | |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 97 | void BatchMatMul::Transpose(DataSlot type) |
| 98 | { |
| 99 | // AKA the permute of the tensor |
| 100 | // This modifies the tensor's info. |
| 101 | |
| 102 | switch(type) |
| 103 | { |
| 104 | case DataSlot::InputX: |
| 105 | { |
| 106 | auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutX, |
| 107 | inputXInfo.GetShape()); |
| 108 | inputXInfo = armnnUtils::Permuted(inputXInfo, permuteVec); |
| 109 | std::vector<float> temp(inputXData.size()); |
| 110 | armnnUtils::Permute(inputXInfo.GetShape(), |
| 111 | permuteVec, |
| 112 | inputXData.data(), |
| 113 | temp.data(), |
| 114 | sizeof(float)); |
| 115 | inputXData = temp; |
| 116 | break; |
| 117 | } |
| 118 | case DataSlot::InputY: |
| 119 | { |
| 120 | auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutY, |
| 121 | inputYInfo.GetShape()); |
| 122 | inputYInfo = armnnUtils::Permuted(inputYInfo, permuteVec); |
| 123 | std::vector<float> temp(inputYData.size()); |
| 124 | armnnUtils::Permute(inputYInfo.GetShape(), |
| 125 | permuteVec, |
| 126 | inputYData.data(), |
| 127 | temp.data(), |
| 128 | sizeof(float)); |
| 129 | inputYData = temp; |
| 130 | break; |
| 131 | } |
| 132 | case DataSlot::Output: // We needn't transpose the output tensor |
| 133 | default: |
| 134 | break; |
| 135 | } |
| 136 | } |
| 137 | |
| 138 | void BatchMatMul::Adjoint(DataSlot type) |
| 139 | { |
| 140 | // Finding the adjoint of a square matrix: |
| 141 | // Calculate the cofactor of each element (using Gauss elimination here) |
| 142 | // Apply a transpose to it (this also modifies the tensor's info) |
| 143 | |
| 144 | TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo; |
| 145 | const auto& dataLayout = (type == DataSlot::InputX) ? params.m_DataLayoutX : params.m_DataLayoutY; |
| 146 | const auto axesToAdjoint = BatchMatMulDescriptor::GetAxesToMul(dataLayout,inputInfo.GetShape()); |
| 147 | |
| 148 | ARMNN_ASSERT(inputInfo.GetShape()[axesToAdjoint.first] == inputInfo.GetShape()[axesToAdjoint.second]); |
| 149 | // We grab a copy of the tensor data to prevent overwriting |
| 150 | std::vector<float> inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData; |
| 151 | |
| 152 | // The sub-matrix is the resultant matrix when the row and column of the current index is removed |
| 153 | unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1; |
| 154 | std::vector<std::vector<float>> subMat(subMatAxisSize, |
| 155 | std::vector<float>(subMatAxisSize)); |
| 156 | |
| 157 | // Lambdas for each sub-step of the cofactor operation |
| 158 | auto almostEquals = [&](const float& a, const float& b, float unitsInLastPlace = 2.0f) |
| 159 | { |
| 160 | float diff = std::fabs(a-b); |
| 161 | float bound = diff * std::numeric_limits<float>::epsilon() * unitsInLastPlace; |
| 162 | return (diff <= bound) || (diff < std::numeric_limits<float>::min()); |
| 163 | }; |
| 164 | |
| 165 | float swapMultiplier = std::numeric_limits<float>::max(); |
| 166 | auto swapRows = [&](unsigned int rowIdxA, unsigned int rowIdxB) |
| 167 | { |
| 168 | // Every row swap flips this around by the negative (set to 1 at the beginning of each cofactor op run) |
| 169 | for(unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++) |
| 170 | { |
| 171 | float tmp = subMat[rowIdxA][colIdx]; |
| 172 | subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx]; |
| 173 | subMat[rowIdxB][colIdx] = tmp; |
| 174 | } |
| 175 | swapMultiplier *= -1.0f; |
| 176 | }; |
| 177 | |
| 178 | auto findNextValidPivotRowIdx = [&](unsigned int colIdx) |
| 179 | { |
| 180 | unsigned int result = std::numeric_limits<unsigned int>::max(); |
| 181 | |
| 182 | // The original diagonal has been checked and is invalid |
| 183 | for(unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++) |
| 184 | { |
| 185 | if(!almostEquals(subMat[rowIdx][colIdx], 0.0f)) |
| 186 | { |
| 187 | result = rowIdx; |
| 188 | break; |
| 189 | } |
| 190 | } |
| 191 | return result; |
| 192 | }; |
| 193 | |
| 194 | auto eliminate = [&](const float& pivot, unsigned int pivotPos) |
| 195 | { |
| 196 | for(unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++) |
| 197 | { |
| 198 | float multiplierNumerator = subMat[rowIdx][pivotPos]; |
| 199 | if(almostEquals(multiplierNumerator, 0.0f)) |
| 200 | { |
| 201 | continue; |
| 202 | } |
| 203 | float multiplier = multiplierNumerator / pivot; // Susceptible to floating point inaccuracies |
| 204 | // Hence the almostEquals usage to counteract this |
| 205 | for(unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++) |
| 206 | { |
| 207 | // We start at col=pivotPos as we have assumed that all elements |
| 208 | // to our left have been eliminated to zero already |
| 209 | |
| 210 | // We subtract based on the element directly above us in our pivot row |
| 211 | subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx]; |
| 212 | } |
| 213 | } |
| 214 | }; |
| 215 | |
| 216 | auto cofactorOperation = [&](const std::vector<unsigned int>& curIdx) |
| 217 | { |
| 218 | auto row = curIdx[axesToAdjoint.first]; |
| 219 | auto col = curIdx[axesToAdjoint.second]; |
| 220 | |
| 221 | float minorMultiplier = static_cast<float>(std::pow(-1, (row + 1 + col + 1))); |
| 222 | |
| 223 | for(unsigned int subRow = 0; subRow < subMatAxisSize; subRow++) |
| 224 | { |
| 225 | for(unsigned int subCol = 0; subCol < subMatAxisSize; subCol++) |
| 226 | { |
| 227 | unsigned int outerRow = (subRow >= row)?subRow + 1:subRow; |
| 228 | unsigned int outerCol = (subCol >= col)?subCol + 1:subCol; |
| 229 | auto cloneIdx = curIdx; |
| 230 | cloneIdx[axesToAdjoint.first] = outerRow; |
| 231 | cloneIdx[axesToAdjoint.second] = outerCol; |
| 232 | subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone); |
| 233 | } |
| 234 | } |
| 235 | |
| 236 | float determinant = 1.0f; |
| 237 | |
| 238 | // Cover the edge cases and simple base cases before resorting to Gauss elimination for larger matrices |
| 239 | switch(subMatAxisSize) |
| 240 | { |
| 241 | case 0: |
| 242 | { |
| 243 | determinant = GetValueAt(type, curIdx, inputDataClone); |
| 244 | break; |
| 245 | } |
| 246 | case 1: |
| 247 | { |
| 248 | // If the resultant sub-matrix is just one element - that's the determinant |
| 249 | determinant = subMat[0][0]; |
| 250 | break; |
| 251 | } |
| 252 | case 2: |
| 253 | { |
| 254 | // For a 2x2 sub-matrix, the determinant is just a*d-b*c |
| 255 | determinant = subMat[0][0] * subMat[1][1] - |
| 256 | subMat[0][1] * subMat[1][0]; |
| 257 | break; |
| 258 | } |
| 259 | default: |
| 260 | { |
| 261 | // Gaussian elimination to find the determinant of this sub-matrix |
| 262 | swapMultiplier = 1.0f; |
| 263 | // March diagonally down the pivots and if it's invalid (a zero), swap the row with the |
| 264 | // nearest non-zero down within the column |
| 265 | for(unsigned int pivotRow = 0, pivotCol = 0; |
| 266 | pivotRow < subMatAxisSize; |
| 267 | pivotRow++, pivotCol++) |
| 268 | { |
| 269 | float& pivot = subMat[pivotRow][pivotCol]; |
| 270 | |
| 271 | if(almostEquals(pivot, 0.0f)) |
| 272 | { |
| 273 | unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol); |
| 274 | if(nextValidPivotRowIdx == std::numeric_limits<unsigned int>::max()) |
| 275 | { |
| 276 | // No valid pivot down this column, which means that this pivot remains a zero. |
| 277 | // This results in the determinant for this entire sub-matrix to just be zero. |
| 278 | determinant = 0.0f; |
| 279 | break; |
| 280 | } |
| 281 | swapRows(pivotRow, nextValidPivotRowIdx); |
| 282 | } |
| 283 | determinant *= pivot; |
| 284 | // The actual elimination bit (which will update/propagate to the pivots down the line) |
| 285 | eliminate(pivot, pivotRow); // Synonymous with pivotCol |
| 286 | } |
| 287 | |
| 288 | determinant *= swapMultiplier; |
| 289 | break; |
| 290 | } |
| 291 | } |
| 292 | float cofactor = minorMultiplier * determinant; |
| 293 | SetValueAt(cofactor, type, curIdx); |
| 294 | }; |
| 295 | |
| 296 | auto startIdx = std::vector<unsigned int>(inputInfo.GetNumDimensions(), 0); |
| 297 | RecurseTensor(inputInfo, |
| 298 | cofactorOperation, |
| 299 | startIdx, |
| 300 | 0); |
| 301 | |
| 302 | Transpose(type); |
| 303 | } |
| 304 | |
| 305 | void BatchMatMul::RecurseTensor(const TensorInfo& tensorInfo, |
| 306 | const std::function<void(const std::vector<unsigned int>&)>& operation, |
| 307 | std::vector<unsigned int>& curIdx, |
| 308 | unsigned int curDim) |
| 309 | { |
| 310 | if(!(curDim < tensorInfo.GetNumDimensions())) |
| 311 | { |
| 312 | // We're at the leaf level of this call tree, so we operate here (each leaf is a data point) |
| 313 | operation(curIdx); |
| 314 | return; |
| 315 | } |
| 316 | |
| 317 | for(unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++) |
| 318 | { |
| 319 | curIdx[curDim] = i; |
| 320 | RecurseTensor(tensorInfo, |
| 321 | operation, |
| 322 | curIdx, |
| 323 | curDim + 1); |
| 324 | } |
| 325 | } |
| 326 | |
| 327 | void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul, |
| 328 | std::pair<unsigned int, unsigned int>& axesYToMul) |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 329 | { |
Samuel Yap | ca7bbd6 | 2022-08-02 09:12:02 +0100 | [diff] [blame] | 330 | int rankDiff = static_cast<int>(inputXInfo.GetNumDimensions()) - |
| 331 | static_cast<int>(inputYInfo.GetNumDimensions()); |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 332 | if(rankDiff == 0) |
| 333 | { |
| 334 | return; |
| 335 | } |
| 336 | else if(rankDiff < 0) |
| 337 | { |
| 338 | // Y is the larger one |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 339 | axesXToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff)); |
| 340 | axesXToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff)); |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 341 | } |
| 342 | else if(rankDiff > 0) |
| 343 | { |
| 344 | // X is the larger one |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 345 | axesYToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff)); |
| 346 | axesYToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff)); |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 347 | } |
| 348 | } |
| 349 | |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 350 | float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData) |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 351 | { |
| 352 | // This gets the data from the input vector that we have, Not the decoder |
| 353 | // But for the output, it is operating on the encoder itself |
| 354 | |
| 355 | AdjustToSafeIdx(type, idx); |
| 356 | unsigned int flatIdx = CalcFlatIdx(type, idx); |
| 357 | float value = 0.0f; |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 358 | switch(type) |
| 359 | { |
| 360 | case DataSlot::InputX: |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 361 | value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx]; |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 362 | break; |
| 363 | case DataSlot::InputY: |
Samuel Yap | dc8ed9d | 2022-08-08 14:07:42 +0100 | [diff] [blame] | 364 | value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx]; |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 365 | break; |
| 366 | case DataSlot::Output: |
| 367 | outputEncoder[flatIdx]; |
| 368 | value = outputEncoder.Get(); |
| 369 | break; |
| 370 | default: |
| 371 | break; |
| 372 | } |
| 373 | |
| 374 | return value; |
| 375 | } |
| 376 | |
| 377 | void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx) |
| 378 | { |
| 379 | AdjustToSafeIdx(type, idx); |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 380 | unsigned int flatIdx = CalcFlatIdx(type, idx); |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 381 | switch(type) |
| 382 | { |
| 383 | case DataSlot::InputX: |
| 384 | inputXData[flatIdx] = value; |
| 385 | break; |
| 386 | case DataSlot::InputY: |
| 387 | inputYData[flatIdx] = value; |
| 388 | break; |
| 389 | case DataSlot::Output: |
| 390 | outputEncoder[flatIdx]; |
| 391 | outputEncoder.Set(value); |
| 392 | break; |
| 393 | default: |
| 394 | break; |
| 395 | } |
| 396 | } |
| 397 | |
| 398 | void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx) |
| 399 | { |
| 400 | for(unsigned int dim = 0; dim < idx.size(); dim++) |
| 401 | { |
| 402 | switch(type) |
| 403 | { |
| 404 | case DataSlot::InputX: |
| 405 | { |
| 406 | auto xRank = inputXInfo.GetNumDimensions(); |
| 407 | auto xDiff = outputInfo.GetNumDimensions() - xRank; |
| 408 | if (dim < xDiff || |
| 409 | idx[dim] > inputXInfo.GetShape()[dim-xDiff]-1) |
| 410 | { |
| 411 | idx[dim] = 0; // Broadcasting |
| 412 | } |
| 413 | break; |
| 414 | } |
| 415 | case DataSlot::InputY: |
| 416 | { |
| 417 | auto yRank = inputYInfo.GetNumDimensions(); |
| 418 | auto yDiff = outputInfo.GetNumDimensions() - yRank; |
| 419 | if (dim < yDiff || |
| 420 | idx[dim] > inputYInfo.GetShape()[dim-yDiff]-1) |
| 421 | { |
| 422 | idx[dim] = 0; |
| 423 | } |
| 424 | break; |
| 425 | } |
| 426 | case DataSlot::Output: |
| 427 | { |
| 428 | // Our indices are based off the output |
| 429 | break; |
| 430 | } |
| 431 | default: |
| 432 | break; |
| 433 | } |
| 434 | } |
| 435 | } |
| 436 | |
| 437 | unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx) |
| 438 | { |
| 439 | unsigned int result = idx[idx.size()-1]; |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 440 | unsigned int dimMultiplier = 1; |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 441 | unsigned int offset; |
| 442 | |
| 443 | // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x) |
| 444 | for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0; i--) |
| 445 | { |
| 446 | switch(type) |
| 447 | { |
| 448 | case DataSlot::InputX: |
| 449 | offset = outputInfo.GetNumDimensions() - inputXInfo.GetNumDimensions(); |
| 450 | dimMultiplier *= inputXInfo.GetShape()[i + 1 - offset]; |
| 451 | break; |
| 452 | case DataSlot::InputY: |
| 453 | offset = outputInfo.GetNumDimensions() - inputYInfo.GetNumDimensions(); |
| 454 | dimMultiplier *= inputYInfo.GetShape()[i + 1 - offset]; |
| 455 | break; |
| 456 | case DataSlot::Output: |
| 457 | dimMultiplier *= outputInfo.GetShape()[i+1]; |
| 458 | break; |
| 459 | default: |
| 460 | break; |
| 461 | } |
| 462 | result += (idx[i] * dimMultiplier); |
| 463 | } |
| 464 | return result; |
| 465 | } |
| 466 | |
Samuel Yap | 6b47809 | 2022-07-06 15:36:03 +0100 | [diff] [blame] | 467 | } // namespace armnn |