blob: 88dc979913865b64567aab42081e5026bb444f79 [file] [log] [blame]
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +01001// Copyright (c) 2023, ARM Limited.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14#include "generate.h"
15
16#include <doctest.h>
17
18#include <array>
19#include <string>
20#include <vector>
21
22TEST_SUITE_BEGIN("generate");
23
24TEST_CASE("negative - api")
25{
26 std::string json_cfg = R"({
27 "tensors" : {
28 "in1" : {
29 "generator": "DOT_PRODUCT",
30 "data_type": "FP32",
31 "input_type": "VARIABLE",
32 "shape" : [ 4, 8, 8 ],
33 "input_pos": 0,
34 "op" : "MATMUL",
35 "dot_product_info": {
36 "s": 0,
37 "ks": 10,
38 "acc_type": "FP32"
39 }
40 }
41 }
42 })";
43
44 const std::string tosaName = "in1";
45 const size_t tosaElements = 4 * 8 * 8;
46 const size_t tosaSize = tosaElements * 4;
47
48 SUBCASE("missing input")
49 {
50 REQUIRE_FALSE(tgd_generate_data(NULL, NULL, NULL, 0));
51 }
52 SUBCASE("invalid json")
53 {
54 std::string invalid_json_cfg = R"({
55 "tensors" : {
56 "in1" : {
57 "generator": DOT_PRODUCT,
58 },
59 }
60 })";
61
62 std::vector<float> buffer(tosaElements);
63 REQUIRE_FALSE(tgd_generate_data(invalid_json_cfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaSize));
64 }
65 SUBCASE("invalid json - mismatching name")
66 {
67 std::string invalidName = "notFound1";
68
69 std::vector<float> buffer(tosaElements);
70 REQUIRE_FALSE(tgd_generate_data(json_cfg.c_str(), invalidName.c_str(), (void*)buffer.data(), tosaSize));
71 }
72 SUBCASE("mismatching size")
73 {
74 size_t smallElements = 4 * 8 * 7;
75 size_t smallSize = smallElements * 4;
76
77 std::vector<float> buffer(smallElements);
78 REQUIRE_FALSE(tgd_generate_data(json_cfg.c_str(), tosaName.c_str(), (void*)buffer.data(), smallSize));
79 }
80}
81
82TEST_CASE("positive - dot product")
83{
84 std::string json_cfg = R"({
85 "tensors" : {
86 "in1" : {
87 "generator": "DOT_PRODUCT",
88 "data_type": "FP32",
89 "input_type": "VARIABLE",
90 "shape" : [ 4, 8, 8 ],
91 "input_pos": 0,
92 "op" : "MATMUL",
93 "dot_product_info": {
94 "s": 0,
95 "ks": 10,
96 "acc_type": "FP32"
97 }
98 }
99 }
100 })";
101
102 const std::string tosaName = "in1";
103 const size_t tosaElements = 4 * 8 * 8;
104 const size_t tosaSize = tosaElements * 4;
105
106 SUBCASE("matmul")
107 {
108 std::vector<float> buffer(tosaElements);
109 REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaSize));
110 REQUIRE(buffer[0] == (float)-0.950864);
111 REQUIRE(buffer[1] == 0.f);
112 }
113}
114
115TEST_SUITE_END(); // generate