blob: a3d944ae754d212a960180449d99f28825bc17b7 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
josh minor4a3c6102020-01-06 16:40:46 -06002// Copyright © 2019 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01006#include "BaseIterator.hpp"
telsoa014fcda012018-03-09 14:13:49 +00007#include <armnn/Tensor.hpp>
8
9#include <functional>
10
11namespace armnn
12{
13
14struct BroadcastLoop
15{
16 BroadcastLoop(const TensorShape& inShape0, const TensorShape& inShape1, const TensorShape& outShape);
17
josh minor4a3c6102020-01-06 16:40:46 -060018 BroadcastLoop(const TensorShape& inShape, const TensorShape& outShape);
19
telsoa014fcda012018-03-09 14:13:49 +000020 unsigned int GetNumDimensions()
21 {
22 return static_cast<unsigned int>(m_DimData.size());
23 }
24
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010025 template <typename Func, typename DecoderOp, typename EncoderOp>
telsoa014fcda012018-03-09 14:13:49 +000026 void Unroll(Func operationFunc,
27 unsigned int dimension,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010028 DecoderOp& inData0,
29 DecoderOp& inData1,
30 EncoderOp& outData)
telsoa014fcda012018-03-09 14:13:49 +000031 {
32 if (dimension >= GetNumDimensions())
33 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010034 outData.Set(operationFunc(inData0.Get(), inData1.Get()));
telsoa014fcda012018-03-09 14:13:49 +000035 return;
36 }
37
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010038 unsigned int inData0Movement = 0;
39 unsigned int inData1Movement = 0;
40 unsigned int outDataMovement = 0;
41
telsoa014fcda012018-03-09 14:13:49 +000042 for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++)
43 {
44 Unroll(operationFunc, dimension + 1, inData0, inData1, outData);
45
46 inData0 += m_DimData[dimension].m_Stride1;
47 inData1 += m_DimData[dimension].m_Stride2;
48 outData += m_DimData[dimension].m_StrideOut;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010049
50 inData0Movement += m_DimData[dimension].m_Stride1;
51 inData1Movement += m_DimData[dimension].m_Stride2;
52 outDataMovement += m_DimData[dimension].m_StrideOut;
telsoa014fcda012018-03-09 14:13:49 +000053 }
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010054
55 // move iterator back to the start
56 inData0 -= inData0Movement;
57 inData1 -= inData1Movement;
58 outData -= outDataMovement;
telsoa014fcda012018-03-09 14:13:49 +000059 }
60
josh minor4a3c6102020-01-06 16:40:46 -060061 template <typename Func, typename DecoderOp, typename EncoderOp>
62 void Unroll(Func operationFunc,
63 unsigned int dimension,
64 DecoderOp& inData,
65 EncoderOp& outData)
66 {
67 if (dimension >= GetNumDimensions())
68 {
69 outData.Set(operationFunc(inData.Get()));
70 return;
71 }
72
73 unsigned int inDataMovement = 0;
74 unsigned int outDataMovement = 0;
75
76 for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++)
77 {
78 Unroll(operationFunc, dimension + 1, inData, outData);
79
80 inData += m_DimData[dimension].m_Stride1;
81 outData += m_DimData[dimension].m_StrideOut;
82
83 inDataMovement += m_DimData[dimension].m_Stride1;
84 outDataMovement += m_DimData[dimension].m_StrideOut;
85 }
86
87 // move iterator back to the start
88 inData -= inDataMovement;
89 outData -= outDataMovement;
90 }
91
telsoa014fcda012018-03-09 14:13:49 +000092private:
telsoa01c577f2c2018-08-31 09:22:23 +010093 // Struct to hold the dimension data.
telsoa014fcda012018-03-09 14:13:49 +000094 struct BroadcastDimensionData
95 {
96 unsigned int m_DimSize;
97 unsigned int m_StrideOut;
98 unsigned int m_Stride1;
99 unsigned int m_Stride2;
100 };
101
102 std::vector<BroadcastDimensionData> m_DimData;
103};
104
105} //namespace armnn