blob: 594158c1065f3770080f4145fa3032420d3dd843 [file] [log] [blame]
Georgios Pinitas7021ef02023-08-22 08:25:57 +01001
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00002// Copyright (c) 2023-2024, ARM Limited.
Georgios Pinitas7021ef02023-08-22 08:25:57 +01003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "verify_utils.h"
17
18#include <nlohmann/json.hpp>
19
20#include <algorithm>
Jeremy Johnson39431cb2024-01-29 14:42:12 +000021#include <cfloat>
Georgios Pinitas7021ef02023-08-22 08:25:57 +010022#include <map>
Jeremy Johnson08965d32024-02-19 13:57:21 +000023#include <string>
Georgios Pinitas7021ef02023-08-22 08:25:57 +010024
25namespace tosa
26{
27
28NLOHMANN_JSON_SERIALIZE_ENUM(DType,
29 {
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +010030 { DType::DType_UNKNOWN, "UNKNOWN" },
Georgios Pinitas7021ef02023-08-22 08:25:57 +010031 { DType::DType_BOOL, "BOOL" },
32 { DType::DType_INT4, "INT4" },
33 { DType::DType_INT8, "INT8" },
34 { DType::DType_INT16, "INT16" },
35 { DType::DType_INT32, "INT32" },
36 { DType::DType_INT48, "INT48" },
37 { DType::DType_FP16, "FP16" },
38 { DType::DType_BF16, "BF16" },
39 { DType::DType_FP32, "FP32" },
Won Jeon2c34b462024-02-06 18:37:00 +000040 { DType::DType_FP8E4M3, "FP8E4M3" },
41 { DType::DType_FP8E5M2, "FP8E5M2" },
Georgios Pinitas7021ef02023-08-22 08:25:57 +010042 })
43
44} // namespace tosa
45
46namespace TosaReference
47{
48
49NLOHMANN_JSON_SERIALIZE_ENUM(VerifyMode,
50 {
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +010051 { VerifyMode::Unknown, "UNKNOWN" },
Georgios Pinitas7021ef02023-08-22 08:25:57 +010052 { VerifyMode::Exact, "EXACT" },
53 { VerifyMode::Ulp, "ULP" },
54 { VerifyMode::DotProduct, "DOT_PRODUCT" },
Georgios Pinitas7021ef02023-08-22 08:25:57 +010055 { VerifyMode::FpSpecial, "FP_SPECIAL" },
Jeremy Johnson9a758382023-11-07 16:27:35 +000056 { VerifyMode::ReduceProduct, "REDUCE_PRODUCT" },
57 { VerifyMode::AbsError, "ABS_ERROR" },
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +000058 { VerifyMode::Relative, "RELATIVE" },
Georgios Pinitas7021ef02023-08-22 08:25:57 +010059 })
60
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +000061void from_json(const nlohmann::json& j, UlpVerifyInfo& ulpInfo)
Georgios Pinitas7021ef02023-08-22 08:25:57 +010062{
63 j.at("ulp").get_to(ulpInfo.ulp);
64}
65
66void from_json(const nlohmann::json& j, DotProductVerifyInfo& dotProductInfo)
67{
Jeremy Johnsonb2d3bff2024-02-26 16:08:07 +000068 j.at("s").get_to(dotProductInfo.setNumber);
69 j.at("ks").get_to(dotProductInfo.kernelSize);
Georgios Pinitas7021ef02023-08-22 08:25:57 +010070}
71
Jack Frankland12ee1a72023-09-20 09:08:34 +010072void from_json(const nlohmann::json& j, ReduceProductVerifyInfo& reduceProduceInfo)
73{
Jeremy Johnsonb2d3bff2024-02-26 16:08:07 +000074 j.at("n").get_to(reduceProduceInfo.numberOfProducts);
Jack Frankland12ee1a72023-09-20 09:08:34 +010075}
76
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +000077void from_json(const nlohmann::json& j, AbsErrorVerifyInfo& absErrorInfo)
78{
79 if (j.contains("lower_bound"))
80 {
81 j.at("lower_bound").get_to(absErrorInfo.lowerBound);
82 }
Jerry Ge51bd4f52024-02-20 11:21:19 -080083 if (j.contains("normal_divisor"))
84 {
85 j.at("normal_divisor").get_to(absErrorInfo.normalDivisor);
86 }
Jeremy Johnson1eb14552024-04-11 16:21:54 +010087 if (j.contains("bound_as_magnitude"))
88 {
89 j.at("bound_as_magnitude").get_to(absErrorInfo.boundAsMagnitude);
90 }
91 if (j.contains("bound_addition"))
92 {
93 j.at("bound_addition").get_to(absErrorInfo.boundAddition);
94 }
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +000095}
96
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +000097void from_json(const nlohmann::json& j, RelativeVerifyInfo& rInfo)
98{
99 j.at("max").get_to(rInfo.max);
100 j.at("scale").get_to(rInfo.scale);
101}
102
Georgios Pinitas7021ef02023-08-22 08:25:57 +0100103void from_json(const nlohmann::json& j, VerifyConfig& cfg)
104{
105 j.at("mode").get_to(cfg.mode);
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100106 j.at("data_type").get_to(cfg.dataType);
Jeremy Johnsonb2d3bff2024-02-26 16:08:07 +0000107 cfg.ulpInfo.ulp = 0;
Georgios Pinitas7021ef02023-08-22 08:25:57 +0100108 if (j.contains("ulp_info"))
109 {
110 j.at("ulp_info").get_to(cfg.ulpInfo);
111 }
Jeremy Johnsonb2d3bff2024-02-26 16:08:07 +0000112 cfg.dotProductInfo.setNumber = 0;
113 cfg.dotProductInfo.kernelSize = 0;
Georgios Pinitas7021ef02023-08-22 08:25:57 +0100114 if (j.contains("dot_product_info"))
115 {
116 j.at("dot_product_info").get_to(cfg.dotProductInfo);
117 }
Jeremy Johnsonb2d3bff2024-02-26 16:08:07 +0000118 cfg.reduceProductInfo.numberOfProducts = 0;
Jack Frankland12ee1a72023-09-20 09:08:34 +0100119 if (j.contains("reduce_product_info"))
120 {
121 j.at("reduce_product_info").get_to(cfg.reduceProductInfo);
122 }
Jeremy Johnson1eb14552024-04-11 16:21:54 +0100123 cfg.absErrorInfo.lowerBound = 0;
124 cfg.absErrorInfo.normalDivisor = 1;
125 cfg.absErrorInfo.boundAsMagnitude = false;
126 cfg.absErrorInfo.boundAddition = 0;
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000127 if (j.contains("abs_error_info"))
128 {
129 j.at("abs_error_info").get_to(cfg.absErrorInfo);
130 }
Jeremy Johnsonb2d3bff2024-02-26 16:08:07 +0000131 cfg.relativeInfo.max = 0;
132 cfg.relativeInfo.scale = 0;
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000133 if (j.contains("relative_info"))
134 {
135 j.at("relative_info").get_to(cfg.relativeInfo);
136 }
Georgios Pinitas7021ef02023-08-22 08:25:57 +0100137}
138
139std::optional<VerifyConfig> parseVerifyConfig(const char* tensorName, const char* json)
140{
141 if (!tensorName)
142 return std::nullopt;
143
144 auto jsonCfg = nlohmann::json::parse(json, nullptr, /* allow exceptions */ false);
145
146 if (jsonCfg.is_discarded())
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100147 {
148 WARNING("[Verifier] Invalid json config.");
Georgios Pinitas7021ef02023-08-22 08:25:57 +0100149 return std::nullopt;
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100150 }
Georgios Pinitas7021ef02023-08-22 08:25:57 +0100151 if (!jsonCfg.contains("tensors"))
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100152 {
153 WARNING("[Verifier] Missing tensors in json config.");
Georgios Pinitas7021ef02023-08-22 08:25:57 +0100154 return std::nullopt;
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100155 }
Georgios Pinitas7021ef02023-08-22 08:25:57 +0100156
157 const auto& tensors = jsonCfg["tensors"];
158 if (!tensors.contains(tensorName))
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100159 if (!tensors.contains(tensorName))
160 {
161 WARNING("[Verifier] Missing tensor %s in json config.", tensorName);
162 return std::nullopt;
163 }
Georgios Pinitas7021ef02023-08-22 08:25:57 +0100164 const auto& namedTensor = tensors[tensorName];
165 return namedTensor.get<VerifyConfig>();
166}
167
168int64_t numElements(const std::vector<int32_t>& shape)
169{
170 return std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<int64_t>());
171}
172
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000173std::vector<int32_t> indexToPosition(int64_t index, const std::vector<int32_t>& shape)
174{
175 std::vector<int32_t> pos;
176 for (auto d = shape.end() - 1; d >= shape.begin(); --d)
177 {
178 pos.insert(pos.begin(), index % *d);
179 index /= *d;
180 }
Jeremy Johnson1f752322024-01-23 15:02:34 +0000181 ASSERT_MSG(index == 0, "index too large for given shape")
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000182 return pos;
183}
184
185std::string positionToString(const std::vector<int32_t>& pos)
186{
187 std::string str = "[";
188 for (auto d = pos.begin(); d < pos.end(); ++d)
189 {
190 str.append(std::to_string(*d));
191 if (pos.end() - d > 1)
192 {
193 str.append(",");
194 }
195 }
196 str.append("]");
197 return str;
198}
199
Georgios Pinitas7021ef02023-08-22 08:25:57 +0100200DType mapToDType(tosa_datatype_t dataType)
201{
202 static std::map<tosa_datatype_t, DType> typeMap = {
Won Jeon2c34b462024-02-06 18:37:00 +0000203 { tosa_datatype_bool_t, DType_BOOL }, { tosa_datatype_int4_t, DType_INT4 },
204 { tosa_datatype_int8_t, DType_INT8 }, { tosa_datatype_uint16_t, DType_UINT16 },
205 { tosa_datatype_int16_t, DType_INT16 }, { tosa_datatype_int32_t, DType_INT32 },
206 { tosa_datatype_int48_t, DType_INT48 }, { tosa_datatype_fp16_t, DType_FP16 },
207 { tosa_datatype_bf16_t, DType_BF16 }, { tosa_datatype_fp32_t, DType_FP32 },
208 { tosa_datatype_shape_t, DType_SHAPE }, { tosa_datatype_fp8e4m3_t, DType_FP8E4M3 },
209 { tosa_datatype_fp8e5m2_t, DType_FP8E5M2 },
Georgios Pinitas7021ef02023-08-22 08:25:57 +0100210 };
211
212 if (typeMap.count(dataType))
213 {
214 return typeMap[dataType];
215 }
216
217 return DType_UNKNOWN;
218}
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100219
220// Like const_exp2 but for use during runtime
221double exp2(int32_t n)
222{
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000223 if (n < -1075)
224 {
225 return 0.0; // smaller than smallest denormal
226 }
227 TOSA_REF_REQUIRE(n <= 1023, " Invalid exponent value (%d) in exp2", n);
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100228 return const_exp2(n);
229}
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100230
231int32_t ilog2(double v)
232{
233 TOSA_REF_REQUIRE(0.0 < v && v < std::numeric_limits<double>::infinity(), " Value out of range (%g) in ilog2", v);
234 int32_t n = 0;
235 while (v >= 2.0)
236 {
237 v = v / 2.0;
238 n++;
239 }
240 while (v < 1.0)
241 {
242 v = v * 2.0;
243 n--;
244 }
245 return n;
246}
Jeremy Johnson9a758382023-11-07 16:27:35 +0000247
248static_assert(std::numeric_limits<float>::is_iec559,
249 "TOSA Reference Model has not been built with standard IEEE 754 32-bit float support; Bounds based "
250 "verification is invalid");
251static_assert(std::numeric_limits<double>::is_iec559,
252 "TOSA Reference Model has not been built with standard IEEE 754 64-bit float support; Bounds based "
253 "verification is invalid");
254
Jeremy Johnson718f3472023-11-30 14:18:19 +0000255template <typename OutType>
Jeremy Johnson08965d32024-02-19 13:57:21 +0000256bool tosaCheckFloatBound(
257 OutType testValue, double referenceValue, double errorBound, double& resultDifference, std::string& resultWarning)
Jeremy Johnson9a758382023-11-07 16:27:35 +0000258{
259 // Both must be NaNs to be correct
260 if (std::isnan(referenceValue) || std::isnan(testValue))
261 {
262 if (std::isnan(referenceValue) && std::isnan(testValue))
263 {
Jeremy Johnson08965d32024-02-19 13:57:21 +0000264 resultDifference = 0.0;
Jeremy Johnson9a758382023-11-07 16:27:35 +0000265 return true;
266 }
Jeremy Johnson08965d32024-02-19 13:57:21 +0000267 char buff[200];
268 snprintf(buff, 200, "Non-matching NaN values - ref (%g) versus test (%g).", referenceValue,
269 static_cast<double>(testValue));
270 resultWarning.assign(buff);
271 resultDifference = std::numeric_limits<double>::quiet_NaN();
Jeremy Johnson9a758382023-11-07 16:27:35 +0000272 return false;
273 }
274
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000275 // Check the errorBound
276 TOSA_REF_REQUIRE(errorBound >= 0.f, " Invalid error bound (%g)", errorBound);
277
Jeremy Johnson9a758382023-11-07 16:27:35 +0000278 // Make the sign of the reference value positive
279 // and adjust the test value appropriately.
280 if (referenceValue < 0)
281 {
282 referenceValue = -referenceValue;
283 testValue = -testValue;
284 }
Jeremy Johnson9a758382023-11-07 16:27:35 +0000285
286 // At this point we are ready to calculate the ULP bounds for the reference value.
287 double referenceMin, referenceMax;
288
289 // If the reference is infinity e.g. the result of an overflow the test value must
290 // be infinity of an appropriate sign.
291 if (std::isinf(referenceValue))
292 {
293 // We already canonicalized the input such that the reference value is positive
294 // so no need to check again here.
Jeremy Johnson718f3472023-11-30 14:18:19 +0000295 referenceMin = std::numeric_limits<OutType>::infinity();
296 referenceMax = std::numeric_limits<OutType>::infinity();
Jeremy Johnson9a758382023-11-07 16:27:35 +0000297 }
298 else if (referenceValue == 0)
299 {
300 // For zero we require that the results match exactly with the correct sign.
301 referenceMin = 0;
302 referenceMax = 0;
303 }
304 else
305 {
306
307 // Scale by the number of ULPs requested by the user.
308 referenceMax = referenceValue + errorBound;
309 referenceMin = referenceValue - errorBound;
310
311 // Handle the overflow cases.
Jeremy Johnson718f3472023-11-30 14:18:19 +0000312 if (referenceMax > AccPrecision<OutType>::normal_max)
Jeremy Johnson9a758382023-11-07 16:27:35 +0000313 {
Jeremy Johnson718f3472023-11-30 14:18:19 +0000314 referenceMax = std::numeric_limits<OutType>::infinity();
Jeremy Johnson9a758382023-11-07 16:27:35 +0000315 }
316
Jeremy Johnson718f3472023-11-30 14:18:19 +0000317 if (referenceMin > AccPrecision<OutType>::normal_max)
Jeremy Johnson9a758382023-11-07 16:27:35 +0000318 {
Jeremy Johnson718f3472023-11-30 14:18:19 +0000319 referenceMin = std::numeric_limits<OutType>::infinity();
Jeremy Johnson9a758382023-11-07 16:27:35 +0000320 }
321
322 // And the underflow cases.
Jeremy Johnson718f3472023-11-30 14:18:19 +0000323 if (referenceMax < AccPrecision<OutType>::normal_min)
Jeremy Johnson9a758382023-11-07 16:27:35 +0000324 {
Jeremy Johnson718f3472023-11-30 14:18:19 +0000325 referenceMax = AccPrecision<OutType>::normal_min;
Jeremy Johnson9a758382023-11-07 16:27:35 +0000326 }
327
Jeremy Johnson718f3472023-11-30 14:18:19 +0000328 if (referenceMin < AccPrecision<OutType>::normal_min)
Jeremy Johnson9a758382023-11-07 16:27:35 +0000329 {
Jeremy Johnson1eb14552024-04-11 16:21:54 +0100330 // Large error bounds could mean referenceMin is negative
331 referenceMin = std::min(0.0, referenceMin);
Jeremy Johnson9a758382023-11-07 16:27:35 +0000332 }
333 }
334
335 // And finally... Do the comparison.
336 double testValue64 = static_cast<double>(testValue);
337 bool withinBound = testValue64 >= referenceMin && testValue64 <= referenceMax;
Jeremy Johnson08965d32024-02-19 13:57:21 +0000338 resultDifference = testValue64 - referenceValue;
Jeremy Johnson9a758382023-11-07 16:27:35 +0000339 if (!withinBound)
340 {
Jeremy Johnson08965d32024-02-19 13:57:21 +0000341 char buff[300];
342 snprintf(buff, 300,
343 "value %.*g has a difference of %.*g compared to an error bound of +/- %.*g (range: %.*g <= ref %.*g "
344 "<= %.*g).",
345 DBL_DIG, testValue64, DBL_DIG, resultDifference, DBL_DIG, errorBound, DBL_DIG, referenceMin, DBL_DIG,
346 referenceValue, DBL_DIG, referenceMax);
347 resultWarning.assign(buff);
Jeremy Johnson9a758382023-11-07 16:27:35 +0000348 }
349 return withinBound;
350}
Jeremy Johnson718f3472023-11-30 14:18:19 +0000351
Jeremy Johnson08965d32024-02-19 13:57:21 +0000352template <typename OutType>
353bool validateData(const double* referenceData,
354 const double* boundsData,
355 const OutType* implementationData,
356 const std::vector<int32_t>& shape,
357 const std::string& modeStr,
358 const void* cfgPtr,
359 double (*calcErrorBound)(double referenceValue, double boundsValue, const void* cfgPtr))
360{
361 const size_t T = static_cast<size_t>(numElements(shape));
362 TOSA_REF_REQUIRE(T > 0, "Invalid shape for reference tensor");
363 TOSA_REF_REQUIRE(referenceData != nullptr, "Missing data for reference tensor");
364 TOSA_REF_REQUIRE(implementationData != nullptr, "Missing data for implementation tensor");
365 // NOTE: Bounds data tensor is allowed to be null as it may not be needed
366 TOSA_REF_REQUIRE(cfgPtr != nullptr, "Missing config for validation");
367 TOSA_REF_REQUIRE(calcErrorBound != nullptr, "Missing error bound function validation");
368
369 std::string warning, worstWarning;
evacha019c96eef2024-02-07 11:21:55 +0000370 double worstDifference = 0.0;
371 // Set to invalid index
372 size_t worstIndex = T;
373 bool compliant = true;
Jeremy Johnson08965d32024-02-19 13:57:21 +0000374
375 for (size_t i = 0; i < T; ++i)
376 {
evacha019c96eef2024-02-07 11:21:55 +0000377 double difference = 0.0;
378 double boundVal = (boundsData == nullptr) ? 0.0 : boundsData[i];
379 double errBound = calcErrorBound(referenceData[i], boundVal, cfgPtr);
380 bool valid = tosaCheckFloatBound(implementationData[i], referenceData[i], errBound, difference, warning);
Jeremy Johnson08965d32024-02-19 13:57:21 +0000381 if (!valid)
382 {
383 compliant = false;
384 if (std::isnan(difference) || std::abs(difference) > std::abs(worstDifference))
385 {
evacha019c96eef2024-02-07 11:21:55 +0000386 worstIndex = i;
Jeremy Johnson08965d32024-02-19 13:57:21 +0000387 worstDifference = difference;
388 worstWarning.assign(warning);
389 if (std::isnan(difference))
390 {
391 // Worst case is difference in NaN
392 break;
393 }
394 }
evacha019c96eef2024-02-07 11:21:55 +0000395 else if (std::abs(difference) == 0.0)
396 {
397 auto pos = indexToPosition(i, shape);
398 WARNING("[Verifier][%s] Invalid error bound, no difference found. Location: %s", modeStr.c_str(),
399 positionToString(pos).c_str());
400 return false;
401 }
Jeremy Johnson08965d32024-02-19 13:57:21 +0000402 }
403 }
404 if (!compliant)
405 {
evacha019c96eef2024-02-07 11:21:55 +0000406 auto pos = indexToPosition(worstIndex, shape);
Jeremy Johnson08965d32024-02-19 13:57:21 +0000407 WARNING("[Verifier][%s] Largest deviance at location %s: %s", modeStr.c_str(), positionToString(pos).c_str(),
408 worstWarning.c_str());
409 }
410 return compliant;
411}
412
Jeremy Johnson718f3472023-11-30 14:18:19 +0000413// Instantiate the needed check functions
Jeremy Johnson08965d32024-02-19 13:57:21 +0000414template bool validateData(const double* referenceData,
415 const double* boundsData,
416 const float* implementationData,
417 const std::vector<int32_t>& shape,
418 const std::string& modeStr,
419 const void* cfgPtr,
420 double (*calcErrorBound)(double referenceValue, double boundsValue, const void* cfgPtr));
421template bool validateData(const double* referenceData,
422 const double* boundsData,
423 const half_float::half* implementationData,
424 const std::vector<int32_t>& shape,
425 const std::string& modeStr,
426 const void* cfgPtr,
427 double (*calcErrorBound)(double referenceValue, double boundsValue, const void* cfgPtr));
428
Georgios Pinitas7021ef02023-08-22 08:25:57 +0100429} // namespace TosaReference