blob: 2d5b7f8cfc114d0f800c87c88f11a969ef82bcae [file] [log] [blame]
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +01001// Copyright (c) 2023, ARM Limited.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef GENERATE_UTILS_H_
16#define GENERATE_UTILS_H_
17
18#include "dtype.h"
19
20#include <array>
21#include <cstdint>
22#include <optional>
23#include <vector>
24
25namespace TosaReference
26{
27
28/// \brief Supported generator types
29enum class GeneratorType
30{
31 PseudoRandom,
32 DotProduct,
33 OpFullRange,
34 OpBoundary,
35 OpSpecial,
36};
37
38/// \brief Supported input types
39enum class InputType
40{
41 Variable,
42 Constant,
43};
44
45/// \brief Dot-product generator meta-data
46struct DotProductInfo
47{
48 DotProductInfo() = default;
49
50 int32_t s;
51 int32_t ks;
52 DType accType;
53 int32_t axis;
54 std::array<int32_t, 2> kernel;
55};
56
57/// \brief Generator configuration
58struct GenerateConfig
59{
60 GeneratorType generatorType;
61 DType dataType;
62 InputType inputType;
63 std::vector<int32_t> shape;
64 int32_t inputPos;
65 tosa::Op opType;
66 DotProductInfo dotProductInfo;
67};
68
69/// \brief Parse the generator config when given in JSON form
70std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName);
71
72/// \brief Extract number of total elements
73int64_t numElementsFromShape(const std::vector<int32_t>& shape);
74
75/// \brief Size in bytes of a given type
76size_t elementSizeFromType(DType type);
77
78}; // namespace TosaReference
79
80#endif // GENERATE_UTILS_H_