IVGCVSW-5742 'NonConstWeights: Update FullyConnected in android-nn-driver'

* Enabled weights and bias as inputs in FULLY_CONNECTED operator.

!armnn:5180

Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: Id325a8bf5be5a772191d27ae89485e992f0c48fa
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 3432d9f..e5f99ed 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -3034,26 +3034,72 @@
     const armnn::TensorInfo& inputInfo  = input.GetTensorInfo();
     const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
 
-    ConstTensorPin weightsPin = DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 1);
-    ConstTensorPin biasPin    = ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 2, model, data); // 1D
-
-    if (!weightsPin.IsValid())
+    LayerInputHandle weightsInput = LayerInputHandle();
+    const HalOperand* weightsOperand = GetInputOperand<HalPolicy>(operation, 1, model);
+    if (!weightsOperand)
     {
-        return Fail("%s: Operation has invalid weights", __func__);
+        return Fail("%s: Could not read weights", __func__);
     }
 
-    if (!biasPin.IsValid())
+    const armnn::TensorInfo& weightsInfo = GetTensorInfoForOperand(*weightsOperand);
+    bool constantWeights = IsOperandConstant<HalPolicy>(*weightsOperand);
+
+    armnn::Optional<armnn::ConstTensor> optionalWeights = armnn::EmptyOptional();
+    if (!constantWeights)
     {
-        return Fail("%s: Operation has invalid bias", __func__);
+        weightsInput = ConvertToLayerInputHandle<HalPolicy>(operation, 1, model, data);
+        if (!weightsInput.IsValid())
+        {
+            return Fail("%s: Operation has invalid inputs", __func__);
+        }
+    }
+    else
+    {
+        ConstTensorPin weightsPin = DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 1);
+        if (!weightsPin.IsValid())
+        {
+            return Fail("%s: Operation has invalid weights", __func__);
+        }
+        optionalWeights = armnn::Optional<armnn::ConstTensor>(weightsPin.GetConstTensor());
     }
 
-    armnn::ConstTensor weights = weightsPin.GetConstTensor();
-    armnn::ConstTensor bias    = biasPin.GetConstTensor();
+    LayerInputHandle biasInput = LayerInputHandle();
+    const HalOperand* biasOperand = GetInputOperand<HalPolicy>(operation, 2, model);
+    if (!biasOperand)
+    {
+        return Fail("%s: Could not read bias", __func__);
+    }
+    armnn::TensorInfo biasInfo = GetTensorInfoForOperand(*biasOperand);
+    bool constantBias = IsOperandConstant<HalPolicy>(*biasOperand);
+
+    armnn::Optional<armnn::ConstTensor> optionalBias = armnn::EmptyOptional();
+    if (!constantBias)
+    {
+        biasInput = ConvertToLayerInputHandle<HalPolicy>(operation, 2, model, data);
+        if (!biasInput.IsValid())
+        {
+            return Fail("%s: Operation has invalid inputs", __func__);
+        }
+    }
+    else
+    {
+        ConstTensorPin biasPin = ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 2, model, data); // 1D
+        if (!biasPin.IsValid())
+        {
+            return Fail("%s: Operation has invalid bias", __func__);
+        }
+        optionalBias = armnn::Optional<armnn::ConstTensor>(biasPin.GetConstTensor());
+    }
+
+    if ((constantWeights && !constantBias) || (!constantWeights && constantBias))
+    {
+        return Fail("%s: Non-compatible weights and bias", __func__);
+    }
+
     armnn::TensorInfo reshapedInfo = inputInfo;
-
     try
     {
-        reshapedInfo.SetShape(FlattenFullyConnectedInput(inputInfo.GetShape(), weights.GetInfo().GetShape()));
+        reshapedInfo.SetShape(FlattenFullyConnectedInput(inputInfo.GetShape(), weightsInfo.GetShape()));
     }
     catch (const std::exception& e)
     {
@@ -3061,7 +3107,7 @@
     }
 
     // ensuring that the bias value is within 1% of the weights input (small float differences can exist)
-    SanitizeBiasQuantizationScale(bias.GetInfo(), weights.GetInfo(), reshapedInfo);
+    SanitizeBiasQuantizationScale(biasInfo, weightsInfo, reshapedInfo);
 
     ActivationFn activationFunction;
     if (!GetInputActivationFunction<HalPolicy>(operation, 3, activationFunction, model, data))
@@ -3072,12 +3118,13 @@
     armnn::FullyConnectedDescriptor desc;
     desc.m_TransposeWeightMatrix = true;
     desc.m_BiasEnabled           = true;
+    desc.m_ConstantWeights       = constantWeights;
 
     bool isSupported = false;
     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
     {
         if (!VerifyFullyConnectedShapes(reshapedInfo.GetShape(),
-                                        weights.GetInfo().GetShape(),
+                                        weightsInfo.GetShape(),
                                         outputInfo.GetShape(),
                                         desc.m_TransposeWeightMatrix))
         {
@@ -3087,14 +3134,14 @@
         }
 
         FORWARD_LAYER_SUPPORT_FUNC(__func__,
-                               IsFullyConnectedSupported,
-                               data.m_Backends,
-                               isSupported,
-                               reshapedInfo,
-                               outputInfo,
-                               weights.GetInfo(),
-                               bias.GetInfo(),
-                               desc);
+                                   IsFullyConnectedSupported,
+                                   data.m_Backends,
+                                   isSupported,
+                                   reshapedInfo,
+                                   outputInfo,
+                                   weightsInfo,
+                                   biasInfo,
+                                   desc);
     };
 
     if(!IsDynamicTensor(outputInfo))
@@ -3112,7 +3159,9 @@
     }
 
     armnn::IConnectableLayer* startLayer =
-            data.m_Network->AddFullyConnectedLayer(desc, weights, armnn::Optional<armnn::ConstTensor>(bias));
+            data.m_Network->AddFullyConnectedLayer(desc,
+                                                   optionalWeights,
+                                                   optionalBias);
 
     if (inputInfo.GetNumDimensions() > 2U)
     {
@@ -3130,6 +3179,13 @@
         input.Connect(startLayer->GetInputSlot(0));
     }
 
+    // connect weights input
+    if (!desc.m_ConstantWeights)
+    {
+        weightsInput.Connect(startLayer->GetInputSlot(1));
+        biasInput.Connect(startLayer->GetInputSlot(2));
+    }
+
     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *startLayer, model,
                                                    data, nullptr, validateFunc, activationFunction);
 }
diff --git a/test/FullyConnected.cpp b/test/FullyConnected.cpp
index 8550c8d..a68a587 100644
--- a/test/FullyConnected.cpp
+++ b/test/FullyConnected.cpp
@@ -264,4 +264,116 @@
     BOOST_TEST(outdata[7] == 8);
 }
 
+BOOST_AUTO_TEST_CASE(TestFullyConnectedWeightsAsInput)
+{
+    auto driver = std::make_unique<ArmnnDriver>(DriverOptions(armnn::Compute::CpuRef));
+
+    V1_0::ErrorStatus error;
+    std::vector<bool> sup;
+
+    ArmnnDriver::getSupportedOperations_cb cb = [&](V1_0::ErrorStatus status, const std::vector<bool>& supported)
+    {
+        error = status;
+        sup = supported;
+    };
+
+    HalPolicy::Model model = {};
+
+    // operands
+    int32_t actValue      = 0;
+    float   weightValue[] = {1, 0, 0, 0, 0, 0, 0, 0,
+                             0, 1, 0, 0, 0, 0, 0, 0,
+                             0, 0, 1, 0, 0, 0, 0, 0,
+                             0, 0, 0, 1, 0, 0, 0, 0,
+                             0, 0, 0, 0, 1, 0, 0, 0,
+                             0, 0, 0, 0, 0, 1, 0, 0,
+                             0, 0, 0, 0, 0, 0, 1, 0,
+                             0, 0, 0, 0, 0, 0, 0, 1}; //identity
+    float   biasValue[]   = {0, 0, 0, 0, 0, 0, 0, 0};
+
+    // fully connected operation
+    AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 1, 1, 8});
+    AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{8, 8});
+    AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{8});
+    AddIntOperand<HalPolicy>(model, actValue);
+    AddOutputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 8});
+
+    model.operations.resize(1);
+
+    model.operations[0].type = HalPolicy::OperationType::FULLY_CONNECTED;
+    model.operations[0].inputs  = hidl_vec<uint32_t>{0,1,2,3};
+    model.operations[0].outputs = hidl_vec<uint32_t>{4};
+
+    // make the prepared model
+    android::sp<V1_0::IPreparedModel> preparedModel = PrepareModel(model, *driver);
+
+    // construct the request for input
+    V1_0::DataLocation inloc = {};
+    inloc.poolIndex          = 0;
+    inloc.offset             = 0;
+    inloc.length             = 8 * sizeof(float);
+    RequestArgument input    = {};
+    input.location           = inloc;
+    input.dimensions         = hidl_vec<uint32_t>{1, 1, 1, 8};
+
+    // construct the request for weights as input
+    V1_0::DataLocation wloc = {};
+    wloc.poolIndex          = 1;
+    wloc.offset             = 0;
+    wloc.length             = 64 * sizeof(float);
+    RequestArgument weights = {};
+    weights.location        = wloc;
+    weights.dimensions      = hidl_vec<uint32_t>{8, 8};
+
+    // construct the request for bias as input
+    V1_0::DataLocation bloc = {};
+    bloc.poolIndex          = 2;
+    bloc.offset             = 0;
+    bloc.length             = 8 * sizeof(float);
+    RequestArgument bias    = {};
+    bias.location           = bloc;
+    bias.dimensions         = hidl_vec<uint32_t>{8};
+
+    V1_0::DataLocation outloc = {};
+    outloc.poolIndex          = 3;
+    outloc.offset             = 0;
+    outloc.length             = 8 * sizeof(float);
+    RequestArgument output    = {};
+    output.location           = outloc;
+    output.dimensions         = hidl_vec<uint32_t>{1, 8};
+
+    V1_0::Request request = {};
+    request.inputs  = hidl_vec<RequestArgument>{input, weights, bias};
+    request.outputs = hidl_vec<RequestArgument>{output};
+
+    // set the input data
+    float indata[] = {1,2,3,4,5,6,7,8};
+    AddPoolAndSetData(8, request, indata);
+
+    // set the weights data
+    AddPoolAndSetData(64, request, weightValue);
+    // set the bias data
+    AddPoolAndSetData(8, request, biasValue);
+
+    // add memory for the output
+    android::sp<IMemory> outMemory = AddPoolAndGetData<float>(8, request);
+    float* outdata = static_cast<float*>(static_cast<void*>(outMemory->getPointer()));
+
+    // run the execution
+    if (preparedModel != nullptr)
+    {
+        Execute(preparedModel, request);
+    }
+
+    // check the result
+    BOOST_TEST(outdata[0] == 1);
+    BOOST_TEST(outdata[1] == 2);
+    BOOST_TEST(outdata[2] == 3);
+    BOOST_TEST(outdata[3] == 4);
+    BOOST_TEST(outdata[4] == 5);
+    BOOST_TEST(outdata[5] == 6);
+    BOOST_TEST(outdata[6] == 7);
+    BOOST_TEST(outdata[7] == 8);
+}
+
 BOOST_AUTO_TEST_SUITE_END()