blob: 4b0773fb48b0845a086237888c7c6f976b8a37f0 [file] [log] [blame]
Teresa Charlinddbda6a2024-02-07 22:58:29 +00001//
2// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "GpuFsaActivation.hpp"
7
8#include <aclCommon/ArmComputeTensorUtils.hpp>
9
10#include <arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadContext.h>
11#include <arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h>
12#include <arm_compute/dynamic_fusion/sketch/gpu/operators/GpuTanh.h>
13#include <arm_compute/dynamic_fusion/sketch/gpu/operators/GpuSigmoid.h>
14#include <arm_compute/dynamic_fusion/sketch/gpu/operators/GpuOutput.h>
15
16using namespace arm_compute::experimental::dynamic_fusion;
17using namespace armnn::armcomputetensorutils;
18
19namespace armnn
20{
21
22arm_compute::Status GpuFsaActivationValidate(const TensorInfo& input,
23 const ActivationDescriptor& descriptor)
24{
25 // Create a new workload sketch, for validation purposes
26 auto compileCtx = arm_compute::CLKernelLibrary::get().get_compile_context();
27 auto workloadContext = GpuWorkloadContext(&compileCtx);
28 GpuWorkloadSketch sketch{ &workloadContext };
29
30 arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input, input.GetNumDimensions());
31 aclInputInfo.set_are_values_constant(input.IsConstant());
32
33 arm_compute::ITensorInfo* inputInfo = workloadContext.create_tensor_info(aclInputInfo);
34
35 switch (descriptor.m_Function)
36 {
37 case ActivationFunction::TanH:
38 {
39 if ( descriptor.m_A != 1 || descriptor.m_B != 1)
40 {
41 return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
42 "Activation function TanH only works with a=1 and b=1");
43 }
44 return GpuTanh::validate_op(sketch, inputInfo);
45 }
46 case ActivationFunction::Sigmoid:
47 {
48 return GpuSigmoid::validate_op(sketch, inputInfo);
49 }
50 default:
51 return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
52 std::string("Activation function currently not supported in GpuFsa: ")
53 + GetActivationFunctionAsCString(descriptor.m_Function));
54 }
55
56}
57
58void GpuFsaActivationCreateOp(GpuFsaPreCompiledBlob* blob,
59 const TensorInfo& input,
60 const ActivationDescriptor& descriptor)
61{
62 GpuWorkloadSketch* sketch = blob->sketch.get();
63 GpuWorkloadContext* workloadContext = blob->workloadContext.get();
64 std::vector<arm_compute::ITensorInfo*> inputTensorInfos = {};
65 std::vector<arm_compute::ITensorInfo*> outputTensorInfos = {};
66
67 arm_compute::TensorInfo aclInput0Info = BuildArmComputeTensorInfo(input, input.GetNumDimensions());
68
69 aclInput0Info.set_are_values_constant(input.IsConstant());
70
71 inputTensorInfos.emplace_back(workloadContext->create_tensor_info(aclInput0Info));
72
73 // Validate operator, check status and update reasonIfUnsupported
74 arm_compute::Status aclStatus{};
75 switch (descriptor.m_Function)
76 {
77 case ActivationFunction::TanH:
78 {
79 aclStatus = GpuTanh::validate_op(*sketch, inputTensorInfos[0]);
80 break;
81 }
82 case ActivationFunction::Sigmoid:
83 {
84 aclStatus = GpuSigmoid::validate_op(*sketch, inputTensorInfos[0]);
85 break;
86 }
87 default:
88 throw InvalidArgumentException(std::string("Activation function currently not supported in GpuFsa: ")
89 + GetActivationFunctionAsCString(descriptor.m_Function));
90
91 }
92 const bool supported = aclStatus.error_code() == arm_compute::ErrorCode::OK;
93 if (!supported)
94 {
95 throw BackendCapabilityException("\"GpuFsa\" backend failed during Activation layer validation");
96 }
97
98 arm_compute::ITensorInfo* activationOutputInfo{};
99 switch (descriptor.m_Function)
100 {
101 case ActivationFunction::TanH:
102 {
103 activationOutputInfo = GpuTanh::create_op(*sketch, inputTensorInfos[0]);
104 break;
105 }
106 case ActivationFunction::Sigmoid:
107 {
108 activationOutputInfo = GpuSigmoid::create_op(*sketch, inputTensorInfos[0]);
109 break;
110 }
111 default:
112 throw InvalidArgumentException(std::string("Activation function currently not supported in GpuFsa: ")
113 + GetActivationFunctionAsCString(descriptor.m_Function));
114
115 }
116
117 // Temporary fix until fusing attempt is make for GpuFsa backend and Output layer workload is created.
118 outputTensorInfos.emplace_back(workloadContext->create_tensor_info());
119 GpuOutput::create_op(*sketch, activationOutputInfo, outputTensorInfos[0]);
120
121 // Store the TensorInfos within the blob as unique_ptrs to be used later
122 blob->inputTensorInfos = std::make_unique<std::vector<arm_compute::ITensorInfo*>>(inputTensorInfos);
123 blob->outputTensorInfos = std::make_unique<std::vector<arm_compute::ITensorInfo*>>(outputTensorInfos);
124}
125
126} // namespace armnn