blob: edef4a1cf9aa0bd3e27b57e9d68d41e7688b4226 [file] [log] [blame]
Kevin May1bea6be2023-12-12 11:18:46 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "TosaTestUtils.hpp"
7
8using namespace armnn;
9using namespace tosa;
10
11void VerifySplit(TosaSerializationBasicBlock* splitBlock,
12 std::vector<std::vector<int32_t>> inputShape,
13 std::vector<std::vector<int32_t>> outputShape,
14 const BaseDescriptor& splitDescriptor,
15 DType dataType = DType_FP32)
16{
17 uint32_t numInputs = static_cast<uint32_t>(inputShape.size());
18 uint32_t numOutputs = static_cast<uint32_t>(outputShape.size());
19
20 std::string blockStr = "Op_SPLIT_block_";
21 CHECK(splitBlock->GetName().find(blockStr) != std::string::npos);
22 CHECK(splitBlock->GetInputs().size() == numInputs);
23 CHECK(splitBlock->GetOutputs().size() == numOutputs);
24 CHECK(splitBlock->GetOperators().size() == 3);
25 CHECK(splitBlock->GetTensors().size() == 4);
26
27 //
28 // Verify slice operator
29 //
30
31 for (uint32_t i = 0; i < splitBlock->GetOperators().size(); i++)
32 {
33 TosaSerializationOperator *sliceOp = splitBlock->GetOperators().at(i);
34 uint32_t sliceOpOutputs = 1;
35 CHECK(sliceOp->GetInputTensorNames().size() == numInputs);
36 CHECK(sliceOp->GetOutputTensorNames().size() == sliceOpOutputs);
37
38 std::basic_string<char> blockInputName = splitBlock->GetInputs()[0];
39 std::basic_string<char> operatorInputName = sliceOp->GetInputTensorNames()[0];
40
41 std::string opInputStr = "input" + std::to_string(0) + "_";
42
43 CHECK(blockInputName == operatorInputName);
44 CHECK(splitBlock->GetTensorByName(blockInputName));
45 CHECK(blockInputName.find(opInputStr) != std::string::npos);
46
47 TosaSerializationTensor* inputTensor = splitBlock->GetTensorByName(operatorInputName);
48 CHECK(inputTensor->GetDtype() == dataType);
49 CHECK(inputTensor->GetData().size() == 0);
50 CHECK(inputTensor->GetShape() == inputShape[0]);
51
52 std::basic_string<char> blockOutputName = splitBlock->GetOutputs()[i];
53 std::basic_string<char> operatorOutputName = sliceOp->GetOutputTensorNames()[0];
54
55 std::string opOutputStr = "output" + std::to_string(i) + "_";
56
57 CHECK(blockOutputName == operatorOutputName);
58 CHECK(splitBlock->GetTensorByName(blockOutputName));
59 CHECK(blockOutputName.find(opOutputStr) != std::string::npos);
60
61 TosaSerializationTensor* outputTensor = splitBlock->GetTensorByName(operatorOutputName);
62 CHECK(outputTensor->GetDtype() == dataType);
63 CHECK(outputTensor->GetData().size() == 0);
64 CHECK(outputTensor->GetShape() == outputShape[0]);
65
66 CHECK(sliceOp->GetAttributeType() == Attribute_SliceAttribute);
67 CHECK(sliceOp->GetOp() == Op_SLICE);
68
69 VerifyTosaAttribute(splitDescriptor,
70 sliceOp->GetAttribute(),
71 inputShape[0],
72 outputShape[0],
73 LayerType::Splitter,
74 i);
75 }
76
77}