blob: 80f3bb85e900655f5e034540e1bd67eae15f989e [file] [log] [blame]
/*
* Copyright (c) 2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#include "src/runtime/CL/mlgo/MLGOHeuristics.h"
#include "arm_compute/core/Log.h"
#include "src/runtime/CL/mlgo/MLGOParser.h"
#include "src/runtime/CL/mlgo/Utils.h"
#include <fstream>
namespace arm_compute
{
namespace mlgo
{
bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs)
{
return std::tie(lhs.m0, lhs.n0, lhs.k0) == std::tie(rhs.m0, rhs.n0, rhs.k0);
}
bool operator==(const GEMMConfigReshapedOnlyRHS &lhs, const GEMMConfigReshapedOnlyRHS &rhs)
{
return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.h0, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.h0, rhs.interleave_rhs, rhs.transpose_rhs,
rhs.export_cl_image);
}
bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs)
{
return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.v0, lhs.h0, lhs.interleave_lhs, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.v0, rhs.h0,
rhs.interleave_lhs, rhs.interleave_rhs, rhs.transpose_rhs, rhs.export_cl_image);
}
constexpr size_t MLGOHeuristics::_max_num_trees;
MLGOHeuristics::MLGOHeuristics()
: _indices{}, _trees{}, _tree_valid{}, _valid{ false }
{
}
std::pair<bool, GEMMType> MLGOHeuristics::query_gemm_type(const Query &query) const
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm type. %s.", to_string(query).c_str());
const auto invalid = GEMMType::RESHAPED;
if(!_valid)
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
return { false, invalid };
}
auto index = std::make_tuple(HeuristicType::GEMM_Type, query.ip_target, query.data_type);
GEMMShape shape_query{ query.m, query.n, query.k, query.b };
if(_trees.find(index) == _trees.end())
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
return { false, invalid };
}
return _trees.at(index).query<GEMMType>(shape_query);
}
std::pair<bool, GEMMConfigNative> MLGOHeuristics::query_gemm_config_native(const Query &query) const
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config native. %s.", to_string(query).c_str());
const auto invalid = GEMMConfigNative{};
if(!_valid)
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
return { false, invalid };
}
auto index = std::make_tuple(HeuristicType::GEMM_Config_Native, query.ip_target, query.data_type);
GEMMShape shape_query{ query.m, query.n, query.k, query.b };
if(_trees.find(index) == _trees.end())
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
return { false, invalid };
}
return _trees.at(index).query<GEMMConfigNative>(shape_query);
}
std::pair<bool, GEMMConfigReshapedOnlyRHS> MLGOHeuristics::query_gemm_config_reshaped_only_rhs(const Query &query) const
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped only rhs. %s.", to_string(query).c_str());
const auto invalid = GEMMConfigReshapedOnlyRHS{};
if(!_valid)
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
return { false, invalid };
}
auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped_Only_RHS, query.ip_target, query.data_type);
GEMMShape shape_query{ query.m, query.n, query.k, query.b };
if(_trees.find(index) == _trees.end())
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
return { false, invalid };
}
return _trees.at(index).query<GEMMConfigReshapedOnlyRHS>(shape_query);
}
std::pair<bool, GEMMConfigReshaped> MLGOHeuristics::query_gemm_config_reshaped(const Query &query) const
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped. %s.", to_string(query).c_str());
const auto invalid = GEMMConfigReshaped{};
if(!_valid)
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
return { false, invalid };
}
auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped, query.ip_target, query.data_type);
GEMMShape shape_query{ query.m, query.n, query.k, query.b };
if(_trees.find(index) == _trees.end())
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
return { false, invalid };
}
return _trees.at(index).query<GEMMConfigReshaped>(shape_query);
}
bool MLGOHeuristics::check_heuristic_tree(HeuristicTree::TreeID id)
{
bool status;
HeuristicTree *tree{ nullptr };
std::tie(status, tree) = get_heuristic_tree(id);
if(!status)
{
return status;
}
status = tree->check();
if(!status)
{
return status;
}
_tree_valid[id] = true;
return true;
}
bool MLGOHeuristics::check_all() const
{
// Tree validities are already checked and cached.
bool all_trees_are_checked = std::find_if(_tree_valid.begin(), _tree_valid.end(), [](auto v)
{
return !v.second;
})
== _tree_valid.end();
if(!all_trees_are_checked)
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Missing checks on some trees. Make sure to call check_heuristic_tree after each tree is completed. This could also indicate there are no trees in the dotmlgo");
return false;
}
// Other top level checks...
return true;
}
std::pair<bool, HeuristicTree *> MLGOHeuristics::get_heuristic_tree(HeuristicTree::TreeID id)
{
if(_indices.find(id) == _indices.end())
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot find tree with id %zu", id);
return std::make_pair(false, nullptr);
}
const auto index = _indices[id];
if(_trees.find(index) == _trees.end())
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
return std::make_pair(false, nullptr);
}
auto &t = _trees[index];
return std::make_pair(true, &t);
}
bool MLGOHeuristics::add_heuristic_tree(HeuristicTree &&t)
{
if(_indices.size() >= _max_num_trees)
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the max number of trees allowed: %zu", _max_num_trees);
return false;
}
// PRE: correctness of t is guaranteed by the tree construction process
// Ensure unique id
const auto id = t.id();
if(_indices.find(id) != _indices.end())
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add redundant trees; tree id %zu already exists", id);
return false;
}
// Ensure unique index
const auto index = t.index();
if(_trees.find(index) != _trees.end())
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot add redundant trees; tree index already exists");
return false;
}
_indices[id] = index;
_trees[index] = std::move(t);
_tree_valid[id] = false;
return true;
}
bool MLGOHeuristics::reload_from_file(const std::string &filename)
{
std::ifstream fs;
fs.exceptions(std::ifstream::badbit);
fs.open(filename, std::ios::in);
if(!fs.is_open())
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot open DotMLGO file %s. Use default heuristics instead", filename.c_str());
return _valid = false;
}
return reload_from_stream(fs);
}
bool MLGOHeuristics::reload_from_stream(std::istream &in)
{
auto parsed = parser::parse_mlgo(in);
if(!parsed.first)
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO parsing failed. Use default heuristics instead");
return _valid = false;
}
*this = std::move(parsed.second);
ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO loaded successfully");
return _valid = true;
}
} // namespace mlgo
} // namespace arm_compute