blob: 7fe302a5adce70f29bc99b61bf52a4bd3944caf2 [file] [log] [blame]
Narumol Prangnawaratea54a012020-03-16 16:36:10 +00001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefConvertFp32ToBf16Workload.hpp"
7#include "RefWorkloadUtils.hpp"
8
9#include <armnnUtils/FloatingPointConverter.hpp>
10
11#include <BFloat16.hpp>
12
13namespace armnn
14{
15
16void RefConvertFp32ToBf16Workload::Execute() const
17{
Finn Williamsb8181f72021-04-07 10:23:21 +010018 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
19}
20
21void RefConvertFp32ToBf16Workload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor)
22{
23 Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
24}
25
26void RefConvertFp32ToBf16Workload::Execute(std::vector<ITensorHandle*> inputs,
27 std::vector<ITensorHandle*> outputs) const
28{
Narumol Prangnawaratea54a012020-03-16 16:36:10 +000029 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefConvertFp32ToBf16Workload_Execute");
30
Finn Williamsb8181f72021-04-07 10:23:21 +010031 const float* const input = reinterpret_cast<const float*>(inputs[0]->Map());
32 BFloat16* const output = reinterpret_cast<BFloat16*>(outputs[0]->Map());
Narumol Prangnawaratea54a012020-03-16 16:36:10 +000033
Finn Williamsb8181f72021-04-07 10:23:21 +010034 unsigned int numElements = GetTensorInfo(inputs[0]).GetNumElements();
Narumol Prangnawaratea54a012020-03-16 16:36:10 +000035 armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(input, numElements, output);
36}
37
38} //namespace armnn