Enable lazy tensor allocation

- The previous ref_model was allocating the memory for all tensors in
the graph upfront which is unnecessary and wasteful.
- This patch changes to only allocate initial input tensors on the entry
point using the allocateInputTensor() function.
- The output tensors are ensured to have been allocated before executing
a node. The output tenosrs are the inputs for the next node.
- When a node's evaluation is finished, its input tensors will be freed
if they will no longer be used by anyone else.

Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: Ibb3e8c9e6344f6cd9eb20532a03b2097b93247f9
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index 0c86cbd..070eb33 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -109,9 +109,9 @@
             goto done;
         }
 
-        if (main_gt.allocateTensor())
+        if (main_gt.allocateInputTensors())
         {
-            WARNING("Failed to allocate tensor. Evaluation aborted.");
+            WARNING("Failed to allocate input tensors. Evaluation aborted.");
             goto done;
         }
 
diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc
index ce548e9..be97644 100644
--- a/reference_model/src/model_runner_impl.cc
+++ b/reference_model/src/model_runner_impl.cc
@@ -327,9 +327,9 @@
         return _main_gt->getGraphStatus();
     }
 
-    if (_main_gt->allocateTensor())
+    if (_main_gt->allocateInputTensors())
     {
-        WARNING("Failed to allocate tensor.");
+        WARNING("Failed to allocate input tensors.");
         return _main_gt->getGraphStatus();
     }
 
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index 0afb7e2..6bbc587 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -43,7 +43,8 @@
     ERROR_IF(block_sgt.linkTensorsAndNodes(), "evalBlock(): Failed to link tensors and nodes for %s",
              block_name.c_str());
     ERROR_IF(block_sgt.validateGraph(), "evalBlock(): Failed to validate subgraph for %s", block_name.c_str());
-    ERROR_IF(block_sgt.allocateTensor(), "evalBlock(): Failed to allocate tensor for %s", block_name.c_str());
+    ERROR_IF(block_sgt.allocateInputTensors(), "evalBlock(): Failed to allocate input tensors for %s",
+             block_name.c_str());
 
     int num_input_tensors  = block_sgt.getNumInputTensors();
     int num_output_tensors = block_sgt.getNumOutputTensors();
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index a7ef5e9..5675be9 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -397,146 +397,156 @@
     return 0;
 }
 
-int SubgraphTraverser::allocateTensor()
+int SubgraphTraverser::allocateInputTensors()
 {
-    for (auto ts : block->GetTensors())
+    auto input_tensor_names_vec = block->GetInputs();
+
+    for (auto input_tensor_name : input_tensor_names_vec)
     {
-        // Bail out if tensor is used and any of its dimension is invalid.
-        auto got = used_tensor_name_set.find(ts->GetName());
-        if (got != used_tensor_name_set.end())
+        this->allocateTensor(input_tensor_name);
+    }
+
+    return 0;
+}
+
+int SubgraphTraverser::allocateTensor(std::string name)
+{
+    auto ts = block->GetTensorByName(name);
+
+    // Bail out if tensor is used and any of its dimension is invalid.
+    auto got = used_tensor_name_set.find(ts->GetName());
+    if (got != used_tensor_name_set.end())
+    {
+        uint32_t elements = 1;
+        for (auto& dim : ts->GetShape())
         {
-            uint32_t elements = 1;
-            for (auto& dim : ts->GetShape())
+            if (dim <= 0)
             {
-                if (dim <= 0)
-                {
-                    DEBUG_INFO(GT, "Failed to allocate tensor %s with invalid dimension of %d", ts->GetName().c_str(),
-                               dim);
-                    this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
-                    return 1;
-                }
-                if (dim > static_cast<int32_t>(TOSA_MAX_TENSOR_SIZE / elements))
-                {
-                    // Size greather than maximum defined in spec
-                    DEBUG_INFO(GT, "Tensor %s size is greater than allowed maximum", ts->GetName().c_str());
-                    this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
-                    return 1;
-                }
+                DEBUG_INFO(GT, "Failed to allocate tensor %s with invalid dimension of %d", ts->GetName().c_str(), dim);
+                this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
+                return 1;
+            }
+            if (dim > static_cast<int32_t>(TOSA_MAX_TENSOR_SIZE / elements))
+            {
+                // Size greather than maximum defined in spec
+                DEBUG_INFO(GT, "Tensor %s size is greater than allowed maximum", ts->GetName().c_str());
+                this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
+                return 1;
             }
         }
+    }
 
-        TosaReference::Tensor* tensor = findTensorByName(ts->GetName());
-        SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateTensor(): can't find tensor %s.", ts->GetName().c_str());
+    TosaReference::Tensor* tensor = findTensorByName(ts->GetName());
+    SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateTensor(): can't find tensor %s.", ts->GetName().c_str());
 
-        DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str());
-        if (tensor->allocate())
+    DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str());
+    if (tensor->allocate())
+    {
+        FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
+    }
+
+    if (!ts->GetData().empty())
+    {
+        DEBUG_INFO(GT, "Setting data for tensor %s", tensor->getName().c_str());
+        auto serialization_dtype = ts->GetDtype();
+        switch (serialization_dtype)
         {
-            FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
-        }
-
-        if (!ts->GetData().empty())
-        {
-            DEBUG_INFO(GT, "Setting data for tensor %s", tensor->getName().c_str());
-            auto serialization_dtype = ts->GetDtype();
-            switch (serialization_dtype)
-            {
-                case DType_INT4: {
-                    std::vector<int8_t> i4_data;
-                    TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data);
-                    std::vector<int32_t> i32_data(i4_data.begin(), i4_data.end());
-                    tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
-                }
-                break;
-                case DType_INT8: {
-                    std::vector<int8_t> i8_data;
-                    TosaSerializationHandler::ConvertU8toI8(ts->GetData(), tensor->getElementCount(), i8_data);
-                    std::vector<int32_t> i32_data(i8_data.begin(), i8_data.end());
-                    tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
-                }
-                break;
-                case DType_INT16: {
-                    std::vector<int16_t> i16_data;
-                    TosaSerializationHandler::ConvertU8toI16(ts->GetData(), tensor->getElementCount(), i16_data);
-                    std::vector<int32_t> i32_data(i16_data.begin(), i16_data.end());
-                    tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
-                }
-                break;
-                case DType_INT32: {
-                    std::vector<int32_t> i32_data;
-                    TosaSerializationHandler::ConvertU8toI32(ts->GetData(), tensor->getElementCount(), i32_data);
-                    tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
-                }
-                break;
-                case DType_INT48: {
-                    std::vector<int64_t> i64_data;
-                    TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data);
-                    tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
-                }
-                break;
-                case DType_FP16: {
-                    // Interpret f16 data as float
-                    std::vector<float> f16_data;
-                    TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data);
-                    if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
-                    {
-                        std::vector<double> f64_data(f16_data.begin(), f16_data.end());
-                        tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
-                    }
-                    else
-                    {
-                        tensor->setTensorValueFloat(f16_data.size(), f16_data.data());
-                    }
-                }
-                break;
-                case DType_BF16: {
-                    std::vector<float> fp32_data;
-                    TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
-                    // Ensure valid bfloat16 stored in each float
-                    for (auto f : fp32_data)
-                        ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f);
-                    if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
-                    {
-                        std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
-                        tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
-                    }
-                    else
-                    {
-                        tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
-                    }
-                }
-                break;
-                case DType_FP32: {
-                    std::vector<float> fp32_data;
-                    TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
-                    if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
-                    {
-                        std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
-                        tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
-                    }
-                    else
-                    {
-                        tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
-                    }
-                }
-                break;
-                case DType_BOOL: {
-                    std::vector<bool> bool_data;
-                    TosaSerializationHandler::ConvertU8toBool(ts->GetData(), tensor->getElementCount(), bool_data);
-
-                    // std::vector<bool>::data() will return bit mask instead of array of bool array.
-                    // Need to translate manually.
-                    bool* bool_array = (bool*)calloc(bool_data.size(), sizeof(bool));
-                    for (size_t i = 0; i < bool_data.size(); i++)
-                    {
-                        bool_array[i] = bool_data[i];
-                    }
-                    tensor->setTensorValueBool(bool_data.size(), bool_array);
-                }
-                break;
-                default:
-                    SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.",
-                                      EnumNameDType(ts->GetDtype()));
+            case DType_INT4: {
+                std::vector<int8_t> i4_data;
+                TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data);
+                std::vector<int32_t> i32_data(i4_data.begin(), i4_data.end());
+                tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
             }
+            break;
+            case DType_INT8: {
+                std::vector<int8_t> i8_data;
+                TosaSerializationHandler::ConvertU8toI8(ts->GetData(), tensor->getElementCount(), i8_data);
+                std::vector<int32_t> i32_data(i8_data.begin(), i8_data.end());
+                tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
+            }
+            break;
+            case DType_INT16: {
+                std::vector<int16_t> i16_data;
+                TosaSerializationHandler::ConvertU8toI16(ts->GetData(), tensor->getElementCount(), i16_data);
+                std::vector<int32_t> i32_data(i16_data.begin(), i16_data.end());
+                tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
+            }
+            break;
+            case DType_INT32: {
+                std::vector<int32_t> i32_data;
+                TosaSerializationHandler::ConvertU8toI32(ts->GetData(), tensor->getElementCount(), i32_data);
+                tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
+            }
+            break;
+            case DType_INT48: {
+                std::vector<int64_t> i64_data;
+                TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data);
+                tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
+            }
+            break;
+            case DType_FP16: {
+                // Interpret f16 data as float
+                std::vector<float> f16_data;
+                TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data);
+                if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
+                {
+                    std::vector<double> f64_data(f16_data.begin(), f16_data.end());
+                    tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
+                }
+                else
+                {
+                    tensor->setTensorValueFloat(f16_data.size(), f16_data.data());
+                }
+            }
+            break;
+            case DType_BF16: {
+                std::vector<float> fp32_data;
+                TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
+                // Ensure valid bfloat16 stored in each float
+                for (auto f : fp32_data)
+                    ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f);
+                if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
+                {
+                    std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
+                    tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
+                }
+                else
+                {
+                    tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+                }
+            }
+            break;
+            case DType_FP32: {
+                std::vector<float> fp32_data;
+                TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
+                if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
+                {
+                    std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
+                    tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
+                }
+                else
+                {
+                    tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+                }
+            }
+            break;
+            case DType_BOOL: {
+                std::vector<bool> bool_data;
+                TosaSerializationHandler::ConvertU8toBool(ts->GetData(), tensor->getElementCount(), bool_data);
+
+                // std::vector<bool>::data() will return bit mask instead of array of bool array.
+                // Need to translate manually.
+                bool* bool_array = (bool*)calloc(bool_data.size(), sizeof(bool));
+                for (size_t i = 0; i < bool_data.size(); i++)
+                {
+                    bool_array[i] = bool_data[i];
+                }
+                tensor->setTensorValueBool(bool_data.size(), bool_array);
+            }
+            break;
+            default:
+                SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.",
+                                  EnumNameDType(ts->GetDtype()));
         }
     }
 
@@ -593,11 +603,13 @@
     for (auto tensor : currNode->getOutputs())
     {
         if (!tensor->is_allocated())
-            if (tensor->allocate())
+        {
+            if (this->allocateTensor(tensor->getName()))
             {
                 FATAL_ERROR("SubgraphTraverser::evaluateNextNode(): Failed to allocate Eigen tensor %s",
                             tensor->getName().c_str());
             }
+        }
     }
 
     if (currNode->eval())
@@ -607,41 +619,40 @@
     }
 
     // free input tensor if all of its consumers have all of their outputs ready and it's not block's output
-    if (!currNode->getInMainBlock())
-    {    // we don't free it if the node is in main block and has nested blocks
-        for (auto tensor : currNode->getInputs())
+    for (auto tensor : currNode->getInputs())
+    {
+        bool in_use = false;
+
+        auto tensor_check = findTensorByName(tensor->getName());
+        if (tensor_check->getIsParentGraphOutput())
         {
-            bool in_use = false;
+            // if it's parent's block output tensor, we can't free it
+            continue;
+        }
 
-            auto tensor_check = findTensorByName(tensor->getName());
-            if (tensor_check->getIsParentGraphOutput())
+        for (auto node : tensor->getConsumers())
+        {
+            // If the node is inside a loop, the input tensor is still needed
+            if (!node->hasAllOutputsReady())
             {
-                // if it's parent's block output tensor, we can't free it
-                continue;
-            }
-
-            for (auto node : tensor->getConsumers())
-            {
-                // If the node is inside a loop, the input tensor is still needed
-                if (!node->hasAllOutputsReady())
-                {
-                    in_use = true;
-                }
-            }
-            for (auto name : block->GetOutputs())
-            {
-                if (name == tensor->getName())
-                {
-                    in_use = true;
-                }
-            }
-
-            if (!in_use)
-            {
-                tensor->deallocate();
+                in_use = true;
             }
         }
+
+        for (auto name : block->GetOutputs())
+        {
+            if (name == tensor->getName())
+            {
+                in_use = true;
+            }
+        }
+
+        if (!in_use)
+        {
+            tensor->deallocate();
+        }
     }
+
     // Search the output tensors of this node to see if
     // there are now new ready nodes available from completing this node
     for (TosaReference::Tensor* tensor : currNode->getOutputs())
diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h
index 00989ee..ef6ea42 100644
--- a/reference_model/src/subgraph_traverser.h
+++ b/reference_model/src/subgraph_traverser.h
@@ -49,7 +49,8 @@
 
     int linkTensorsAndNodes();
     int validateGraph();
-    int allocateTensor();
+    int allocateInputTensors();
+    int allocateTensor(std::string name);
 
     int dumpGraph(FILE* out) const;
     int dumpNextNodeList(FILE* out) const;