blob: cbfac4b36e9575dcd4b1d71bd10d52e0e0f1f770 [file] [log] [blame]
// Copyright (c) 2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "generate_dot_product.h"
namespace
{
//---------------------------------------------------------------------------//
// MatMul //
//---------------------------------------------------------------------------//
void generateMatMulA(const TosaReference::GenerateConfig& cfg,
TosaReference::IDotProductGenerator& generator,
void* data,
size_t size)
{
float* a = reinterpret_cast<float*>(data);
const uint32_t T = cfg.shape[0] * cfg.shape[1] * cfg.shape[2];
const uint32_t C = cfg.shape[2];
for (uint32_t t = 0; t < T; ++t)
{
a[t] = generator(t % C); // k = c
}
}
void generateMatMulB(const TosaReference::GenerateConfig& cfg,
TosaReference::IDotProductGenerator& generator,
void* data,
size_t size)
{
float* b = reinterpret_cast<float*>(data);
const uint32_t T = cfg.shape[0] * cfg.shape[1] * cfg.shape[2];
const uint32_t C = cfg.shape[1];
const uint32_t W = cfg.shape[2];
for (uint32_t t = 0; t < T; ++t)
{
b[t] = generator((t / W) % C); // k = c
}
}
bool generateMatMul(const TosaReference::GenerateConfig& cfg,
TosaReference::IDotProductGenerator& generator,
void* data,
size_t size)
{
if (cfg.dataType != DType::DType_FP32)
{
WARNING("[Generator][DP][MatMul] Only supports FP32.");
return false;
}
if (cfg.shape.size() != 3)
{
WARNING("[Generator][DP][MatMul] Tensor shape expected 3 dimensions.");
return false;
}
if (cfg.inputPos > 1 || cfg.inputPos < 0)
{
WARNING("[Generator][DP][MatMul] Invalid input tensor slot position to operator.");
return false;
}
(cfg.inputPos == 0) ? generateMatMulA(cfg, generator, data, size) : generateMatMulB(cfg, generator, data, size);
return true;
}
} // namespace
namespace TosaReference
{
bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size)
{
auto generator = pickDotProductGenerator(cfg);
if (!generator)
{
WARNING("[Generator][DP] Requested generator could not be created!");
return 0;
}
// Select which generator to use
switch (cfg.opType)
{
case tosa::Op_MATMUL:
return generateMatMul(cfg, *generator, data, size);
default:
WARNING("[Generator][DP] Unsupported operator.");
return false;
}
return false;
}
} // namespace TosaReference