blob: b23c416c4f248d9bf10fe04c3da1578760b825de [file] [log] [blame]
Mike Kelly831faed2018-11-28 11:52:08 +00001//
Teresa Charlin2ea403d2023-06-19 12:06:19 +01002// Copyright © 2017-2018,2020-2023 Arm Ltd and Contributors. All rights reserved.
Mike Kelly831faed2018-11-28 11:52:08 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Teresa Charlin588cbdf2022-01-19 15:55:37 +00008#include "ClBaseWorkload.hpp"
Teresa Charlin2ea403d2023-06-19 12:06:19 +01009#include "ClWorkloadUtils.hpp"
10
Matthew Bentham9b3e7382020-02-05 21:39:55 +000011#include <arm_compute/runtime/CL/functions/CLBatchToSpaceLayer.h>
Teresa Charlin2ea403d2023-06-19 12:06:19 +010012#include <arm_compute/runtime/CL/functions/CLReshapeLayer.h>
Mike Kelly831faed2018-11-28 11:52:08 +000013
14namespace armnn
15{
16
17arm_compute::Status ClBatchToSpaceNdWorkloadValidate(const TensorInfo& input,
18 const TensorInfo& output,
Keith Davisbcd860a2021-08-05 14:20:33 +010019 const BatchToSpaceNdDescriptor& descriptor);
Mike Kelly831faed2018-11-28 11:52:08 +000020
Teresa Charlin588cbdf2022-01-19 15:55:37 +000021class ClBatchToSpaceNdWorkload : public ClBaseWorkload<BatchToSpaceNdQueueDescriptor>
Mike Kelly831faed2018-11-28 11:52:08 +000022{
23public:
Sadik Armagane9444752020-12-02 11:28:58 +000024 ClBatchToSpaceNdWorkload(const BatchToSpaceNdQueueDescriptor& descriptor,
25 const WorkloadInfo& info,
26 const arm_compute::CLCompileContext& clCompileContext);
Mike Kelly831faed2018-11-28 11:52:08 +000027
Teresa Charlin2ea403d2023-06-19 12:06:19 +010028 virtual void Execute() const override;
Mike Kelly831faed2018-11-28 11:52:08 +000029
30private:
Mike Kelly831faed2018-11-28 11:52:08 +000031 mutable arm_compute::CLBatchToSpaceLayer m_Layer;
Teresa Charlin2ea403d2023-06-19 12:06:19 +010032 mutable std::unique_ptr<arm_compute::CLReshapeLayer> m_LayerReshapeInput;
33 mutable std::unique_ptr<arm_compute::CLReshapeLayer> m_LayerReshapeOutput;
34 arm_compute::CLTensor m_ReshapeInputTensor;
35 arm_compute::CLTensor m_ReshapeOutputTensor;
Mike Kelly831faed2018-11-28 11:52:08 +000036};
37
38} //namespace armnn