blob: 2b3a31488854875e526dbb02457396e5e4911092 [file] [log] [blame]
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "armnn/backends/Workload.hpp"
9
Matthew Sloyan5c54c382022-11-09 16:28:51 +000010#include <graph_status.h>
11#include <model_runner.h>
12
Francis Murtagh9270d9e2022-08-12 13:54:17 +010013#include <memory>
14#include <string>
15#include <vector>
16
17namespace armnn
18{
19
20bool TosaRefPreCompiledWorkloadValidate(std::string* reasonIfUnsupported);
21
22class TosaRefPreCompiledWorkload : public BaseWorkload<PreCompiledQueueDescriptor>
23{
24public:
25 TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor& descriptor,
26 const WorkloadInfo& info);
27 void Execute() const override;
28
29private:
30 bool SupportsTensorHandleReplacement() const override
31 {
32 return true;
33 }
34
35 void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
36 {
37 this->m_Data.m_Inputs[slot] = tensorHandle;
38 }
39
40 void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
41 {
42 this->m_Data.m_Outputs[slot] = tensorHandle;
43 }
Matthew Sloyan5c54c382022-11-09 16:28:51 +000044
45 template <typename T>
46 void SetInput(TosaReference::IModelRunner& runner, std::string inputName, uint32_t inputIndex) const;
47
48 template <typename T>
49 void GetOutput(TosaReference::IModelRunner& runner, std::string outputName, uint32_t outputIndex) const;
50
51 WorkloadInfo m_workloadInfo;
Francis Murtagh9270d9e2022-08-12 13:54:17 +010052};
53
54} //namespace armnn