blob: 25b6c85d7759157ee86ad39e275dc26dcb84cc5f [file] [log] [blame]
Samuel Yap6b478092022-07-06 15:36:03 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10
11#include <armnn/backends/WorkloadData.hpp>
12
13namespace armnn
14{
15
16class BatchMatMul {
17public:
18 enum DataSlot
19 {
20 InputX = 0,
21 InputY = 1,
22 Output = 2
23 };
24
25 BatchMatMul(const BatchMatMulDescriptor& params,
26 const TensorInfo& inputXInfo,
27 const TensorInfo& inputYInfo,
28 const TensorInfo& outputInfo,
29 Decoder<float>& inputXDecoder,
30 Decoder<float>& inputYDecoder,
31 Encoder<float>& outputEncoder)
32 : params(params),
33 inputXInfo(inputXInfo),
34 inputYInfo(inputYInfo),
35 outputInfo(outputInfo),
36 inputXDecoder(inputXDecoder),
37 inputYDecoder(inputYDecoder),
38 outputEncoder(outputEncoder)
39 {}
40
41 void BatchMatMulImpl();
42
43 void RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim);
44
45 // Adjusts it for when input tensors are of unequal rank
46 void AdjustAxesToMulForUnequalRanks(
47 std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul);
48
49 float GetValueAt(DataSlot type, std::vector<unsigned int> idx);
50
51 void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx);
52
53 // Takes into account broadcasting
54 void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx);
55
56 unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx);
57
58 template <typename T>
59 std::string StringifyVec(const std::vector<T>& vec);
60
61private:
62 const BatchMatMulDescriptor& params;
63 const TensorInfo& inputXInfo;
64 const TensorInfo& inputYInfo;
65 const TensorInfo& outputInfo;
66 Decoder<float>& inputXDecoder;
67 Decoder<float>& inputYDecoder;
68 Encoder<float>& outputEncoder;
69
70 std::vector<float> inputXData;
71 std::vector<float> inputYData;
72
73};
74
75} // namespace armnn