blob: 3b89c229c38d3ee34ccf1c7089fb5decc2928b65 [file] [log] [blame]
Les Bell033626d2018-09-03 16:24:52 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
Les Bell033626d2018-09-03 16:24:52 +01004//
5#include "ArithmeticBaseLayer.hpp"
6
7#include "InternalTypes.hpp"
8#include "armnn/Exceptions.hpp"
9#include <armnn/TypesUtils.hpp>
10
11#include <boost/assert.hpp>
12
13namespace armnn
14{
15
16ArithmeticBaseLayer::ArithmeticBaseLayer(unsigned int numInputSlots, unsigned int numOutputSlots,
17 LayerType type, const char* name)
18 : Layer(numInputSlots, numOutputSlots, type, name)
19{
20}
21
22std::vector<TensorShape> ArithmeticBaseLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
23{
24 BOOST_ASSERT(inputShapes.size() == 2);
25 auto& input0 = inputShapes[0];
26 auto& input1 = inputShapes[1];
27
28 // Get the max of the inputs.
29 BOOST_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions());
30 unsigned int numDims = input0.GetNumDimensions();
31 std::vector<unsigned int> dims(numDims);
32
33 for (unsigned int i = 0; i < numDims; i++)
34 {
35 unsigned int dim0 = input0[i];
36 unsigned int dim1 = input1[i];
37
38#if !NDEBUG
39 // Validate inputs are broadcast compatible.
40 BOOST_ASSERT_MSG(dim0 == dim1 || dim0 == 1 || dim1 == 1,
41 "Dimensions should either match or one should be of size 1.");
42#endif
43
44 dims[i] = std::max(dim0, dim1);
45 }
46
47 return std::vector<TensorShape>({ TensorShape(numDims, dims.data()) });
48}
49
50void ArithmeticBaseLayer::ValidateTensorShapesFromInputs()
51{
52 VerifyLayerConnections(2, CHECK_LOCATION());
53
54 auto inferredShapes = InferOutputShapes({
55 GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
56 GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()
57 });
58
59 BOOST_ASSERT(inferredShapes.size() == 1);
60
61 std::string msg = GetLayerTypeAsCString(GetType());
62 msg += "Layer: TensorShape set on OutputSlot[0] does not match the inferred shape.";
63 ConditionalThrowIfNotEqual<LayerValidationException>(msg,
64 GetOutputSlot(0).GetTensorInfo().GetShape(),
65 inferredShapes[0]);
66}
67
68} // namespace armnn