e3e662b52b
Change-Id: I2648909483b8dc32fcd720c18608c5ca32399045
1097 linhas
50 KiB
C++
1097 linhas
50 KiB
C++
/*
|
|
Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
of this software and associated documentation files (the "Software"), to deal
|
|
in the Software without restriction, including without limitation the rights
|
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
copies of the Software, and to permit persons to whom the Software is
|
|
furnished to do so, subject to the following conditions:
|
|
|
|
The above copyright notice and this permission notice shall be included in
|
|
all copies or substantial portions of the Software.
|
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
|
THE SOFTWARE.
|
|
*/
|
|
|
|
#include <hip_test_common.hh>
|
|
#include <hip/hip_fp8.h>
|
|
|
|
#include <type_traits>
|
|
#include <vector>
|
|
#include <bitset>
|
|
|
|
/*
|
|
* Tests for fp8 conversions on host
|
|
* Both FNUZ and OCP types are supported on host
|
|
*/
|
|
|
|
TEST_CASE("Unit_fp8_ocp_bool_host") {
|
|
// clang-format off
|
|
std::vector<float> fvals{-10.0f, -1.0f, -0.0f, 0.0f, 1.0f, 10.0f};
|
|
std::vector<bool> tvals {true, true, false, false, true, true};
|
|
// clang-format on
|
|
|
|
bool result[] = {false, false, false,
|
|
false, false, false}; // cant use std::vector coz data() = delete
|
|
|
|
SECTION("e4m3_ocp-cpu") {
|
|
for (size_t i = 0; i < tvals.size(); i++) {
|
|
__hip_fp8_e4m3 fp8(fvals[i]);
|
|
result[i] = fp8;
|
|
}
|
|
}
|
|
|
|
SECTION("e5m2_ocp-cpu") {
|
|
for (size_t i = 0; i < tvals.size(); i++) {
|
|
__hip_fp8_e5m2 fp8(fvals[i]);
|
|
result[i] = fp8;
|
|
}
|
|
}
|
|
|
|
for (size_t i = 0; i < tvals.size(); i++) {
|
|
INFO("Check for: " << fvals[i] << " expected: " << tvals[i] << " result: " << result[i]);
|
|
REQUIRE(result[i] == tvals[i]);
|
|
}
|
|
}
|
|
|
|
// test to check we are putting in data correctly in vector types
|
|
TEST_CASE("Unit_all_fp8_ocp_vector_cvt") {
|
|
float2 in2{1.0f, 2.0f};
|
|
float4 in4{3.0f, 4.0f, 5.0f, 6.0f};
|
|
|
|
SECTION("e4m3_ocp x2") {
|
|
__hip_fp8x2_e4m3 in(in2);
|
|
float2 out = in;
|
|
INFO("In: " << in2.x << " - " << in2.y);
|
|
INFO("Out: " << out.x << " - " << out.y);
|
|
REQUIRE(out == in2);
|
|
}
|
|
SECTION("e4m3_ocp x4") {
|
|
__hip_fp8x4_e4m3 in(in4);
|
|
float4 out = in;
|
|
INFO("In: " << in4.x << " - " << in4.y << " - " << in4.z << " - " << in4.w);
|
|
INFO("Out: " << out.x << " - " << out.y << " - " << out.z << " - " << out.w);
|
|
REQUIRE(out == in4);
|
|
}
|
|
|
|
SECTION("e5m2_ocp x2") {
|
|
__hip_fp8x2_e5m2 in(in2);
|
|
float2 out = in;
|
|
INFO("In: " << in2.x << " - " << in2.y);
|
|
INFO("Out: " << out.x << " - " << out.y);
|
|
REQUIRE(out == in2);
|
|
}
|
|
|
|
SECTION("e5m2_ocp x4") {
|
|
__hip_fp8x4_e5m2 in(in4);
|
|
float4 out = in;
|
|
INFO("In: " << in4.x << " - " << in4.y << " - " << in4.z << " - " << in4.w);
|
|
INFO("Out: " << out.x << " - " << out.y << " - " << out.z << " - " << out.w);
|
|
REQUIRE(out == in4);
|
|
}
|
|
|
|
SECTION("half x2 e4m3_ocp") {
|
|
__hip_fp8x2_e4m3 in(in2);
|
|
auto hr2 = __hip_cvt_fp8x2_to_halfraw2(in.__x, __HIP_E4M3);
|
|
float2 fout1 = in;
|
|
float2 fout2 = __half22float2(__half2(hr2));
|
|
INFO("In: " << in2.x << " - " << in2.y);
|
|
INFO("Out from f8 : " << fout1.x << " - " << fout1.y);
|
|
INFO("Out from half: " << fout2.x << " - " << fout2.y);
|
|
REQUIRE(fout1 == fout2);
|
|
}
|
|
|
|
SECTION("half x2 e5m2_ocp") {
|
|
__hip_fp8x2_e5m2 in(in2);
|
|
auto hr2 = __hip_cvt_fp8x2_to_halfraw2(in.__x, __HIP_E5M2);
|
|
float2 fout1 = in;
|
|
float2 fout2 = __half22float2(__half2(hr2));
|
|
INFO("In: " << in2.x << " - " << in2.y);
|
|
INFO("Out from f8 : " << fout1.x << " - " << fout1.y);
|
|
INFO("Out from half: " << fout2.x << " - " << fout2.y);
|
|
REQUIRE(fout1 == fout2);
|
|
}
|
|
}
|
|
|
|
|
|
TEMPLATE_TEST_CASE("Unit_fp8_ocp_correctness", "", float, double) {
|
|
SECTION("e4m3_ocp") {
|
|
/* These are basically all the fp8 - e4m3_ocp type numbers.
|
|
* They can be generated by iterating over 0'0000'000 and converting them to fp32 number
|
|
* skipping the nan/inf */
|
|
std::vector<TestType> e4m3_ocp_nums = { 0, 0.00195312, 0.00390625,
|
|
0.00585938, 0.0078125, 0.00976562,
|
|
0.0117188, 0.0136719, 0.015625,
|
|
0.0175781, 0.0195312, 0.0214844,
|
|
0.0234375, 0.0253906, 0.0273438,
|
|
0.0292969, 0.03125, 0.0351562,
|
|
0.0390625, 0.0429688, 0.046875,
|
|
0.0507812, 0.0546875, 0.0585938,
|
|
0.0625, 0.0703125, 0.078125,
|
|
0.0859375, 0.09375, 0.101562,
|
|
0.109375, 0.117188, 0.125,
|
|
0.140625, 0.15625, 0.171875,
|
|
0.1875, 0.203125, 0.21875,
|
|
0.234375, 0.25, 0.28125,
|
|
0.3125, 0.34375, 0.375,
|
|
0.40625, 0.4375, 0.46875,
|
|
0.5, 0.5625, 0.625,
|
|
0.6875, 0.75, 0.8125,
|
|
0.875, 0.9375, 1,
|
|
1.125, 1.25, 1.375,
|
|
1.5, 1.625, 1.75,
|
|
1.875, 2, 2.25,
|
|
2.5, 2.75, 3,
|
|
3.25, 3.5, 3.75,
|
|
4, 4.5, 5,
|
|
5.5, 6, 6.5,
|
|
7, 7.5, 8,
|
|
9, 10, 11,
|
|
12, 13, 14,
|
|
15, 16, 18,
|
|
20, 22, 24,
|
|
26, 28, 30,
|
|
32, 36, 40,
|
|
44, 48, 52,
|
|
56, 60, 64,
|
|
72, 80, 88,
|
|
96, 104, 112,
|
|
120, 128, 144,
|
|
160, 176, 192,
|
|
208, 224, 240,
|
|
256, 288, 320,
|
|
352, 384, 416,
|
|
448, -0, -0.00195312,
|
|
-0.00390625, -0.00585938, -0.0078125,
|
|
-0.00976562, -0.0117188, -0.0136719,
|
|
-0.015625, -0.0175781, -0.0195312,
|
|
-0.0214844, -0.0234375, -0.0253906,
|
|
-0.0273438, -0.0292969, -0.03125,
|
|
-0.0351562, -0.0390625, -0.0429688,
|
|
-0.046875, -0.0507812, -0.0546875,
|
|
-0.0585938, -0.0625, -0.0703125,
|
|
-0.078125, -0.0859375, -0.09375,
|
|
-0.101562, -0.109375, -0.117188,
|
|
-0.125, -0.140625, -0.15625,
|
|
-0.171875, -0.1875, -0.203125,
|
|
-0.21875, -0.234375, -0.25,
|
|
-0.28125, -0.3125, -0.34375,
|
|
-0.375, -0.40625, -0.4375,
|
|
-0.46875, -0.5, -0.5625,
|
|
-0.625, -0.6875, -0.75,
|
|
-0.8125, -0.875, -0.9375,
|
|
-1, -1.125, -1.25,
|
|
-1.375, -1.5, -1.625,
|
|
-1.75, -1.875, -2,
|
|
-2.25, -2.5, -2.75,
|
|
-3, -3.25, -3.5,
|
|
-3.75, -4, -4.5,
|
|
-5, -5.5, -6,
|
|
-6.5, -7, -7.5,
|
|
-8, -9, -10,
|
|
-11, -12, -13,
|
|
-14, -15, -16,
|
|
-18, -20, -22,
|
|
-24, -26, -28,
|
|
-30, -32, -36,
|
|
-40, -44, -48,
|
|
-52, -56, -60,
|
|
-64, -72, -80,
|
|
-88, -96, -104,
|
|
-112, -120, -128,
|
|
-144, -160, -176,
|
|
-192, -208, -224,
|
|
-240, -256, -288,
|
|
-320, -352, -384,
|
|
-416, -448};
|
|
for (const auto& orig : e4m3_ocp_nums) {
|
|
__hip_fp8_e4m3 fp8(orig);
|
|
float cvt1 = fp8;
|
|
|
|
__hip_fp8_e4m3 tmp;
|
|
tmp.__x = std::is_same<TestType, float>::value
|
|
? __hip_cvt_float_to_fp8(orig, __HIP_SATFINITE, __HIP_E4M3)
|
|
: __hip_cvt_double_to_fp8(orig, __HIP_SATFINITE, __HIP_E4M3);
|
|
;
|
|
float cvt2 = tmp;
|
|
|
|
INFO("Original: " << std::bitset<32>(*reinterpret_cast<const unsigned int*>(&orig)));
|
|
INFO("Cvt back: " << std::bitset<32>(*reinterpret_cast<const unsigned int*>(&cvt1)));
|
|
REQUIRE(cvt1 == Approx(orig));
|
|
REQUIRE(cvt2 == cvt1);
|
|
}
|
|
}
|
|
|
|
SECTION("e5m2_ocp") {
|
|
/* These are basically all the fp8 - e5m2_ocp type numbers.
|
|
* They can be generated by iterating over 0'00000'00 converting them to fp32 number skipping
|
|
* the nan/inf */
|
|
std::vector<TestType> e5m2_ocp_nums = { 0, 1.52588e-05, 3.05176e-05,
|
|
4.57764e-05, 6.10352e-05, 7.62939e-05,
|
|
9.15527e-05, 0.000106812, 0.00012207,
|
|
0.000152588, 0.000183105, 0.000213623,
|
|
0.000244141, 0.000305176, 0.000366211,
|
|
0.000427246, 0.000488281, 0.000610352,
|
|
0.000732422, 0.000854492, 0.000976562,
|
|
0.0012207, 0.00146484, 0.00170898,
|
|
0.00195312, 0.00244141, 0.00292969,
|
|
0.00341797, 0.00390625, 0.00488281,
|
|
0.00585938, 0.00683594, 0.0078125,
|
|
0.00976562, 0.0117188, 0.0136719,
|
|
0.015625, 0.0195312, 0.0234375,
|
|
0.0273438, 0.03125, 0.0390625,
|
|
0.046875, 0.0546875, 0.0625,
|
|
0.078125, 0.09375, 0.109375,
|
|
0.125, 0.15625, 0.1875,
|
|
0.21875, 0.25, 0.3125,
|
|
0.375, 0.4375, 0.5,
|
|
0.625, 0.75, 0.875,
|
|
1, 1.25, 1.5,
|
|
1.75, 2, 2.5,
|
|
3, 3.5, 4,
|
|
5, 6, 7,
|
|
8, 10, 12,
|
|
14, 16, 20,
|
|
24, 28, 32,
|
|
40, 48, 56,
|
|
64, 80, 96,
|
|
112, 128, 160,
|
|
192, 224, 256,
|
|
320, 384, 448,
|
|
512, 640, 768,
|
|
896, 1024, 1280,
|
|
1536, 1792, 2048,
|
|
2560, 3072, 3584,
|
|
4096, 5120, 6144,
|
|
7168, 8192, 10240,
|
|
12288, 14336, 16384,
|
|
20480, 24576, 28672,
|
|
32768, 40960, 49152,
|
|
57344, -0, -1.52588e-05,
|
|
-3.05176e-05, -4.57764e-05, -6.10352e-05,
|
|
-7.62939e-05, -9.15527e-05, -0.000106812,
|
|
-0.00012207, -0.000152588, -0.000183105,
|
|
-0.000213623, -0.000244141, -0.000305176,
|
|
-0.000366211, -0.000427246, -0.000488281,
|
|
-0.000610352, -0.000732422, -0.000854492,
|
|
-0.000976562, -0.0012207, -0.00146484,
|
|
-0.00170898, -0.00195312, -0.00244141,
|
|
-0.00292969, -0.00341797, -0.00390625,
|
|
-0.00488281, -0.00585938, -0.00683594,
|
|
-0.0078125, -0.00976562, -0.0117188,
|
|
-0.0136719, -0.015625, -0.0195312,
|
|
-0.0234375, -0.0273438, -0.03125,
|
|
-0.0390625, -0.046875, -0.0546875,
|
|
-0.0625, -0.078125, -0.09375,
|
|
-0.109375, -0.125, -0.15625,
|
|
-0.1875, -0.21875, -0.25,
|
|
-0.3125, -0.375, -0.4375,
|
|
-0.5, -0.625, -0.75,
|
|
-0.875, -1, -1.25,
|
|
-1.5, -1.75, -2,
|
|
-2.5, -3, -3.5,
|
|
-4, -5, -6,
|
|
-7, -8, -10,
|
|
-12, -14, -16,
|
|
-20, -24, -28,
|
|
-32, -40, -48,
|
|
-56, -64, -80,
|
|
-96, -112, -128,
|
|
-160, -192, -224,
|
|
-256, -320, -384,
|
|
-448, -512, -640,
|
|
-768, -896, -1024,
|
|
-1280, -1536, -1792,
|
|
-2048, -2560, -3072,
|
|
-3584, -4096, -5120,
|
|
-6144, -7168, -8192,
|
|
-10240, -12288, -14336,
|
|
-16384, -20480, -24576,
|
|
-28672, -32768, -40960,
|
|
-49152, -57344};
|
|
for (const auto& orig : e5m2_ocp_nums) {
|
|
__hip_fp8_e5m2 fp8(orig);
|
|
float cvt1 = fp8;
|
|
|
|
__hip_fp8_e5m2 tmp;
|
|
tmp.__x = std::is_same<TestType, float>::value
|
|
? __hip_cvt_float_to_fp8(orig, __HIP_SATFINITE, __HIP_E5M2)
|
|
: __hip_cvt_double_to_fp8(orig, __HIP_SATFINITE, __HIP_E5M2);
|
|
;
|
|
float cvt2 = tmp;
|
|
|
|
INFO("Original: " << std::bitset<32>(*reinterpret_cast<const unsigned int*>(&orig)));
|
|
INFO("Cvt back: " << std::bitset<32>(*reinterpret_cast<const unsigned int*>(&cvt1)));
|
|
REQUIRE(cvt1 == Approx(orig));
|
|
REQUIRE(cvt1 == cvt2);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check the orientation encoded is correct
|
|
TEST_CASE("Unit_fp8_ocp_vector_basic_conversions") {
|
|
float f1 = 1.0f;
|
|
float2 f2 = {1.0f, 2.0f};
|
|
float4 f4 = {1.0f, 2.0f, 3.0f, 4.0f};
|
|
|
|
SECTION("e4m3-ocp cvt float") {
|
|
__hip_fp8_e4m3 f8_1 = f1;
|
|
__hip_fp8x2_e4m3 f8_2 = f2;
|
|
__hip_fp8x4_e4m3 f8_4 = f4;
|
|
|
|
float cf1 = f8_1;
|
|
float2 cf2 = f8_2;
|
|
float4 cf4 = f8_4;
|
|
|
|
__hip_fp8x2_e4m3 tmp;
|
|
tmp.__x = __hip_cvt_float2_to_fp8x2(cf2, __HIP_SATFINITE, __HIP_E4M3);
|
|
float2 xtmp = tmp;
|
|
|
|
REQUIRE(f1 == cf1);
|
|
REQUIRE(f2 == cf2);
|
|
REQUIRE(f4 == cf4);
|
|
|
|
REQUIRE(xtmp == f2);
|
|
}
|
|
|
|
SECTION("e5m2-ocp cvt float") {
|
|
__hip_fp8_e5m2 f8_1 = f1;
|
|
__hip_fp8x2_e5m2 f8_2 = f2;
|
|
__hip_fp8x4_e5m2 f8_4 = f4;
|
|
|
|
float cf1 = f8_1;
|
|
float2 cf2 = f8_2;
|
|
float4 cf4 = f8_4;
|
|
|
|
__hip_fp8x2_e5m2 tmp;
|
|
tmp.__x = __hip_cvt_float2_to_fp8x2(cf2, __HIP_SATFINITE, __HIP_E5M2);
|
|
float2 xtmp = tmp;
|
|
|
|
REQUIRE(f1 == cf1);
|
|
REQUIRE(f2 == cf2);
|
|
REQUIRE(f4 == cf4);
|
|
|
|
REQUIRE(xtmp == f2);
|
|
}
|
|
|
|
SECTION("e4m3-ocp cvt double") {
|
|
double d1 = f1;
|
|
double2 d2 = {f2.x, f2.y};
|
|
double4 d4 = {f4.x, f4.y, f4.z, f4.w};
|
|
__hip_fp8_e4m3 f8_1 = d1;
|
|
__hip_fp8x2_e4m3 f8_2 = d2;
|
|
__hip_fp8x4_e4m3 f8_4 = d4;
|
|
|
|
double cf1 = f8_1;
|
|
float2 cf2 = f8_2;
|
|
float4 cf4 = f8_4;
|
|
|
|
__hip_fp8x2_e4m3 tmp;
|
|
tmp.__x = __hip_cvt_double2_to_fp8x2(d2, __HIP_SATFINITE, __HIP_E4M3);
|
|
float2 xtmp = tmp;
|
|
|
|
REQUIRE(d1 == cf1);
|
|
REQUIRE(d2 == double2{cf2.x, cf2.y});
|
|
REQUIRE(d4 == double4{cf4.x, cf4.y, cf4.z, cf4.w});
|
|
|
|
REQUIRE(double2{xtmp.x, xtmp.y} == d2);
|
|
}
|
|
|
|
SECTION("e5m2-ocp cvt double") {
|
|
double d1 = f1;
|
|
double2 d2 = {f2.x, f2.y};
|
|
double4 d4 = {f4.x, f4.y, f4.z, f4.w};
|
|
__hip_fp8_e5m2 f8_1 = d1;
|
|
__hip_fp8x2_e5m2 f8_2 = d2;
|
|
__hip_fp8x4_e5m2 f8_4 = d4;
|
|
|
|
double cf1 = f8_1;
|
|
float2 cf2 = f8_2;
|
|
float4 cf4 = f8_4;
|
|
|
|
__hip_fp8x2_e5m2 tmp;
|
|
tmp.__x = __hip_cvt_double2_to_fp8x2(d2, __HIP_SATFINITE, __HIP_E5M2);
|
|
float2 xtmp = tmp;
|
|
|
|
REQUIRE(d1 == cf1);
|
|
REQUIRE(d2 == double2{cf2.x, cf2.y});
|
|
REQUIRE(d4 == double4{cf4.x, cf4.y, cf4.z, cf4.w});
|
|
|
|
REQUIRE(double2{xtmp.x, xtmp.y} == d2);
|
|
}
|
|
|
|
SECTION("e4m3-ocp half2/bfloat162") {
|
|
auto bf16_val = __float22bfloat162_rn(f2);
|
|
auto half2_val = __float22half2_rn(f2);
|
|
|
|
__hip_fp8x2_e4m3 x1(bf16_val);
|
|
__hip_fp8x2_e4m3 x2(half2_val);
|
|
|
|
__hip_fp8x2_e4m3 tmp1;
|
|
tmp1.__x = __hip_cvt_bfloat16raw2_to_fp8x2(bf16_val, __HIP_SATFINITE, __HIP_E4M3);
|
|
float2 bf2_1 = tmp1;
|
|
|
|
tmp1.__x = __hip_cvt_halfraw2_to_fp8x2(half2_val, __HIP_SATFINITE, __HIP_E4M3);
|
|
float2 h2_1 = tmp1;
|
|
|
|
float2 f2_1 = x1;
|
|
float2 f2_2 = x2;
|
|
|
|
REQUIRE(f2_1 == f2);
|
|
REQUIRE(f2_2 == f2);
|
|
|
|
REQUIRE(f2 == bf2_1);
|
|
REQUIRE(f2 == h2_1);
|
|
}
|
|
|
|
SECTION("e5m2-ocp half2/bfloat162") {
|
|
auto bf16_val = __float22bfloat162_rn(f2);
|
|
auto half2_val = __float22half2_rn(f2);
|
|
|
|
__hip_fp8x2_e5m2 x1(bf16_val);
|
|
__hip_fp8x2_e5m2 x2(half2_val);
|
|
|
|
__hip_fp8x2_e5m2 tmp1;
|
|
tmp1.__x = __hip_cvt_bfloat16raw2_to_fp8x2(bf16_val, __HIP_SATFINITE, __HIP_E5M2);
|
|
float2 bf2_1 = tmp1;
|
|
|
|
tmp1.__x = __hip_cvt_halfraw2_to_fp8x2(half2_val, __HIP_SATFINITE, __HIP_E5M2);
|
|
float2 h2_1 = tmp1;
|
|
|
|
float2 f2_1 = x1;
|
|
float2 f2_2 = x2;
|
|
|
|
REQUIRE(f2_1 == f2);
|
|
REQUIRE(f2_2 == f2);
|
|
|
|
REQUIRE(f2 == bf2_1);
|
|
REQUIRE(f2 == h2_1);
|
|
}
|
|
}
|
|
|
|
|
|
TEST_CASE("Unit_fp8_fnuz_bool_host") {
|
|
// clang-format off
|
|
std::vector<float> fvals{-10.0f, -1.0f, -0.0f, 0.0f, 1.0f, 10.0f};
|
|
std::vector<bool> tvals {true, true, false, false, true, true};
|
|
// clang-format on
|
|
|
|
bool result[] = {false, false, false,
|
|
false, false, false}; // cant use std::vector coz data() = delete
|
|
|
|
SECTION("e4m3_fnuz-cpu") {
|
|
for (size_t i = 0; i < tvals.size(); i++) {
|
|
__hip_fp8_e4m3_fnuz fp8(fvals[i]);
|
|
result[i] = fp8;
|
|
}
|
|
}
|
|
|
|
SECTION("e5m2_fnuz-cpu") {
|
|
for (size_t i = 0; i < tvals.size(); i++) {
|
|
__hip_fp8_e5m2_fnuz fp8(fvals[i]);
|
|
result[i] = fp8;
|
|
}
|
|
}
|
|
|
|
for (size_t i = 0; i < tvals.size(); i++) {
|
|
INFO("Check for: " << fvals[i] << " expected: " << tvals[i] << " result: " << result[i]);
|
|
REQUIRE(result[i] == tvals[i]);
|
|
}
|
|
|
|
}
|
|
|
|
// test to check we are putting in data correctly in vector types
|
|
TEST_CASE("Unit_all_fp8_fnuz_vector_cvt") {
|
|
float2 in2{1.0f, 2.0f};
|
|
float4 in4{3.0f, 4.0f, 5.0f, 6.0f};
|
|
|
|
SECTION("e4m3_fnuz x2") {
|
|
__hip_fp8x2_e4m3_fnuz in(in2);
|
|
float2 out = in;
|
|
INFO("In: " << in2.x << " - " << in2.y);
|
|
INFO("Out: " << out.x << " - " << out.y);
|
|
REQUIRE(out == in2);
|
|
}
|
|
SECTION("e4m3_fnuz x4") {
|
|
__hip_fp8x4_e4m3_fnuz in(in4);
|
|
float4 out = in;
|
|
INFO("In: " << in4.x << " - " << in4.y << " - " << in4.z << " - " << in4.w);
|
|
INFO("Out: " << out.x << " - " << out.y << " - " << out.z << " - " << out.w);
|
|
REQUIRE(out == in4);
|
|
}
|
|
|
|
SECTION("e5m2_fnuz x2") {
|
|
__hip_fp8x2_e5m2_fnuz in(in2);
|
|
float2 out = in;
|
|
INFO("In: " << in2.x << " - " << in2.y);
|
|
INFO("Out: " << out.x << " - " << out.y);
|
|
REQUIRE(out == in2);
|
|
}
|
|
|
|
SECTION("e5m2_fnuz x4") {
|
|
__hip_fp8x4_e5m2_fnuz in(in4);
|
|
float4 out = in;
|
|
INFO("In: " << in4.x << " - " << in4.y << " - " << in4.z << " - " << in4.w);
|
|
INFO("Out: " << out.x << " - " << out.y << " - " << out.z << " - " << out.w);
|
|
REQUIRE(out == in4);
|
|
}
|
|
|
|
SECTION("half x2 e4m3_fnuz") {
|
|
__hip_fp8x2_e4m3_fnuz in(in2);
|
|
auto hr2 = __hip_cvt_fp8x2_to_halfraw2(in.__x, __HIP_E4M3_FNUZ);
|
|
float2 fout1 = in;
|
|
float2 fout2 = __half22float2(__half2(hr2));
|
|
INFO("In: " << in2.x << " - " << in2.y);
|
|
INFO("Out from f8 : " << fout1.x << " - " << fout1.y);
|
|
INFO("Out from half: " << fout2.x << " - " << fout2.y);
|
|
REQUIRE(fout1 == fout2);
|
|
}
|
|
|
|
SECTION("half x2 e5m2_fnuz") {
|
|
__hip_fp8x2_e5m2_fnuz in(in2);
|
|
auto hr2 = __hip_cvt_fp8x2_to_halfraw2(in.__x, __HIP_E5M2_FNUZ);
|
|
float2 fout1 = in;
|
|
float2 fout2 = __half22float2(__half2(hr2));
|
|
INFO("In: " << in2.x << " - " << in2.y);
|
|
INFO("Out from f8 : " << fout1.x << " - " << fout1.y);
|
|
INFO("Out from half: " << fout2.x << " - " << fout2.y);
|
|
REQUIRE(fout1 == fout2);
|
|
}
|
|
}
|
|
|
|
TEMPLATE_TEST_CASE("Unit_fp8_fnuz_correctness", "", float, double) {
|
|
SECTION("e4m3_fnuz") {
|
|
/* These are basically all the fp8 - e4m3_fnuz type numbers.
|
|
* They can be generated by iterating over 0'0000'000 and converting them to fp32 number
|
|
* skipping the nan/inf */
|
|
std::vector<TestType> e4m3_fnuz_nums = {0, 0.000976562, 0.00195312,
|
|
0.00292969, 0.00390625, 0.00488281,
|
|
0.00585938, 0.00683594, 0.0078125,
|
|
0.00878906, 0.00976562, 0.0107422,
|
|
0.0117188, 0.0126953, 0.0136719,
|
|
0.0146484, 0.015625, 0.0175781,
|
|
0.0195312, 0.0214844, 0.0234375,
|
|
0.0253906, 0.0273438, 0.0292969,
|
|
0.03125, 0.0351562, 0.0390625,
|
|
0.0429688, 0.046875, 0.0507812,
|
|
0.0546875, 0.0585938, 0.0625,
|
|
0.0703125, 0.078125, 0.0859375,
|
|
0.09375, 0.101562, 0.109375,
|
|
0.117188, 0.125, 0.140625,
|
|
0.15625, 0.171875, 0.1875,
|
|
0.203125, 0.21875, 0.234375,
|
|
0.25, 0.28125, 0.3125,
|
|
0.34375, 0.375, 0.40625,
|
|
0.4375, 0.46875, 0.5,
|
|
0.5625, 0.625, 0.6875,
|
|
0.75, 0.8125, 0.875,
|
|
0.9375, 1, 1.125,
|
|
1.25, 1.375, 1.5,
|
|
1.625, 1.75, 1.875,
|
|
2, 2.25, 2.5,
|
|
2.75, 3, 3.25,
|
|
3.5, 3.75, 4,
|
|
4.5, 5, 5.5,
|
|
6, 6.5, 7,
|
|
7.5, 8, 9,
|
|
10, 11, 12,
|
|
13, 14, 15,
|
|
16, 18, 20,
|
|
22, 24, 26,
|
|
28, 30, 32,
|
|
36, 40, 44,
|
|
48, 52, 56,
|
|
60, 64, 72,
|
|
80, 88, 96,
|
|
104, 112, 120,
|
|
128, 144, 160,
|
|
176, 192, 208,
|
|
224, 240, -0.000976562,
|
|
-0.00195312, -0.00292969, -0.00390625,
|
|
-0.00488281, -0.00585938, -0.00683594,
|
|
-0.0078125, -0.00878906, -0.00976562,
|
|
-0.0107422, -0.0117188, -0.0126953,
|
|
-0.0136719, -0.0146484, -0.015625,
|
|
-0.0175781, -0.0195312, -0.0214844,
|
|
-0.0234375, -0.0253906, -0.0273438,
|
|
-0.0292969, -0.03125, -0.0351562,
|
|
-0.0390625, -0.0429688, -0.046875,
|
|
-0.0507812, -0.0546875, -0.0585938,
|
|
-0.0625, -0.0703125, -0.078125,
|
|
-0.0859375, -0.09375, -0.101562,
|
|
-0.109375, -0.117188, -0.125,
|
|
-0.140625, -0.15625, -0.171875,
|
|
-0.1875, -0.203125, -0.21875,
|
|
-0.234375, -0.25, -0.28125,
|
|
-0.3125, -0.34375, -0.375,
|
|
-0.40625, -0.4375, -0.46875,
|
|
-0.5, -0.5625, -0.625,
|
|
-0.6875, -0.75, -0.8125,
|
|
-0.875, -0.9375, -1,
|
|
-1.125, -1.25, -1.375,
|
|
-1.5, -1.625, -1.75,
|
|
-1.875, -2, -2.25,
|
|
-2.5, -2.75, -3,
|
|
-3.25, -3.5, -3.75,
|
|
-4, -4.5, -5,
|
|
-5.5, -6, -6.5,
|
|
-7, -7.5, -8,
|
|
-9, -10, -11,
|
|
-12, -13, -14,
|
|
-15, -16, -18,
|
|
-20, -22, -24,
|
|
-26, -28, -30,
|
|
-32, -36, -40,
|
|
-44, -48, -52,
|
|
-56, -60, -64,
|
|
-72, -80, -88,
|
|
-96, -104, -112,
|
|
-120, -128, -144,
|
|
-160, -176, -192,
|
|
-208, -224, -240};
|
|
|
|
for (const auto& orig : e4m3_fnuz_nums) {
|
|
__hip_fp8_e4m3_fnuz fp8(orig);
|
|
float cvt1 = fp8;
|
|
|
|
__hip_fp8_e4m3_fnuz tmp;
|
|
tmp.__x = std::is_same<TestType, float>::value
|
|
? __hip_cvt_float_to_fp8(orig, __HIP_SATFINITE, __HIP_E4M3_FNUZ)
|
|
: __hip_cvt_double_to_fp8(orig, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
|
|
;
|
|
float cvt2 = tmp;
|
|
|
|
INFO("Original: " << std::bitset<32>(*reinterpret_cast<const unsigned int*>(&orig)));
|
|
INFO("Cvt back: " << std::bitset<32>(*reinterpret_cast<const unsigned int*>(&cvt1)));
|
|
REQUIRE(cvt1 == Approx(orig));
|
|
REQUIRE(cvt2 == cvt1);
|
|
}
|
|
}
|
|
|
|
SECTION("e5m2_fnuz") {
|
|
/* These are basically all the fp8 - e5m2_fnuz type numbers.
|
|
* They can be generated by iterating over 0'00000'00 converting them to fp32 number skipping
|
|
* the nan/inf */
|
|
std::vector<TestType> e5m2_fnuz_nums = {0,
|
|
7.62939e-06,
|
|
1.52588e-05,
|
|
2.28882e-05,
|
|
3.05176e-05,
|
|
3.8147e-05,
|
|
4.57764e-05,
|
|
5.34058e-05,
|
|
6.10352e-05,
|
|
7.62939e-05,
|
|
9.15527e-05,
|
|
0.000106812,
|
|
0.00012207,
|
|
0.000152588,
|
|
0.000183105,
|
|
0.000213623,
|
|
0.000244141,
|
|
0.000305176,
|
|
0.000366211,
|
|
0.000427246,
|
|
0.000488281,
|
|
0.000610352,
|
|
0.000732422,
|
|
0.000854492,
|
|
0.000976562,
|
|
0.0012207,
|
|
0.00146484,
|
|
0.00170898,
|
|
0.00195312,
|
|
0.00244141,
|
|
0.00292969,
|
|
0.00341797,
|
|
0.00390625,
|
|
0.00488281,
|
|
0.00585938,
|
|
0.00683594,
|
|
0.0078125,
|
|
0.00976562,
|
|
0.0117188,
|
|
0.0136719,
|
|
0.015625,
|
|
0.0195312,
|
|
0.0234375,
|
|
0.0273438,
|
|
0.03125,
|
|
0.0390625,
|
|
0.046875,
|
|
0.0546875,
|
|
0.0625,
|
|
0.078125,
|
|
0.09375,
|
|
0.109375,
|
|
0.125,
|
|
0.15625,
|
|
0.1875,
|
|
0.21875,
|
|
0.25,
|
|
0.3125,
|
|
0.375,
|
|
0.4375,
|
|
0.5,
|
|
0.625,
|
|
0.75,
|
|
0.875,
|
|
1,
|
|
1.25,
|
|
1.5,
|
|
1.75,
|
|
2,
|
|
2.5,
|
|
3,
|
|
3.5,
|
|
4,
|
|
5,
|
|
6,
|
|
7,
|
|
8,
|
|
10,
|
|
12,
|
|
14,
|
|
16,
|
|
20,
|
|
24,
|
|
28,
|
|
32,
|
|
40,
|
|
48,
|
|
56,
|
|
64,
|
|
80,
|
|
96,
|
|
112,
|
|
128,
|
|
160,
|
|
192,
|
|
224,
|
|
256,
|
|
320,
|
|
384,
|
|
448,
|
|
512,
|
|
640,
|
|
768,
|
|
896,
|
|
1024,
|
|
1280,
|
|
1536,
|
|
1792,
|
|
2048,
|
|
2560,
|
|
3072,
|
|
3584,
|
|
4096,
|
|
5120,
|
|
6144,
|
|
7168,
|
|
8192,
|
|
10240,
|
|
12288,
|
|
14336,
|
|
16384,
|
|
20480,
|
|
24576,
|
|
28672,
|
|
32768,
|
|
40960,
|
|
49152,
|
|
57344,
|
|
-7.62939e-06,
|
|
-1.52588e-05,
|
|
-2.28882e-05,
|
|
-3.05176e-05,
|
|
-3.8147e-05,
|
|
-4.57764e-05,
|
|
-5.34058e-05,
|
|
-6.10352e-05,
|
|
-7.62939e-05,
|
|
-9.15527e-05,
|
|
-0.000106812,
|
|
-0.00012207,
|
|
-0.000152588,
|
|
-0.000183105,
|
|
-0.000213623,
|
|
-0.000244141,
|
|
-0.000305176,
|
|
-0.000366211,
|
|
-0.000427246,
|
|
-0.000488281,
|
|
-0.000610352,
|
|
-0.000732422,
|
|
-0.000854492,
|
|
-0.000976562,
|
|
-0.0012207,
|
|
-0.00146484,
|
|
-0.00170898,
|
|
-0.00195312,
|
|
-0.00244141,
|
|
-0.00292969,
|
|
-0.00341797,
|
|
-0.00390625,
|
|
-0.00488281,
|
|
-0.00585938,
|
|
-0.00683594,
|
|
-0.0078125,
|
|
-0.00976562,
|
|
-0.0117188,
|
|
-0.0136719,
|
|
-0.015625,
|
|
-0.0195312,
|
|
-0.0234375,
|
|
-0.0273438,
|
|
-0.03125,
|
|
-0.0390625,
|
|
-0.046875,
|
|
-0.0546875,
|
|
-0.0625,
|
|
-0.078125,
|
|
-0.09375,
|
|
-0.109375,
|
|
-0.125,
|
|
-0.15625,
|
|
-0.1875,
|
|
-0.21875,
|
|
-0.25,
|
|
-0.3125,
|
|
-0.375,
|
|
-0.4375,
|
|
-0.5,
|
|
-0.625,
|
|
-0.75,
|
|
-0.875,
|
|
-1,
|
|
-1.25,
|
|
-1.5,
|
|
-1.75,
|
|
-2,
|
|
-2.5,
|
|
-3,
|
|
-3.5,
|
|
-4,
|
|
-5,
|
|
-6,
|
|
-7,
|
|
-8,
|
|
-10,
|
|
-12,
|
|
-14,
|
|
-16,
|
|
-20,
|
|
-24,
|
|
-28,
|
|
-32,
|
|
-40,
|
|
-48,
|
|
-56,
|
|
-64,
|
|
-80,
|
|
-96,
|
|
-112,
|
|
-128,
|
|
-160,
|
|
-192,
|
|
-224,
|
|
-256,
|
|
-320,
|
|
-384,
|
|
-448,
|
|
-512,
|
|
-640,
|
|
-768,
|
|
-896,
|
|
-1024,
|
|
-1280,
|
|
-1536,
|
|
-1792,
|
|
-2048,
|
|
-2560,
|
|
-3072,
|
|
-3584,
|
|
-4096,
|
|
-5120,
|
|
-6144,
|
|
-7168,
|
|
-8192,
|
|
-10240,
|
|
-12288,
|
|
-14336,
|
|
-16384,
|
|
-20480,
|
|
-24576,
|
|
-28672,
|
|
-32768,
|
|
-40960,
|
|
-49152,
|
|
-57344};
|
|
|
|
for (const auto& orig : e5m2_fnuz_nums) {
|
|
__hip_fp8_e5m2_fnuz fp8(orig);
|
|
float cvt1 = fp8;
|
|
|
|
__hip_fp8_e5m2_fnuz tmp;
|
|
tmp.__x = std::is_same<TestType, float>::value
|
|
? __hip_cvt_float_to_fp8(orig, __HIP_SATFINITE, __HIP_E5M2_FNUZ)
|
|
: __hip_cvt_double_to_fp8(orig, __HIP_SATFINITE, __HIP_E5M2_FNUZ);
|
|
;
|
|
float cvt2 = tmp;
|
|
|
|
INFO("Original: " << std::bitset<32>(*reinterpret_cast<const unsigned int*>(&orig)));
|
|
INFO("Cvt back: " << std::bitset<32>(*reinterpret_cast<const unsigned int*>(&cvt1)));
|
|
REQUIRE(cvt1 == Approx(orig));
|
|
REQUIRE(cvt1 == cvt2);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check the orientation encoded is correct
|
|
TEST_CASE("Unit_fp8_fnuz_vector_basic_conversions") {
|
|
float f1 = 1.0f;
|
|
float2 f2 = {1.0f, 2.0f};
|
|
float4 f4 = {1.0f, 2.0f, 3.0f, 4.0f};
|
|
|
|
SECTION("e4m3-fnuz cvt float") {
|
|
__hip_fp8_e4m3_fnuz f8_1 = f1;
|
|
__hip_fp8x2_e4m3_fnuz f8_2 = f2;
|
|
__hip_fp8x4_e4m3_fnuz f8_4 = f4;
|
|
|
|
float cf1 = f8_1;
|
|
float2 cf2 = f8_2;
|
|
float4 cf4 = f8_4;
|
|
|
|
__hip_fp8x2_e4m3_fnuz tmp;
|
|
tmp.__x = __hip_cvt_float2_to_fp8x2(cf2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
|
|
float2 xtmp = tmp;
|
|
|
|
REQUIRE(f1 == cf1);
|
|
REQUIRE(f2 == cf2);
|
|
REQUIRE(f4 == cf4);
|
|
|
|
REQUIRE(xtmp == f2);
|
|
}
|
|
|
|
SECTION("e5m2-fnuz cvt float") {
|
|
__hip_fp8_e5m2_fnuz f8_1 = f1;
|
|
__hip_fp8x2_e5m2_fnuz f8_2 = f2;
|
|
__hip_fp8x4_e5m2_fnuz f8_4 = f4;
|
|
|
|
float cf1 = f8_1;
|
|
float2 cf2 = f8_2;
|
|
float4 cf4 = f8_4;
|
|
|
|
__hip_fp8x2_e5m2_fnuz tmp;
|
|
tmp.__x = __hip_cvt_float2_to_fp8x2(cf2, __HIP_SATFINITE, __HIP_E5M2_FNUZ);
|
|
float2 xtmp = tmp;
|
|
|
|
REQUIRE(f1 == cf1);
|
|
REQUIRE(f2 == cf2);
|
|
REQUIRE(f4 == cf4);
|
|
|
|
REQUIRE(xtmp == f2);
|
|
}
|
|
|
|
SECTION("e4m3-fnuz cvt double") {
|
|
double d1 = f1;
|
|
double2 d2 = {f2.x, f2.y};
|
|
double4 d4 = {f4.x, f4.y, f4.z, f4.w};
|
|
__hip_fp8_e4m3_fnuz f8_1 = d1;
|
|
__hip_fp8x2_e4m3_fnuz f8_2 = d2;
|
|
__hip_fp8x4_e4m3_fnuz f8_4 = d4;
|
|
|
|
double cf1 = f8_1;
|
|
float2 cf2 = f8_2;
|
|
float4 cf4 = f8_4;
|
|
|
|
__hip_fp8x2_e4m3_fnuz tmp;
|
|
tmp.__x = __hip_cvt_double2_to_fp8x2(d2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
|
|
float2 xtmp = tmp;
|
|
|
|
REQUIRE(d1 == cf1);
|
|
REQUIRE(d2 == double2{cf2.x, cf2.y});
|
|
REQUIRE(d4 == double4{cf4.x, cf4.y, cf4.z, cf4.w});
|
|
|
|
REQUIRE(double2{xtmp.x, xtmp.y} == d2);
|
|
}
|
|
|
|
SECTION("e5m2-fnuz cvt double") {
|
|
double d1 = f1;
|
|
double2 d2 = {f2.x, f2.y};
|
|
double4 d4 = {f4.x, f4.y, f4.z, f4.w};
|
|
__hip_fp8_e5m2_fnuz f8_1 = d1;
|
|
__hip_fp8x2_e5m2_fnuz f8_2 = d2;
|
|
__hip_fp8x4_e5m2_fnuz f8_4 = d4;
|
|
|
|
double cf1 = f8_1;
|
|
float2 cf2 = f8_2;
|
|
float4 cf4 = f8_4;
|
|
|
|
__hip_fp8x2_e5m2_fnuz tmp;
|
|
tmp.__x = __hip_cvt_double2_to_fp8x2(d2, __HIP_SATFINITE, __HIP_E5M2_FNUZ);
|
|
float2 xtmp = tmp;
|
|
|
|
REQUIRE(d1 == cf1);
|
|
REQUIRE(d2 == double2{cf2.x, cf2.y});
|
|
REQUIRE(d4 == double4{cf4.x, cf4.y, cf4.z, cf4.w});
|
|
|
|
REQUIRE(double2{xtmp.x, xtmp.y} == d2);
|
|
}
|
|
|
|
SECTION("e4m3-fnuz half2/bfloat162") {
|
|
auto bf16_val = __float22bfloat162_rn(f2);
|
|
auto half2_val = __float22half2_rn(f2);
|
|
|
|
__hip_fp8x2_e4m3_fnuz x1(bf16_val);
|
|
__hip_fp8x2_e4m3_fnuz x2(half2_val);
|
|
|
|
__hip_fp8x2_e4m3_fnuz tmp1;
|
|
tmp1.__x = __hip_cvt_bfloat16raw2_to_fp8x2(bf16_val, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
|
|
float2 bf2_1 = tmp1;
|
|
|
|
tmp1.__x = __hip_cvt_halfraw2_to_fp8x2(half2_val, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
|
|
float2 h2_1 = tmp1;
|
|
|
|
float2 f2_1 = x1;
|
|
float2 f2_2 = x2;
|
|
|
|
REQUIRE(f2_1 == f2);
|
|
REQUIRE(f2_2 == f2);
|
|
|
|
REQUIRE(f2 == bf2_1);
|
|
REQUIRE(f2 == h2_1);
|
|
}
|
|
|
|
SECTION("e5m2-fnuz half2/bfloat162") {
|
|
auto bf16_val = __float22bfloat162_rn(f2);
|
|
auto half2_val = __float22half2_rn(f2);
|
|
|
|
__hip_fp8x2_e5m2_fnuz x1(bf16_val);
|
|
__hip_fp8x2_e5m2_fnuz x2(half2_val);
|
|
|
|
__hip_fp8x2_e5m2_fnuz tmp1;
|
|
tmp1.__x = __hip_cvt_bfloat16raw2_to_fp8x2(bf16_val, __HIP_SATFINITE, __HIP_E5M2_FNUZ);
|
|
float2 bf2_1 = tmp1;
|
|
|
|
tmp1.__x = __hip_cvt_halfraw2_to_fp8x2(half2_val, __HIP_SATFINITE, __HIP_E5M2_FNUZ);
|
|
float2 h2_1 = tmp1;
|
|
|
|
float2 f2_1 = x1;
|
|
float2 f2_2 = x2;
|
|
|
|
REQUIRE(f2_1 == f2);
|
|
REQUIRE(f2_2 == f2);
|
|
|
|
REQUIRE(f2 == bf2_1);
|
|
REQUIRE(f2 == h2_1);
|
|
}
|
|
}
|