blob: 5bf6be8939c27dd422a0bfb242a1e61c2d055952 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +01006#include "BaseIterator.hpp"
telsoa014fcda012018-03-09 14:13:49 +00007#include <armnn/Tensor.hpp>
8
9#include <functional>
10
11namespace armnn
12{
13
14struct BroadcastLoop
15{
16 BroadcastLoop(const TensorShape& inShape0, const TensorShape& inShape1, const TensorShape& outShape);
17
18 unsigned int GetNumDimensions()
19 {
20 return static_cast<unsigned int>(m_DimData.size());
21 }
22
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010023 template <typename Func, typename DecoderOp, typename EncoderOp>
telsoa014fcda012018-03-09 14:13:49 +000024 void Unroll(Func operationFunc,
25 unsigned int dimension,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010026 DecoderOp& inData0,
27 DecoderOp& inData1,
28 EncoderOp& outData)
telsoa014fcda012018-03-09 14:13:49 +000029 {
30 if (dimension >= GetNumDimensions())
31 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010032 outData.Set(operationFunc(inData0.Get(), inData1.Get()));
telsoa014fcda012018-03-09 14:13:49 +000033 return;
34 }
35
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010036 unsigned int inData0Movement = 0;
37 unsigned int inData1Movement = 0;
38 unsigned int outDataMovement = 0;
39
telsoa014fcda012018-03-09 14:13:49 +000040 for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++)
41 {
42 Unroll(operationFunc, dimension + 1, inData0, inData1, outData);
43
44 inData0 += m_DimData[dimension].m_Stride1;
45 inData1 += m_DimData[dimension].m_Stride2;
46 outData += m_DimData[dimension].m_StrideOut;
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010047
48 inData0Movement += m_DimData[dimension].m_Stride1;
49 inData1Movement += m_DimData[dimension].m_Stride2;
50 outDataMovement += m_DimData[dimension].m_StrideOut;
telsoa014fcda012018-03-09 14:13:49 +000051 }
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010052
53 // move iterator back to the start
54 inData0 -= inData0Movement;
55 inData1 -= inData1Movement;
56 outData -= outDataMovement;
telsoa014fcda012018-03-09 14:13:49 +000057 }
58
59private:
telsoa01c577f2c2018-08-31 09:22:23 +010060 // Struct to hold the dimension data.
telsoa014fcda012018-03-09 14:13:49 +000061 struct BroadcastDimensionData
62 {
63 unsigned int m_DimSize;
64 unsigned int m_StrideOut;
65 unsigned int m_Stride1;
66 unsigned int m_Stride2;
67 };
68
69 std::vector<BroadcastDimensionData> m_DimData;
70};
71
72} //namespace armnn