blob: 6693f1576067cb4f07159fe57208d8cbc201d0c8 [file] [log] [blame]
Samuel Yap4b7a34d2022-07-06 15:36:03 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "BatchMatMulImpl.hpp"
7
8#include <armnn/backends/WorkloadData.hpp>
9#include <armnn/Logging.hpp>
10
11namespace armnn
12{
13
14void BatchMatMul::BatchMatMulImpl()
15{
16 inputXData = inputXDecoder.DecodeTensor(inputXInfo.GetShape());
17 inputYData = inputYDecoder.DecodeTensor(inputYInfo.GetShape());
18 // At this point, we don't touch the input decoders - just the resultant vectors
19
20 // Pre-transpose and pre-adjoint if their vectors aren't empty
21 // and also DataLayouts which may change with permutations/adjoints
22
23 // Todo: Have you updated input validation and inferred output shapes to accommodate for these pre-permutes?
24
25 auto idx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
26 RecurseBMM(idx, 0);
27}
28
29void BatchMatMul::RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim)
30{
31 // We're working off of the indexes of the output tensor (the max possible shape)
32
33 if(!(curDim < outputInfo.GetNumDimensions()))
34 {
35 // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
36
37 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(params,
38 inputXInfo.GetShape(),
39 inputYInfo.GetShape());
40 AdjustAxesToMulForUnequalRanks(axesToMul);
41
42 unsigned int inputXColDim = axesToMul.first.second;
43 unsigned int inputYRowDim = axesToMul.second.first;
44
45 unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
46
47 float sum = 0.0f;
48
49 // You could also use inputXColSize
50 for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
51 auto xIdx = curIdx;
52 xIdx[inputXColDim] = inputYRowIdx;
53
54 auto yIdx = curIdx;
55 yIdx[inputYRowDim] = inputYRowIdx;
56
57 sum += (GetValueAt(DataSlot::InputX, xIdx)
58 * GetValueAt(DataSlot::InputY, yIdx));
59 }
60
61 SetValueAt(sum, DataSlot::Output, curIdx);
62
63 return;
64 }
65
66 for (unsigned int i = 0; i < outputInfo.GetShape()[curDim]; i++)
67 {
68 curIdx[curDim] = i;
69 RecurseBMM(curIdx, curDim+1);
70 }
71}
72
73void BatchMatMul::AdjustAxesToMulForUnequalRanks(
74 std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul)
75{
Samuel Yapbc2d5a62022-08-02 09:12:02 +010076 int rankDiff = static_cast<int>(inputXInfo.GetNumDimensions()) -
77 static_cast<int>(inputYInfo.GetNumDimensions());
Samuel Yap4b7a34d2022-07-06 15:36:03 +010078 if(rankDiff == 0)
79 {
80 return;
81 }
82 else if(rankDiff < 0)
83 {
84 // Y is the larger one
85 axesToMul.first.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
86 axesToMul.first.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
87 }
88 else if(rankDiff > 0)
89 {
90 // X is the larger one
91 axesToMul.second.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
92 axesToMul.second.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
93 }
94}
95
96float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx)
97{
98 // This gets the data from the input vector that we have, Not the decoder
99 // But for the output, it is operating on the encoder itself
100
101 AdjustToSafeIdx(type, idx);
102 unsigned int flatIdx = CalcFlatIdx(type, idx);
103 float value = 0.0f;
104
105 switch(type)
106 {
107 case DataSlot::InputX:
108 value = inputXData[flatIdx];
109 break;
110 case DataSlot::InputY:
111 value = inputYData[flatIdx];
112 break;
113 case DataSlot::Output:
114 outputEncoder[flatIdx];
115 value = outputEncoder.Get();
116 break;
117 default:
118 break;
119 }
120
121 return value;
122}
123
124void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx)
125{
126 AdjustToSafeIdx(type, idx);
127
128 unsigned int flatIdx = CalcFlatIdx(type, idx);
129
130 switch(type)
131 {
132 case DataSlot::InputX:
133 inputXData[flatIdx] = value;
134 break;
135 case DataSlot::InputY:
136 inputYData[flatIdx] = value;
137 break;
138 case DataSlot::Output:
139 outputEncoder[flatIdx];
140 outputEncoder.Set(value);
141 break;
142 default:
143 break;
144 }
145}
146
147void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
148{
149 for(unsigned int dim = 0; dim < idx.size(); dim++)
150 {
151 switch(type)
152 {
153 case DataSlot::InputX:
154 {
155 auto xRank = inputXInfo.GetNumDimensions();
156 auto xDiff = outputInfo.GetNumDimensions() - xRank;
157 if (dim < xDiff ||
158 idx[dim] > inputXInfo.GetShape()[dim-xDiff]-1)
159 {
160 idx[dim] = 0; // Broadcasting
161 }
162 break;
163 }
164 case DataSlot::InputY:
165 {
166 auto yRank = inputYInfo.GetNumDimensions();
167 auto yDiff = outputInfo.GetNumDimensions() - yRank;
168 if (dim < yDiff ||
169 idx[dim] > inputYInfo.GetShape()[dim-yDiff]-1)
170 {
171 idx[dim] = 0;
172 }
173 break;
174 }
175 case DataSlot::Output:
176 {
177 // Our indices are based off the output
178 break;
179 }
180 default:
181 break;
182 }
183 }
184}
185
186unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx)
187{
188 unsigned int result = idx[idx.size()-1];
189
190 unsigned int dimMultiplier = 1;
191
192 unsigned int offset;
193
194 // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
195 for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0; i--)
196 {
197 switch(type)
198 {
199 case DataSlot::InputX:
200 offset = outputInfo.GetNumDimensions() - inputXInfo.GetNumDimensions();
201 dimMultiplier *= inputXInfo.GetShape()[i + 1 - offset];
202 break;
203 case DataSlot::InputY:
204 offset = outputInfo.GetNumDimensions() - inputYInfo.GetNumDimensions();
205 dimMultiplier *= inputYInfo.GetShape()[i + 1 - offset];
206 break;
207 case DataSlot::Output:
208 dimMultiplier *= outputInfo.GetShape()[i+1];
209 break;
210 default:
211 break;
212 }
213 result += (idx[i] * dimMultiplier);
214 }
215 return result;
216}
217
218template <typename T>
219std::string BatchMatMul::StringifyVec(const std::vector<T>& vec)
220{
221 std::string res = "{ ";
222 for(auto x : vec)
223 {
224 res += std::to_string(x);
225 res += " ";
226 }
227 res += "}";
228 return res;
229}
230
231} // namespace armnn