blob: 741cd79a057b368832df7645a3f62fc03bf4a863 [file] [log] [blame]
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +01001// Copyright (c) 2023, ARM Limited.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "generate.h"
16
17#include "generate_dot_product.h"
Jeremy Johnsond41feb72023-10-12 16:03:15 +010018#include "generate_pseudo_random.h"
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010019#include "generate_utils.h"
20
21#include "func_debug.h"
22#include "model_common.h"
23
24namespace TosaReference
25{
26
27bool generate(const GenerateConfig& cfg, void* data, size_t size)
28{
29 switch (cfg.generatorType)
30 {
31 case GeneratorType::DotProduct: {
32 return generateDotProduct(cfg, data, size);
33 break;
34 }
Jeremy Johnsond41feb72023-10-12 16:03:15 +010035 case GeneratorType::PseudoRandom: {
36 return generatePseudoRandom(cfg, data, size);
37 break;
38 }
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010039 default: {
40 WARNING("[Generator] Unsupported generation mode.");
41 break;
42 }
43 }
44 return false;
45}
46
47} // namespace TosaReference
48
49extern "C"
50{
51 bool tgd_generate_data(const char* config_json, const char* tensor_name, void* data, size_t size)
52 {
53 // Check inputs for nullptr
54 if (!config_json || !tensor_name || !data)
55 {
56 WARNING("[Generator] One of the inputs is missing.");
57 return false;
58 }
59
60 // Check JSON config validity
61 auto cfg = TosaReference::parseGenerateConfig(config_json, tensor_name);
62 if (!cfg)
63 {
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010064 return false;
65 }
66
67 // Check size
68 const size_t totalBytesNeeded =
69 TosaReference::numElementsFromShape(cfg->shape) * TosaReference::elementSizeFromType(cfg->dataType);
70 if (totalBytesNeeded > size)
71 {
72 WARNING("[Generator] Not enough space in provided buffer.");
73 return false;
74 }
75
76 // Run generator
77 return generate(cfg.value(), data, size);
78 }
79} // extern "C"