blob: fa36f49003c14e97e95aeb87d41be1a9c570d06c [file] [log] [blame]
Keith Davis3ae3f972021-05-21 16:33:48 +01001//
Mike Kelly7cbe7812023-07-25 17:37:33 +01002// Copyright © 2021-2023 Arm Ltd and Contributors. All rights reserved.
Keith Davis3ae3f972021-05-21 16:33:48 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Finn Williams73c547d2022-02-15 20:47:34 +00008#include "RefBaseWorkload.hpp"
Colm Donelan0c479742021-12-10 12:43:54 +00009#include <armnn/backends/WorkloadData.hpp>
Keith Davis3ae3f972021-05-21 16:33:48 +010010
11#include "RefWorkloadUtils.hpp"
12
13namespace armnn
14{
15
Finn Williams73c547d2022-02-15 20:47:34 +000016struct RefShapeWorkload : public RefBaseWorkload<ShapeQueueDescriptor>
Keith Davis3ae3f972021-05-21 16:33:48 +010017{
18public:
Finn Williams73c547d2022-02-15 20:47:34 +000019 using RefBaseWorkload<ShapeQueueDescriptor>::RefBaseWorkload;
Keith Davis3ae3f972021-05-21 16:33:48 +010020 virtual void Execute() const override
21 {
22 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
23 }
Matthew Sloyan2d213a72022-06-30 17:13:04 +010024 void ExecuteAsync(ExecutionData& executionData) override
Keith Davis3ae3f972021-05-21 16:33:48 +010025 {
Matthew Sloyan2d213a72022-06-30 17:13:04 +010026 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
27 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
Keith Davis3ae3f972021-05-21 16:33:48 +010028 }
29
30private:
31 void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
32 {
Mike Kelly7cbe7812023-07-25 17:37:33 +010033 ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefShapeWorkload_Execute");
34
Keith Davis3ae3f972021-05-21 16:33:48 +010035 const TensorShape Shape = GetTensorInfo(inputs[0]).GetShape();
36
37 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
38
39 unsigned int numBytes =
40 GetTensorInfo(inputs[0]).GetNumDimensions() * GetDataTypeSize(outputInfo.GetDataType());
41
42 std::memcpy(outputs[0]->Map(), &Shape, numBytes);
43 outputs[0]->Unmap();
44 }
45};
46
47} //namespace armnn
48
49
50
51