blob: fc95ef439e7b9cf2d9ff15113c3f9dfb1eb17ded [file] [log] [blame]
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// See LICENSE file in the project root for full license information.
//
#include "NeonInterceptorScheduler.hpp"
#include <boost/assert.hpp>
namespace armnn{
NeonInterceptorScheduler::NeonInterceptorScheduler(NeonTimer::KernelMeasurements& kernels,
arm_compute::IScheduler &realScheduler)
: m_Kernels(kernels), m_RealScheduler(realScheduler)
{
}
void NeonInterceptorScheduler::set_num_threads(unsigned int numThreads)
{
m_RealScheduler.set_num_threads(numThreads);
}
unsigned int NeonInterceptorScheduler::num_threads() const
{
return m_RealScheduler.num_threads();
}
void NeonInterceptorScheduler::schedule(arm_compute::ICPPKernel* kernel, const Hints& hints)
{
m_Timer.Start();
m_RealScheduler.schedule(kernel, hints.split_dimension());
m_Timer.Stop();
std::vector<Measurement> measurements = m_Timer.GetMeasurements();
BOOST_ASSERT(!measurements.empty());
Measurement measurement(measurements.front()); // NOTE: 1st measurement is delta
measurement.m_Name = kernel->name();
m_Kernels.push_back(std::move(measurement));
}
void NeonInterceptorScheduler::run_workloads(std::vector <Workload>& workloads)
{
m_Timer.Start();
m_RealScheduler.run_workloads(workloads);
m_Timer.Stop();
std::vector<Measurement> measurements = m_Timer.GetMeasurements();
BOOST_ASSERT_MSG(measurements.size() == 3, "WallClockTimer does not have correct amount of measurements.");
// WallClockTimer has 3 measurements, duration always being the first.
Measurement measurement(measurements.front());
measurement.m_Name = "Workload";
m_Kernels.push_back(std::move(measurement));
}
} // namespace armnn