blob: cbfac4b36e9575dcd4b1d71bd10d52e0e0f1f770 [file] [log] [blame]
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +01001// Copyright (c) 2023, ARM Limited.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "generate_dot_product.h"
16
17namespace
18{
19//---------------------------------------------------------------------------//
20// MatMul //
21//---------------------------------------------------------------------------//
22
23void generateMatMulA(const TosaReference::GenerateConfig& cfg,
24 TosaReference::IDotProductGenerator& generator,
25 void* data,
26 size_t size)
27{
28 float* a = reinterpret_cast<float*>(data);
29 const uint32_t T = cfg.shape[0] * cfg.shape[1] * cfg.shape[2];
30 const uint32_t C = cfg.shape[2];
31
32 for (uint32_t t = 0; t < T; ++t)
33 {
34 a[t] = generator(t % C); // k = c
35 }
36}
37
38void generateMatMulB(const TosaReference::GenerateConfig& cfg,
39 TosaReference::IDotProductGenerator& generator,
40 void* data,
41 size_t size)
42{
43 float* b = reinterpret_cast<float*>(data);
44 const uint32_t T = cfg.shape[0] * cfg.shape[1] * cfg.shape[2];
45 const uint32_t C = cfg.shape[1];
46 const uint32_t W = cfg.shape[2];
47
48 for (uint32_t t = 0; t < T; ++t)
49 {
50 b[t] = generator((t / W) % C); // k = c
51 }
52}
53
54bool generateMatMul(const TosaReference::GenerateConfig& cfg,
55 TosaReference::IDotProductGenerator& generator,
56 void* data,
57 size_t size)
58{
Jeremy Johnsond41feb72023-10-12 16:03:15 +010059 if (cfg.dataType != DType::DType_FP32)
60 {
61 WARNING("[Generator][DP][MatMul] Only supports FP32.");
62 return false;
63 }
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010064 if (cfg.shape.size() != 3)
65 {
66 WARNING("[Generator][DP][MatMul] Tensor shape expected 3 dimensions.");
67 return false;
68 }
69 if (cfg.inputPos > 1 || cfg.inputPos < 0)
70 {
71 WARNING("[Generator][DP][MatMul] Invalid input tensor slot position to operator.");
72 return false;
73 }
74
75 (cfg.inputPos == 0) ? generateMatMulA(cfg, generator, data, size) : generateMatMulB(cfg, generator, data, size);
76
77 return true;
78}
79} // namespace
80
81namespace TosaReference
82{
83
84bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size)
85{
86 auto generator = pickDotProductGenerator(cfg);
87 if (!generator)
88 {
89 WARNING("[Generator][DP] Requested generator could not be created!");
90 return 0;
91 }
92
93 // Select which generator to use
94 switch (cfg.opType)
95 {
96 case tosa::Op_MATMUL:
97 return generateMatMul(cfg, *generator, data, size);
98 default:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +010099 WARNING("[Generator][DP] Unsupported operator.");
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100100 return false;
101 }
102
103 return false;
104}
105} // namespace TosaReference