blob: 1a35903934024e2ac6653bcca08b308b63dde101 [file] [log] [blame]
//
// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "TosaTestUtils.hpp"
using namespace armnn;
using namespace tosa;
void VerifyQuantize(TosaSerializationBasicBlock* quantizeBlock,
std::vector<int32_t> shape,
DType inputDataType = DType_FP32,
DType outputDataType = DType_FP32)
{
std::string blockStr = "Op_QUANTIZE_block_";
CHECK(quantizeBlock->GetName().find(blockStr) != std::string::npos);
CHECK(quantizeBlock->GetInputs().size() == 1);
CHECK(quantizeBlock->GetOutputs().size() == 1);
CHECK(quantizeBlock->GetOperators().size() == 5); // MUL, CONST, ADD, CONST, CAST
CHECK(quantizeBlock->GetTensors().size() == 6);
std::basic_string<char> blockInputName = quantizeBlock->GetInputs()[0];
std::basic_string<char> blockOutputName = quantizeBlock->GetOutputs()[0];
//
// Verify Constants
//
TosaSerializationOperator* constZeroPointOp = quantizeBlock->GetOperators().at(0);
CHECK(constZeroPointOp->GetAttributeType() == Attribute_NONE);
CHECK(constZeroPointOp->GetOp() == tosa::Op_CONST);
TosaSerializationOperator* constScaleOp = quantizeBlock->GetOperators().at(1);
CHECK(constScaleOp->GetAttributeType() == Attribute_NONE);
CHECK(constScaleOp->GetOp() == tosa::Op_CONST);
//
// Verify Multiplication
//
ElementwiseBinaryDescriptor mulDescriptor(BinaryOperation::Mul);
TosaSerializationOperator* mulOp = quantizeBlock->GetOperators().at(2);
CHECK(mulOp->GetAttributeType() == tosa::Attribute_MulAttribute);
CHECK(mulOp->GetOp() == tosa::Op_MUL);
CHECK(mulOp->GetInputTensorNames().size() == 2);
std::basic_string<char> mulInputName0 = mulOp->GetInputTensorNames()[0];
std::basic_string<char> mulInputName1 = mulOp->GetInputTensorNames()[1];
CHECK(blockInputName == mulInputName0);
TosaSerializationTensor* mulInputTensor0 = quantizeBlock->GetTensorByName(mulInputName0);
CHECK(mulInputTensor0->GetDtype() == inputDataType);
CHECK(mulInputTensor0->GetData().size() == 0);
CHECK(mulInputTensor0->GetShape() == shape);
TosaSerializationTensor* mulInputTensor1 = quantizeBlock->GetTensorByName(mulInputName1);
CHECK(mulInputTensor1->GetShape() == shape);
//
// Verify Addition
//
ElementwiseBinaryDescriptor addDescriptor(BinaryOperation::Add);
TosaSerializationOperator* addOp = quantizeBlock->GetOperators().at(3);
CHECK(addOp->GetAttributeType() == Attribute_NONE);
CHECK(addOp->GetOp() == tosa::Op_ADD);
CHECK(addOp->GetInputTensorNames().size() == 2);
std::basic_string<char> addInputName0 = addOp->GetInputTensorNames()[0];
std::basic_string<char> addInputName1 = addOp->GetInputTensorNames()[1];
TosaSerializationTensor* addInputTensor0 = quantizeBlock->GetTensorByName(addInputName0);
CHECK(addInputTensor0->GetDtype() == inputDataType);
CHECK(addInputTensor0->GetData().size() == 0);
CHECK(addInputTensor0->GetShape() == shape);
TosaSerializationTensor* addInputTensor1 = quantizeBlock->GetTensorByName(addInputName1);
CHECK(addInputTensor1->GetShape() == shape);
//
// Verify Cast
//
TosaSerializationOperator* castOp = quantizeBlock->GetOperators().at(4);
CHECK(castOp->GetAttributeType() == Attribute_NONE);
CHECK(castOp->GetOp() == tosa::Op_CAST);
CHECK(castOp->GetInputTensorNames().size() == 1);
CHECK(castOp->GetOutputTensorNames().size() == 1);
std::basic_string<char> castInputName = castOp->GetInputTensorNames()[0];
std::basic_string<char> castOutputName = castOp->GetOutputTensorNames()[0];
TosaSerializationTensor* castInputTensor = quantizeBlock->GetTensorByName(castInputName);
CHECK(castInputTensor->GetDtype() == inputDataType);
CHECK(castInputTensor->GetData().size() == 0);
CHECK(castInputTensor->GetShape() == shape);
TosaSerializationTensor* castOutputTensor = quantizeBlock->GetTensorByName(castOutputName);
CHECK(castOutputTensor->GetDtype() == outputDataType);
CHECK(castOutputTensor->GetData().size() == 0);
CHECK(castOutputTensor->GetShape() == shape);
CHECK(blockOutputName == castOutputName);
}