blob: 9ed15bebbf6bb54e49f4942fa5015ddb5a036657 [file] [log] [blame]
surmeh013537c2c2018-05-18 16:31:43 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh013537c2c2018-05-18 16:31:43 +01004//
5#pragma once
6
7#include "LayerWithParameters.hpp"
8
9namespace armnn
10{
11
12class ScopedCpuTensorHandle;
13
14class BatchNormalizationLayer : public LayerWithParameters<BatchNormalizationDescriptor>
15{
16public:
17 std::unique_ptr<ScopedCpuTensorHandle> m_Mean;
18 std::unique_ptr<ScopedCpuTensorHandle> m_Variance;
19 std::unique_ptr<ScopedCpuTensorHandle> m_Beta;
20 std::unique_ptr<ScopedCpuTensorHandle> m_Gamma;
21
22 virtual std::unique_ptr<IWorkload> CreateWorkload(const Graph& graph,
23 const IWorkloadFactory& factory) const override;
24
25 BatchNormalizationLayer* Clone(Graph& graph) const override;
26
27 void ValidateTensorShapesFromInputs() override;
28
29protected:
30 BatchNormalizationLayer(const BatchNormalizationDescriptor& param, const char* name);
31 ~BatchNormalizationLayer() = default;
telsoa01c577f2c2018-08-31 09:22:23 +010032
33 ConstantTensors GetConstantTensorsByRef() override;
surmeh013537c2c2018-05-18 16:31:43 +010034};
35
36} // namespace