blob: 857d29bbf0ba0197f7e3f69f9f82d9ae182d3ddc [file] [log] [blame]
arovir01b0717b52018-09-05 17:03:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "HalPolicy.hpp"
7
8#include "../1.0/HalPolicy.hpp"
9
10namespace armnn_driver
11{
12namespace hal_1_1
13{
14
15bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, ConversionData& data)
16{
17 if (compliantWithV1_0(operation))
18 {
19 hal_1_0::HalPolicy::Operation v10Operation = convertToV1_0(operation);
20 hal_1_0::HalPolicy::Model v10Model = convertToV1_0(model);
21
22 return hal_1_0::HalPolicy::ConvertOperation(v10Operation, v10Model, data);
23 }
24 else
25 {
26 switch (operation.type)
27 {
28 case V1_1::OperationType::DIV:
29 return ConvertDiv(operation, model, data);
David Beck38e12942018-09-12 16:02:24 +010030 case V1_1::OperationType::SUB:
31 return ConvertSub(operation, model, data);
arovir01b0717b52018-09-05 17:03:25 +010032 default:
33 return Fail("%s: Operation type %s not supported in ArmnnDriver",
34 __func__, toString(operation.type).c_str());
35 }
36 }
37}
38
39bool HalPolicy::ConvertDiv(const Operation& operation, const Model& model, ConversionData& data)
40{
41 LayerInputHandle input0 = ConvertToLayerInputHandle(operation, 0, model, data);
42 LayerInputHandle input1 = ConvertToLayerInputHandle(operation, 1, model, data);
43
44 if (!input0.IsValid() || !input1.IsValid())
45 {
46 return Fail("%s: Operation has invalid inputs", __func__);
47 }
48
49 // The FuseActivation parameter is always the input index 2
50 // and it should be optional
51 ActivationFn activationFunction;
52 if (!GetOptionalInputActivation(operation, 2, activationFunction, model, data))
53 {
54 return Fail("%s: Operation has invalid inputs", __func__);
55 }
56
57 const Operand* outputOperand = GetOutputOperand(operation, 0, model);
58 if (!outputOperand)
59 {
60 return false;
61 }
62
63 const armnn::TensorInfo& outInfo = GetTensorInfoForOperand(*outputOperand);
64
65 if (!IsLayerSupported(__func__,
66 armnn::IsDivisionSupported,
67 data.m_Compute,
68 input0.GetTensorInfo(),
69 input1.GetTensorInfo(),
70 outInfo))
71 {
72 return false;
73 }
74
75 armnn::IConnectableLayer* const startLayer = data.m_Network->AddDivisionLayer();
76 armnn::IConnectableLayer* const endLayer = ProcessActivation(outInfo, activationFunction, startLayer, data);
77
78 const armnn::TensorInfo& inputTensorInfo0 = input0.GetTensorInfo();
79 const armnn::TensorInfo& inputTensorInfo1 = input1.GetTensorInfo();
80
81 if (endLayer)
82 {
83 BroadcastTensor(input0, input1, startLayer, *data.m_Network);
84 return SetupAndTrackLayerOutputSlot(operation, 0, *endLayer, model, data);
85 }
86
87 return Fail("%s: ProcessActivation failed", __func__);
88}
89
David Beck38e12942018-09-12 16:02:24 +010090bool HalPolicy::ConvertSub(const Operation& operation, const Model& model, ConversionData& data)
91{
92 LayerInputHandle input0 = ConvertToLayerInputHandle(operation, 0, model, data);
93 LayerInputHandle input1 = ConvertToLayerInputHandle(operation, 1, model, data);
94
95 if (!input0.IsValid() || !input1.IsValid())
96 {
97 return Fail("%s: Operation has invalid inputs", __func__);
98 }
99
100 // The FuseActivation parameter is always the input index 2
101 // and it should be optional
102 ActivationFn activationFunction;
103 if (!GetOptionalInputActivation(operation, 2, activationFunction, model, data))
104 {
105 return Fail("%s: Operation has invalid inputs", __func__);
106 }
107
108 const Operand* outputOperand = GetOutputOperand(operation, 0, model);
109 if (!outputOperand)
110 {
111 return false;
112 }
113
114 const armnn::TensorInfo& outInfo = GetTensorInfoForOperand(*outputOperand);
115
116 if (!IsLayerSupported(__func__,
117 armnn::IsSubtractionSupported,
118 data.m_Compute,
119 input0.GetTensorInfo(),
120 input1.GetTensorInfo(),
121 outInfo))
122 {
123 return false;
124 }
125
126 armnn::IConnectableLayer* const startLayer = data.m_Network->AddSubtractionLayer();
127 armnn::IConnectableLayer* const endLayer = ProcessActivation(outInfo, activationFunction, startLayer, data);
128
129 const armnn::TensorInfo& inputTensorInfo0 = input0.GetTensorInfo();
130 const armnn::TensorInfo& inputTensorInfo1 = input1.GetTensorInfo();
131
132 if (endLayer)
133 {
134 BroadcastTensor(input0, input1, startLayer, *data.m_Network);
135 return SetupAndTrackLayerOutputSlot(operation, 0, *endLayer, model, data);
136 }
137
138 return Fail("%s: ProcessActivation failed", __func__);
139}
140
arovir01b0717b52018-09-05 17:03:25 +0100141} // namespace hal_1_1
142} // namespace armnn_driver