blob: 0fb029dea01b6340baae34d04b4e62cf6a741267 [file] [log] [blame]
Matthew Benthamf61c2702019-04-23 16:43:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Tensor.hpp>
9
10#include "../ConversionUtils.hpp"
11
12namespace armnn_driver
13{
14
15inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape &inputShape,
16 const armnn::TensorShape &weightsShape)
17{
18 if (inputShape.GetNumDimensions() > 2U)
19 {
20 unsigned int dim0 = inputShape[0];
21 unsigned int dim1 = inputShape[1];
22
23 for (unsigned int i = 2U; i < inputShape.GetNumDimensions(); ++i)
24 {
25 dim1 *= inputShape[i];
26 }
27
28 unsigned int divisor = weightsShape[1] / dim1;
29 if(dim0 % divisor != 0)
30 {
31 throw std::runtime_error("Failed to deduce tensor shape");
32 }
33
34 return armnn::TensorShape({dim0 / divisor, dim1 * divisor});
35 }
36 else
37 {
38 return inputShape;
39 }
40}
41
42}