blob: 56997ad240f2bf1e118909697346e211f8549cb5 [file] [log] [blame]
Matthew Benthamf61c2702019-04-23 16:43:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Tensor.hpp>
9
10#include "../ConversionUtils.hpp"
11
12namespace armnn_driver
13{
14
FinnWilliamsArm7b8d2e62020-01-08 14:57:47 +000015inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape& inputShape,
16 const armnn::TensorShape& weightsShape)
Matthew Benthamf61c2702019-04-23 16:43:27 +010017{
18 if (inputShape.GetNumDimensions() > 2U)
19 {
Matthew Benthame9aa4692019-04-24 14:36:20 +010020 unsigned int totalInputElements = inputShape.GetNumElements();
21 unsigned int inputSize = weightsShape[1];
Matthew Benthamf61c2702019-04-23 16:43:27 +010022
Matthew Benthame9aa4692019-04-24 14:36:20 +010023 unsigned int batchSize = totalInputElements / inputSize;
Matthew Benthamf61c2702019-04-23 16:43:27 +010024
Matthew Benthame9aa4692019-04-24 14:36:20 +010025 if(totalInputElements % batchSize != 0)
Matthew Benthamf61c2702019-04-23 16:43:27 +010026 {
27 throw std::runtime_error("Failed to deduce tensor shape");
28 }
29
Matthew Benthame9aa4692019-04-24 14:36:20 +010030 return armnn::TensorShape({batchSize, inputSize});
Matthew Benthamf61c2702019-04-23 16:43:27 +010031 }
32 else
33 {
34 return inputShape;
35 }
36}
37
FinnWilliamsArm7b8d2e62020-01-08 14:57:47 +000038inline bool VerifyFullyConnectedShapes(const armnn::TensorShape& inputShape,
39 const armnn::TensorShape& weightsShape,
40 const armnn::TensorShape& outputShape,
41 bool transposeWeightMatrix)
42{
43 unsigned int dimIdx = transposeWeightMatrix ? 0 : 1;
44 return (inputShape[0] == outputShape[0] && weightsShape[dimIdx] == outputShape[1]);
45}
46
Matthew Benthamf61c2702019-04-23 16:43:27 +010047}