blob: d7cc5d8f91b812dfd8bfb1bf6d28e925860031eb [file] [log] [blame]
Cathal Corbett9c9d5b92022-08-17 17:30:16 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <Layer.hpp>
7
8#include <tosaCommon/TosaMappings.hpp>
9
10#include <doctest/doctest.h>
11
12using namespace armnn;
13using namespace tosa;
14
15void AssertTosaOneToOneMappingBasicBlock(TosaSerializationBasicBlock* basicBlock,
16 std::vector<int32_t> shape,
17 uint32_t numInputs,
18 uint32_t numOutputs,
19 Op tosaOp,
20 std::string operatorString,
Matthew Sloyanda824cc2022-10-10 12:43:20 +010021 DType dataType = DType_FP32)
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010022{
23 std::string blockStr = operatorString + "_block_";
Matthew Sloyane497ed12022-10-10 15:41:19 +010024 CHECK(basicBlock->GetName().find(blockStr) != std::string::npos);
25 CHECK(basicBlock->GetInputs().size() == numInputs);
26 CHECK(basicBlock->GetOutputs().size() == numOutputs);
27 CHECK(basicBlock->GetOperators().size() == 1);
28 CHECK(basicBlock->GetTensors().size() == (numInputs + numOutputs));
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010029
30 TosaSerializationOperator* op = basicBlock->GetOperators().at(0);
Matthew Sloyane497ed12022-10-10 15:41:19 +010031 CHECK(op->GetInputTensorNames().size() == numInputs);
32 CHECK(op->GetOutputTensorNames().size() == numOutputs);
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010033
34 for (uint32_t i = 0; i < numInputs; i++)
35 {
36 std::basic_string<char> blockInputName = basicBlock->GetInputs()[i];
37 std::basic_string<char> operatorInputName = op->GetInputTensorNames()[i];
38 std::basic_string<char> tensorName = basicBlock->GetTensors()[i]->GetName();
39
40 std::string opStr = operatorString + "_input" + std::to_string(i) + "_";
41
Matthew Sloyane497ed12022-10-10 15:41:19 +010042 CHECK(blockInputName == operatorInputName);
43 CHECK(tensorName == operatorInputName);
44 CHECK(blockInputName.find(opStr) != std::string::npos);
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010045 }
46
47 for (uint32_t i = 0; i < numOutputs; i++)
48 {
49 std::basic_string<char> blockOutputName = basicBlock->GetOutputs()[i];
50 std::basic_string<char> operatorOutputName = op->GetOutputTensorNames()[i];
51 std::basic_string<char> tensorName = basicBlock->GetTensors()[numInputs + i]->GetName();
52
53 std::string opStr = operatorString + "_output" + std::to_string(i) + "_";
54
Matthew Sloyane497ed12022-10-10 15:41:19 +010055 CHECK(blockOutputName == operatorOutputName);
56 CHECK(tensorName == operatorOutputName);
57 CHECK(blockOutputName.find(opStr) != std::string::npos);
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010058 }
59
Matthew Sloyane497ed12022-10-10 15:41:19 +010060 CHECK(op->GetAttributeType() == Attribute_NONE);
61 CHECK(op->GetOp() == tosaOp);
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010062
63 TosaSerializationTensor* tensor0 = basicBlock->GetTensors()[0];
Matthew Sloyane497ed12022-10-10 15:41:19 +010064 CHECK(tensor0->GetDtype() == dataType);
65 CHECK(tensor0->GetData().size() == 0);
66 CHECK(tensor0->GetShape() == shape);
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010067}
68
69TEST_SUITE("TosaOperatorMappingOneToOneTests")
70{
71TEST_CASE("GetTosaMapping_AdditionLayer")
72{
73 TensorInfo info = TensorInfo({ 1, 2, 4, 2 }, DataType::Float32, 0.0f, 0, true);
74 TosaSerializationBasicBlock* basicBlock =
Matthew Sloyan5c54c382022-11-09 16:28:51 +000075 GetTosaMapping(LayerType::Addition, {&info, &info}, {&info}, BaseDescriptor(), false);
Cathal Corbett9c9d5b92022-08-17 17:30:16 +010076 AssertTosaOneToOneMappingBasicBlock(basicBlock, { 1, 2, 4, 2 }, 2, 1, Op::Op_ADD, "Op_ADD");
77}
78
79TEST_CASE("GetTosaMappingFromLayer_AdditionLayer")
80{
81 IRuntime::CreationOptions options;
82 IRuntimePtr runtime(IRuntime::Create(options));
83
84 // Builds up the structure of the network.
85 INetworkPtr net(INetwork::Create());
86
87 IConnectableLayer* input0 = net->AddInputLayer(0, "input0");
88 IConnectableLayer* input1 = net->AddInputLayer(1, "input1");
89 IConnectableLayer* add = net->AddAdditionLayer("add");
90 IConnectableLayer* output = net->AddOutputLayer(0, "output");
91
92 input0->GetOutputSlot(0).Connect(add->GetInputSlot(0));
93 input1->GetOutputSlot(0).Connect(add->GetInputSlot(1));
94 add->GetOutputSlot(0).Connect(output->GetInputSlot(0));
95
96 TensorInfo info = TensorInfo({ 1, 2, 4, 2 }, DataType::Float32, 0.0f, 0, true);
97
98 input0->GetOutputSlot(0).SetTensorInfo(info);
99 input1->GetOutputSlot(0).SetTensorInfo(info);
100 add->GetOutputSlot(0).SetTensorInfo(info);
101
102 TosaSerializationBasicBlock* basicBlock =
Matthew Sloyan5c54c382022-11-09 16:28:51 +0000103 GetTosaMappingFromLayer(PolymorphicDowncast<Layer*>(add), false);
Cathal Corbett9c9d5b92022-08-17 17:30:16 +0100104 AssertTosaOneToOneMappingBasicBlock(basicBlock, { 1, 2, 4, 2 }, 2, 1, Op::Op_ADD, "Op_ADD");
105}
106
107TEST_CASE("GetTosaMapping_Unimplemented")
108{
109 TosaSerializationBasicBlock* basicBlock =
Matthew Sloyan5c54c382022-11-09 16:28:51 +0000110 GetTosaMapping(LayerType::UnidirectionalSequenceLstm, {}, {}, BaseDescriptor(), false);
Cathal Corbett9c9d5b92022-08-17 17:30:16 +0100111
Matthew Sloyane497ed12022-10-10 15:41:19 +0100112 CHECK(basicBlock->GetName() == "");
113 CHECK(basicBlock->GetTensors().size() == 0);
114 CHECK(basicBlock->GetOperators().size() == 1);
115 CHECK(basicBlock->GetInputs().size() == 0);
116 CHECK(basicBlock->GetOutputs().size() == 0);
Cathal Corbett9c9d5b92022-08-17 17:30:16 +0100117
118 TosaSerializationOperator* op = basicBlock->GetOperators()[0];
Matthew Sloyane497ed12022-10-10 15:41:19 +0100119 CHECK(op->GetAttributeType() == Attribute_NONE);
120 CHECK(op->GetOp() == tosa::Op_UNKNOWN);
121 CHECK(op->GetInputTensorNames().size() == 0);
122 CHECK(op->GetOutputTensorNames().size() == 0);
Cathal Corbett9c9d5b92022-08-17 17:30:16 +0100123}
124}