blob: 3a6184d22e9e0dcd23197d652a0fc91678ffe11a [file] [log] [blame]
Idriss Chaouch98e383e2023-08-28 14:28:31 +01001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefBroadcastToWorkload.hpp"
7#include "RefWorkloadUtils.hpp"
8#include "Profiling.hpp"
9#include "Broadcast.hpp"
10
11#include "Decoders.hpp"
12#include "Encoders.hpp"
13
14namespace armnn
15{
16
17RefBroadcastToWorkload::RefBroadcastToWorkload(const BroadcastToQueueDescriptor& descriptor, const WorkloadInfo& info)
18 : RefBaseWorkload(descriptor, info)
19{}
20
21void RefBroadcastToWorkload::Execute() const
22{
23 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
24}
25
26void RefBroadcastToWorkload::ExecuteAsync(ExecutionData& executionData)
27{
28 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
29 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
30}
31
32void RefBroadcastToWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
33{
34 ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefBroadcastToWorkload_Execute");
35 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
36 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
37
38 std::unique_ptr<Decoder<float>> input = MakeDecoder<float>(inputInfo, inputs[0]->Map());
39 std::unique_ptr<Encoder<float>> output= MakeEncoder<float>(outputInfo, outputs[0]->Map());
40
41 auto broadcastTo = [](float x)
42 {
43 return x;
44 };
45 BroadcastLoop(inputInfo.GetShape(), outputInfo.GetShape()).Unroll(broadcastTo,
46 0, *input, *output);
47}
48} // namespace armnn