Kevin May | 1bea6be | 2023-12-12 11:18:46 +0000 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2023 Arm Ltd and Contributors. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "TosaTestUtils.hpp" |
| 7 | |
| 8 | using namespace armnn; |
| 9 | using namespace tosa; |
| 10 | |
| 11 | void VerifySplit(TosaSerializationBasicBlock* splitBlock, |
| 12 | std::vector<std::vector<int32_t>> inputShape, |
| 13 | std::vector<std::vector<int32_t>> outputShape, |
| 14 | const BaseDescriptor& splitDescriptor, |
| 15 | DType dataType = DType_FP32) |
| 16 | { |
| 17 | uint32_t numInputs = static_cast<uint32_t>(inputShape.size()); |
| 18 | uint32_t numOutputs = static_cast<uint32_t>(outputShape.size()); |
| 19 | |
| 20 | std::string blockStr = "Op_SPLIT_block_"; |
| 21 | CHECK(splitBlock->GetName().find(blockStr) != std::string::npos); |
| 22 | CHECK(splitBlock->GetInputs().size() == numInputs); |
| 23 | CHECK(splitBlock->GetOutputs().size() == numOutputs); |
| 24 | CHECK(splitBlock->GetOperators().size() == 3); |
| 25 | CHECK(splitBlock->GetTensors().size() == 4); |
| 26 | |
| 27 | // |
| 28 | // Verify slice operator |
| 29 | // |
| 30 | |
| 31 | for (uint32_t i = 0; i < splitBlock->GetOperators().size(); i++) |
| 32 | { |
| 33 | TosaSerializationOperator *sliceOp = splitBlock->GetOperators().at(i); |
| 34 | uint32_t sliceOpOutputs = 1; |
| 35 | CHECK(sliceOp->GetInputTensorNames().size() == numInputs); |
| 36 | CHECK(sliceOp->GetOutputTensorNames().size() == sliceOpOutputs); |
| 37 | |
| 38 | std::basic_string<char> blockInputName = splitBlock->GetInputs()[0]; |
| 39 | std::basic_string<char> operatorInputName = sliceOp->GetInputTensorNames()[0]; |
| 40 | |
| 41 | std::string opInputStr = "input" + std::to_string(0) + "_"; |
| 42 | |
| 43 | CHECK(blockInputName == operatorInputName); |
| 44 | CHECK(splitBlock->GetTensorByName(blockInputName)); |
| 45 | CHECK(blockInputName.find(opInputStr) != std::string::npos); |
| 46 | |
| 47 | TosaSerializationTensor* inputTensor = splitBlock->GetTensorByName(operatorInputName); |
| 48 | CHECK(inputTensor->GetDtype() == dataType); |
| 49 | CHECK(inputTensor->GetData().size() == 0); |
| 50 | CHECK(inputTensor->GetShape() == inputShape[0]); |
| 51 | |
| 52 | std::basic_string<char> blockOutputName = splitBlock->GetOutputs()[i]; |
| 53 | std::basic_string<char> operatorOutputName = sliceOp->GetOutputTensorNames()[0]; |
| 54 | |
| 55 | std::string opOutputStr = "output" + std::to_string(i) + "_"; |
| 56 | |
| 57 | CHECK(blockOutputName == operatorOutputName); |
| 58 | CHECK(splitBlock->GetTensorByName(blockOutputName)); |
| 59 | CHECK(blockOutputName.find(opOutputStr) != std::string::npos); |
| 60 | |
| 61 | TosaSerializationTensor* outputTensor = splitBlock->GetTensorByName(operatorOutputName); |
| 62 | CHECK(outputTensor->GetDtype() == dataType); |
| 63 | CHECK(outputTensor->GetData().size() == 0); |
| 64 | CHECK(outputTensor->GetShape() == outputShape[0]); |
| 65 | |
| 66 | CHECK(sliceOp->GetAttributeType() == Attribute_SliceAttribute); |
| 67 | CHECK(sliceOp->GetOp() == Op_SLICE); |
| 68 | |
| 69 | VerifyTosaAttribute(splitDescriptor, |
| 70 | sliceOp->GetAttribute(), |
| 71 | inputShape[0], |
| 72 | outputShape[0], |
| 73 | LayerType::Splitter, |
| 74 | i); |
| 75 | } |
| 76 | |
| 77 | } |