blob: e92ed0598db2173593f966b9921a59e2fa717d24 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include <armnn/Tensor.hpp>
7
8#include <functional>
9
10namespace armnn
11{
12
13struct BroadcastLoop
14{
15 BroadcastLoop(const TensorShape& inShape0, const TensorShape& inShape1, const TensorShape& outShape);
16
17 unsigned int GetNumDimensions()
18 {
19 return static_cast<unsigned int>(m_DimData.size());
20 }
21
22 template <typename T0, typename T1, typename U, typename Func>
23 void Unroll(Func operationFunc,
24 unsigned int dimension,
25 const T0* inData0,
26 const T1* inData1,
27 U* outData)
28 {
29 if (dimension >= GetNumDimensions())
30 {
31 *outData = operationFunc(*inData0, *inData1);
32 return;
33 }
34
35 for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++)
36 {
37 Unroll(operationFunc, dimension + 1, inData0, inData1, outData);
38
39 inData0 += m_DimData[dimension].m_Stride1;
40 inData1 += m_DimData[dimension].m_Stride2;
41 outData += m_DimData[dimension].m_StrideOut;
42 }
43 }
44
45private:
telsoa01c577f2c2018-08-31 09:22:23 +010046 // Struct to hold the dimension data.
telsoa014fcda012018-03-09 14:13:49 +000047 struct BroadcastDimensionData
48 {
49 unsigned int m_DimSize;
50 unsigned int m_StrideOut;
51 unsigned int m_Stride1;
52 unsigned int m_Stride2;
53 };
54
55 std::vector<BroadcastDimensionData> m_DimData;
56};
57
58} //namespace armnn