blob: 8e2a410b0cd3e9e350728514aae6cbc9471fdcb8 [file] [log] [blame]
Keith Davis3ae3f972021-05-21 16:33:48 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <backendsCommon/Workload.hpp>
9#include <backendsCommon/WorkloadData.hpp>
10
11#include "RefWorkloadUtils.hpp"
12
13namespace armnn
14{
15
16struct RefShapeWorkload : public BaseWorkload<ShapeQueueDescriptor>
17{
18public:
19 using BaseWorkload<ShapeQueueDescriptor>::BaseWorkload;
20 virtual void Execute() const override
21 {
22 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
23 }
24 void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override
25 {
26 Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
27 }
28
29private:
30 void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
31 {
32 const TensorShape Shape = GetTensorInfo(inputs[0]).GetShape();
33
34 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
35
36 unsigned int numBytes =
37 GetTensorInfo(inputs[0]).GetNumDimensions() * GetDataTypeSize(outputInfo.GetDataType());
38
39 std::memcpy(outputs[0]->Map(), &Shape, numBytes);
40 outputs[0]->Unmap();
41 }
42};
43
44} //namespace armnn
45
46
47
48