half/bfloat/convert.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
use crate::leading_zeros::leading_zeros_u16;
use core::mem;
#[inline]
pub(crate) const fn f32_to_bf16(value: f32) -> u16 {
// TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized
// Convert to raw bytes
let x: u32 = unsafe { mem::transmute::<f32, u32>(value) };
// check for NaN
if x & 0x7FFF_FFFFu32 > 0x7F80_0000u32 {
// Keep high part of current mantissa but also set most significiant mantissa bit
return ((x >> 16) | 0x0040u32) as u16;
}
// round and shift
let round_bit = 0x0000_8000u32;
if (x & round_bit) != 0 && (x & (3 * round_bit - 1)) != 0 {
(x >> 16) as u16 + 1
} else {
(x >> 16) as u16
}
}
#[inline]
pub(crate) const fn f64_to_bf16(value: f64) -> u16 {
// TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized
// Convert to raw bytes, truncating the last 32-bits of mantissa; that precision will always
// be lost on half-precision.
let val: u64 = unsafe { mem::transmute::<f64, u64>(value) };
let x = (val >> 32) as u32;
// Extract IEEE754 components
let sign = x & 0x8000_0000u32;
let exp = x & 0x7FF0_0000u32;
let man = x & 0x000F_FFFFu32;
// Check for all exponent bits being set, which is Infinity or NaN
if exp == 0x7FF0_0000u32 {
// Set mantissa MSB for NaN (and also keep shifted mantissa bits).
// We also have to check the last 32 bits.
let nan_bit = if man == 0 && (val as u32 == 0) {
0
} else {
0x0040u32
};
return ((sign >> 16) | 0x7F80u32 | nan_bit | (man >> 13)) as u16;
}
// The number is normalized, start assembling half precision version
let half_sign = sign >> 16;
// Unbias the exponent, then bias for bfloat16 precision
let unbiased_exp = ((exp >> 20) as i64) - 1023;
let half_exp = unbiased_exp + 127;
// Check for exponent overflow, return +infinity
if half_exp >= 0xFF {
return (half_sign | 0x7F80u32) as u16;
}
// Check for underflow
if half_exp <= 0 {
// Check mantissa for what we can do
if 7 - half_exp > 21 {
// No rounding possibility, so this is a full underflow, return signed zero
return half_sign as u16;
}
// Don't forget about hidden leading mantissa bit when assembling mantissa
let man = man | 0x0010_0000u32;
let mut half_man = man >> (14 - half_exp);
// Check for rounding
let round_bit = 1 << (13 - half_exp);
if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
half_man += 1;
}
// No exponent for subnormals
return (half_sign | half_man) as u16;
}
// Rebias the exponent
let half_exp = (half_exp as u32) << 7;
let half_man = man >> 13;
// Check for rounding
let round_bit = 0x0000_1000u32;
if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
// Round it
((half_sign | half_exp | half_man) + 1) as u16
} else {
(half_sign | half_exp | half_man) as u16
}
}
#[inline]
pub(crate) const fn bf16_to_f32(i: u16) -> f32 {
// TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized
// If NaN, keep current mantissa but also set most significiant mantissa bit
if i & 0x7FFFu16 > 0x7F80u16 {
unsafe { mem::transmute::<u32, f32>((i as u32 | 0x0040u32) << 16) }
} else {
unsafe { mem::transmute::<u32, f32>((i as u32) << 16) }
}
}
#[inline]
pub(crate) const fn bf16_to_f64(i: u16) -> f64 {
// TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized
// Check for signed zero
if i & 0x7FFFu16 == 0 {
return unsafe { mem::transmute::<u64, f64>((i as u64) << 48) };
}
let half_sign = (i & 0x8000u16) as u64;
let half_exp = (i & 0x7F80u16) as u64;
let half_man = (i & 0x007Fu16) as u64;
// Check for an infinity or NaN when all exponent bits set
if half_exp == 0x7F80u64 {
// Check for signed infinity if mantissa is zero
if half_man == 0 {
return unsafe {
mem::transmute::<u64, f64>((half_sign << 48) | 0x7FF0_0000_0000_0000u64)
};
} else {
// NaN, keep current mantissa but also set most significiant mantissa bit
return unsafe {
mem::transmute::<u64, f64>(
(half_sign << 48) | 0x7FF8_0000_0000_0000u64 | (half_man << 45),
)
};
}
}
// Calculate double-precision components with adjusted exponent
let sign = half_sign << 48;
// Unbias exponent
let unbiased_exp = ((half_exp as i64) >> 7) - 127;
// Check for subnormals, which will be normalized by adjusting exponent
if half_exp == 0 {
// Calculate how much to adjust the exponent by
let e = leading_zeros_u16(half_man as u16) - 9;
// Rebias and adjust exponent
let exp = ((1023 - 127 - e) as u64) << 52;
let man = (half_man << (46 + e)) & 0xF_FFFF_FFFF_FFFFu64;
return unsafe { mem::transmute::<u64, f64>(sign | exp | man) };
}
// Rebias exponent for a normalized normal
let exp = ((unbiased_exp + 1023) as u64) << 52;
let man = (half_man & 0x007Fu64) << 45;
unsafe { mem::transmute::<u64, f64>(sign | exp | man) }
}