blob: 26d61e4c7e971faeb875a1c21ad203583ab7b17b [file] [log] [blame]
Matthew Benthamf61c2702019-04-23 16:43:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Tensor.hpp>
9
10#include "../ConversionUtils.hpp"
11
12namespace armnn_driver
13{
14
15inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape &inputShape,
16 const armnn::TensorShape &weightsShape)
17{
18 if (inputShape.GetNumDimensions() > 2U)
19 {
Matthew Benthame9aa4692019-04-24 14:36:20 +010020 unsigned int totalInputElements = inputShape.GetNumElements();
21 unsigned int inputSize = weightsShape[1];
Matthew Benthamf61c2702019-04-23 16:43:27 +010022
Matthew Benthame9aa4692019-04-24 14:36:20 +010023 unsigned int batchSize = totalInputElements / inputSize;
Matthew Benthamf61c2702019-04-23 16:43:27 +010024
Matthew Benthame9aa4692019-04-24 14:36:20 +010025 if(totalInputElements % batchSize != 0)
Matthew Benthamf61c2702019-04-23 16:43:27 +010026 {
27 throw std::runtime_error("Failed to deduce tensor shape");
28 }
29
Matthew Benthame9aa4692019-04-24 14:36:20 +010030 return armnn::TensorShape({batchSize, inputSize});
Matthew Benthamf61c2702019-04-23 16:43:27 +010031 }
32 else
33 {
34 return inputShape;
35 }
36}
37
38}