blob: bc4d50ac927abb86af84a6954d780d66d8e736e0 [file] [log] [blame]
Keith Davis3ae3f972021-05-21 16:33:48 +01001//
Matthew Sloyan2d213a72022-06-30 17:13:04 +01002// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
Keith Davis3ae3f972021-05-21 16:33:48 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Finn Williams73c547d2022-02-15 20:47:34 +00008#include "RefBaseWorkload.hpp"
Colm Donelan0c479742021-12-10 12:43:54 +00009#include <armnn/backends/WorkloadData.hpp>
Keith Davis3ae3f972021-05-21 16:33:48 +010010
11#include "RefWorkloadUtils.hpp"
12
13namespace armnn
14{
15
Finn Williams73c547d2022-02-15 20:47:34 +000016struct RefShapeWorkload : public RefBaseWorkload<ShapeQueueDescriptor>
Keith Davis3ae3f972021-05-21 16:33:48 +010017{
18public:
Finn Williams73c547d2022-02-15 20:47:34 +000019 using RefBaseWorkload<ShapeQueueDescriptor>::RefBaseWorkload;
Keith Davis3ae3f972021-05-21 16:33:48 +010020 virtual void Execute() const override
21 {
22 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
23 }
Matthew Sloyan2d213a72022-06-30 17:13:04 +010024 void ExecuteAsync(ExecutionData& executionData) override
Keith Davis3ae3f972021-05-21 16:33:48 +010025 {
Matthew Sloyan2d213a72022-06-30 17:13:04 +010026 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
27 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
Keith Davis3ae3f972021-05-21 16:33:48 +010028 }
29
30private:
31 void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
32 {
33 const TensorShape Shape = GetTensorInfo(inputs[0]).GetShape();
34
35 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
36
37 unsigned int numBytes =
38 GetTensorInfo(inputs[0]).GetNumDimensions() * GetDataTypeSize(outputInfo.GetDataType());
39
40 std::memcpy(outputs[0]->Map(), &Shape, numBytes);
41 outputs[0]->Unmap();
42 }
43};
44
45} //namespace armnn
46
47
48
49