IVGCVSW-6862 Modify GATHERNd Neon workload

* Add validate for all layers for GatherNd
* Fix convert policy for Mul

Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I0f2bae5107607ba3c02b5546f60dd9623cd95853
diff --git a/src/backends/neon/workloads/NeonGatherNdWorkload.cpp b/src/backends/neon/workloads/NeonGatherNdWorkload.cpp
index 00c66cf..2d56ef2 100644
--- a/src/backends/neon/workloads/NeonGatherNdWorkload.cpp
+++ b/src/backends/neon/workloads/NeonGatherNdWorkload.cpp
@@ -11,33 +11,87 @@
 
 namespace armnn
 {
-arm_compute::Status NeonGatherNdWorkloadValidate(const TensorInfo& paramInfo,
+arm_compute::Status NeonGatherNdWorkloadValidate(const TensorInfo& paramsInfo,
                                                  const TensorInfo& indicesInfo,
                                                  const TensorInfo& outputInfo)
 {
     // Calculate ND, K, W, C.
-    std::map<std::string, unsigned int> keyIndices = CalculateGatherNdKeyIndices(paramInfo, indicesInfo);
+    std::map<std::string, unsigned int> keyIndices = CalculateGatherNdKeyIndices(paramsInfo, indicesInfo);
 
-    /// Call Gather with adequate shapes
-    // Reshape params into { K, C }
-    armnn::TensorInfo params_K_C_Info =  paramInfo;
+    /// Validate Mul
+    // Indices with shape { W, ND }
+    armnn::TensorInfo indices_W_ND_Info = indicesInfo;
+    indices_W_ND_Info.SetShape({ keyIndices["W"], keyIndices["ND"] });
+    const arm_compute::TensorInfo aclIndicesInfo = BuildArmComputeTensorInfo(indices_W_ND_Info);
+
+    // Flattened coefficients with shape { ND }
+    armnn::TensorInfo flattenedCoeff_Info = indicesInfo;
+    flattenedCoeff_Info.SetShape({ keyIndices["ND"] });
+    const arm_compute::TensorInfo aclFlattenedCoeffInfo = BuildArmComputeTensorInfo(flattenedCoeff_Info);
+
+    // Output of Mul with shape { W, ND }
+    const arm_compute::TensorInfo aclOutputMulInfo = BuildArmComputeTensorInfo(indices_W_ND_Info);
+
+    auto statusMul = arm_compute::NEPixelWiseMultiplication::validate(&aclIndicesInfo,
+                                                                      &aclFlattenedCoeffInfo,
+                                                                      &aclOutputMulInfo,
+                                                                      1.0f,
+                                                                      arm_compute::ConvertPolicy::WRAP,
+                                                                      arm_compute::RoundingPolicy::TO_ZERO,
+                                                                      arm_compute::ActivationLayerInfo());
+
+    /// Validate ReduceSum
+    // Flattened indices with shape { W }
+    armnn::TensorInfo flattenedIndices_Info = indicesInfo;
+    flattenedIndices_Info.SetShape({ keyIndices["W"] });
+    const arm_compute::TensorInfo aclFlattenedIndicesInfo = BuildArmComputeTensorInfo(flattenedIndices_Info);
+
+    const std::vector<unsigned int> armnnReduceAxes(1, 1);
+    arm_compute::Coordinates coords = BuildArmComputeReductionCoordinates(aclOutputMulInfo.num_dimensions(),
+                                                                          indices_W_ND_Info.GetNumDimensions(),
+                                                                          armnnReduceAxes);
+
+    auto statusReduceSum = arm_compute::NEReductionOperation::validate(&aclOutputMulInfo,
+                                                                       &aclFlattenedIndicesInfo,
+                                                                       static_cast<unsigned int>(coords[0]),
+                                                                       arm_compute::ReductionOperation::SUM,
+                                                                       false);;
+
+    /// Validate Gather
+    // Params with shape { K, C }
+    armnn::TensorInfo params_K_C_Info = paramsInfo;
     params_K_C_Info.SetShape({ keyIndices["K"], keyIndices["C"] });
+    const arm_compute::TensorInfo aclParamsInfo = BuildArmComputeTensorInfo(params_K_C_Info);
 
-    // Reshape indices into { W }
-    armnn::TensorInfo indices_W_Info = indicesInfo;
-    indices_W_Info.SetShape({ keyIndices["W"] });
-
-    // Reshape output to have the shape given by gather { W, C }
-    // (the original outputInfo has the shape given by gatherNd)
+    // Output of gather with shape { W, C }
     armnn::TensorInfo outputGather_Info = outputInfo;
     outputGather_Info.SetShape({ keyIndices["W"], keyIndices["C"] });
-
-    const arm_compute::TensorInfo aclParamsInfo  = BuildArmComputeTensorInfo(params_K_C_Info);
-    const arm_compute::TensorInfo aclIndicesInfo = BuildArmComputeTensorInfo(indices_W_Info);
-    const arm_compute::TensorInfo aclOutputInfo  = BuildArmComputeTensorInfo(outputGather_Info);
+    const arm_compute::TensorInfo aclOutputGatherInfo = BuildArmComputeTensorInfo(outputGather_Info);
 
     auto aclAxis = ComputeAclAxis(0, params_K_C_Info);
-    return arm_compute::NEGather::validate(&aclParamsInfo, &aclIndicesInfo, &aclOutputInfo, aclAxis);
+    auto statusGather =
+            arm_compute::NEGather::validate(&aclParamsInfo, &aclFlattenedIndicesInfo, &aclOutputGatherInfo, aclAxis);
+
+    /// Validate Reshape
+    const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(outputInfo);
+
+    auto statusReshape = arm_compute::NEReshapeLayer::validate(&aclOutputGatherInfo, &aclOutputInfo);
+
+    /// Return OK if all the layers are valid
+    auto okCode = arm_compute::ErrorCode::OK;
+    if (statusMul.error_code()       == okCode &&
+        statusReduceSum.error_code() == okCode &&
+        statusGather.error_code()    == okCode &&
+        statusReshape.error_code()   == okCode)
+    {
+        return arm_compute::Status(arm_compute::ErrorCode::OK,
+                                   "All GatherND layers validate status OK.");
+    }
+    else
+    {
+        return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
+                                   "GatherND layer validate status failed.");
+    }
 }
 
 NeonGatherNdWorkload::NeonGatherNdWorkload(const GatherNdQueueDescriptor& descriptor,
@@ -94,16 +148,11 @@
     armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_outputMul);
 
     // Multiply
-    auto convertPolicy = (IsQuantizedType(info.m_InputTensorInfos[0].GetDataType()) ||
-                          IsQuantizedType(info.m_InputTensorInfos[1].GetDataType())) ?
-                          arm_compute::ConvertPolicy::SATURATE :
-                          arm_compute::ConvertPolicy::WRAP;
-
     m_MulLayer.configure(&indices,
                          &m_FlattenedCoeff,
                          &m_outputMul,
                          1.0f,
-                         convertPolicy,
+                         arm_compute::ConvertPolicy::WRAP,
                          arm_compute::RoundingPolicy::TO_ZERO,
                          arm_compute::ActivationLayerInfo());