blob: 19971a4af329aaeb98422904c06b9d17f4d44190 [file] [log] [blame]
Samuel Yap6b478092022-07-06 15:36:03 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10
11#include <armnn/backends/WorkloadData.hpp>
12
13namespace armnn
14{
15
16class BatchMatMul {
17public:
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010018 BatchMatMul(const BatchMatMulDescriptor& params,
19 const TensorInfo& inputXInfo,
20 const TensorInfo& inputYInfo,
21 const TensorInfo& outputInfo,
22 Decoder<float>& inputXDecoder,
23 Decoder<float>& inputYDecoder,
24 Encoder<float>& outputEncoder);
25
26private:
Samuel Yap6b478092022-07-06 15:36:03 +010027 enum DataSlot
28 {
29 InputX = 0,
30 InputY = 1,
31 Output = 2
32 };
33
Samuel Yap6b478092022-07-06 15:36:03 +010034 const BatchMatMulDescriptor& params;
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010035 TensorInfo inputXInfo;
36 TensorInfo inputYInfo;
37 TensorInfo outputInfo;
Samuel Yap6b478092022-07-06 15:36:03 +010038 Decoder<float>& inputXDecoder;
39 Decoder<float>& inputYDecoder;
40 Encoder<float>& outputEncoder;
41
42 std::vector<float> inputXData;
43 std::vector<float> inputYData;
44
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010045 void ApplyBatchMatMul();
46
47 void ApplyParams();
48
49 void Transpose(DataSlot type);
50
51 void Adjoint(DataSlot type);
52
53 void RecurseTensor(const TensorInfo& tensorInfo,
54 std::function<void(const std::vector<unsigned int>&)> const& operation,
55 std::vector<unsigned int>& curIdx,
56 unsigned int curDim);
57
58 // Adjusts it for when input tensors are of unequal rank
59 void AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
60 std::pair<unsigned int, unsigned int>& axesYToMul);
61
62 float GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData = {});
63
64 void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx);
65
66 // Takes into account broadcasting
67 void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx);
68
69 unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx);
Samuel Yap6b478092022-07-06 15:36:03 +010070};
71
72} // namespace armnn