blob: 287c71ebc7de2c0d925d31a7f201d6646298078a [file] [log] [blame]
Matteo Martincighf02e6cd2019-05-17 12:15:30 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "CommonTestUtils.hpp"
7
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00008#include <armnn/backends/IBackendInternal.hpp>
Matteo Martincighf02e6cd2019-05-17 12:15:30 +01009
10using namespace armnn;
11
Matteo Martincighf02e6cd2019-05-17 12:15:30 +010012SubgraphView::InputSlots CreateInputsFrom(const std::vector<Layer*>& layers)
13{
14 SubgraphView::InputSlots result;
15 for (auto&& layer : layers)
16 {
17 for (auto&& it = layer->BeginInputSlots(); it != layer->EndInputSlots(); ++it)
18 {
19 result.push_back(&(*it));
20 }
21 }
22 return result;
23}
24
25SubgraphView::OutputSlots CreateOutputsFrom(const std::vector<Layer*>& layers)
26{
27 SubgraphView::OutputSlots result;
28 for (auto && layer : layers)
29 {
30 for (auto&& it = layer->BeginOutputSlots(); it != layer->EndOutputSlots(); ++it)
31 {
32 result.push_back(&(*it));
33 }
34 }
35 return result;
36}
37
38SubgraphView::SubgraphViewPtr CreateSubgraphViewFrom(SubgraphView::InputSlots&& inputs,
39 SubgraphView::OutputSlots&& outputs,
40 SubgraphView::Layers&& layers)
41{
42 return std::make_unique<SubgraphView>(std::move(inputs), std::move(outputs), std::move(layers));
43}
44
45armnn::IBackendInternalUniquePtr CreateBackendObject(const armnn::BackendId& backendId)
46{
47 auto& backendRegistry = BackendRegistryInstance();
48 auto backendFactory = backendRegistry.GetFactory(backendId);
49 auto backendObjPtr = backendFactory();
50
51 return backendObjPtr;
52}
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010053
54armnn::TensorShape MakeTensorShape(unsigned int batches,
55 unsigned int channels,
56 unsigned int height,
57 unsigned int width,
58 armnn::DataLayout layout)
59{
60 using namespace armnn;
61 switch (layout)
62 {
63 case DataLayout::NCHW:
64 return TensorShape{ batches, channels, height, width };
65 case DataLayout::NHWC:
66 return TensorShape{ batches, height, width, channels };
67 default:
68 throw InvalidArgumentException(std::string("Unsupported data layout: ") + GetDataLayoutName(layout));
69 }
70}