blob: b7ed761e0c0ef3562cffb280fc13226ac65316fe [file] [log] [blame]
Keith Davis3ae3f972021-05-21 16:33:48 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Finn Williams73c547d2022-02-15 20:47:34 +00008#include "RefBaseWorkload.hpp"
Colm Donelan0c479742021-12-10 12:43:54 +00009#include <armnn/backends/WorkloadData.hpp>
Keith Davis3ae3f972021-05-21 16:33:48 +010010
11#include "RefWorkloadUtils.hpp"
12
13namespace armnn
14{
15
Finn Williams73c547d2022-02-15 20:47:34 +000016struct RefShapeWorkload : public RefBaseWorkload<ShapeQueueDescriptor>
Keith Davis3ae3f972021-05-21 16:33:48 +010017{
18public:
Finn Williams73c547d2022-02-15 20:47:34 +000019 using RefBaseWorkload<ShapeQueueDescriptor>::RefBaseWorkload;
Keith Davis3ae3f972021-05-21 16:33:48 +010020 virtual void Execute() const override
21 {
22 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
23 }
24 void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override
25 {
26 Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
27 }
28
29private:
30 void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
31 {
32 const TensorShape Shape = GetTensorInfo(inputs[0]).GetShape();
33
34 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
35
36 unsigned int numBytes =
37 GetTensorInfo(inputs[0]).GetNumDimensions() * GetDataTypeSize(outputInfo.GetDataType());
38
39 std::memcpy(outputs[0]->Map(), &Shape, numBytes);
40 outputs[0]->Unmap();
41 }
42};
43
44} //namespace armnn
45
46
47
48