diff --git a/src/utils/512Math.sol b/src/utils/512Math.sol index f7a58bddd..16c01df34 100644 --- a/src/utils/512Math.sol +++ b/src/utils/512Math.sol @@ -148,20 +148,42 @@ WARNING *** WARNING *** WARNING *** WARNING *** WARNING *** WARNING *** WARNING /// ### Division /// /// * div(uint512,uint256) returns (uint256) +/// * divUp(uint512,uint256) returns (uint256) /// * div(uint512,uint512) returns (uint256) +/// * divUp(uint512,uint512) returns (uint256) /// * odiv(uint512,uint512,uint256) /// * idiv(uint512,uint256) +/// * odivUp(uint512,uint512,uint256) +/// * idivUp(uint512,uint256) /// * odiv(uint512,uint512,uint512) /// * idiv(uint512,uint512) /// * irdiv(uint512,uint512) +/// * odivUp(uint512,uint512,uint512) +/// * idivUp(uint512,uint512) +/// * irdivUp(uint512,uint512) /// * divAlt(uint512,uint512) returns (uint256) -- divAlt(uint512,uint256) is not provided because div(uint512,uint256) is suitable for chains without MODEXP /// * odivAlt(uint512,uint512,uint512) /// * idivAlt(uint512,uint512) /// * irdivAlt(uint512,uint512) +/// * divUpAlt(uint512,uint512) returns (uint256) +/// * odivUpAlt(uint512,uint512,uint512) +/// * idivUpAlt(uint512,uint512) +/// * irdivUpAlt(uint512,uint512) /// /// ### Square root /// /// * sqrt(uint512) returns (uint256) +/// * osqrtUp(uint512,uint512) +/// * isqrtUp(uint512) +/// +/// ### Shifting +/// +/// * oshr(uint512,uint512,uint256) +/// * ishr(uint512,uint256) +/// * oshrUp(uint512,uint512,uint256) +/// * ishrUp(uint512,uint256) +/// * oshl(uint512,uint512,uint256) +/// * ishl(uint512,uint256) type uint512 is bytes32; function alloc() pure returns (uint512 r) { @@ -581,11 +603,15 @@ library Lib512MathArithmetic { //// adapted from Remco Bloemen's work https://2π.com/21/muldiv/ . //// The original code was released under the MIT license. - function _roundDown(uint256 x_hi, uint256 x_lo, uint256 d) private pure returns (uint256 r_hi, uint256 r_lo) { + function _roundDown(uint256 x_hi, uint256 x_lo, uint256 d) + private + pure + returns (uint256 r_hi, uint256 r_lo, uint256 rem) + { assembly ("memory-safe") { // Get the remainder [n_hi n_lo] % d (< 2²⁵⁶ - 1) // 2**256 % d = -d % 2**256 % d -- https://2π.com/17/512-bit-division/ - let rem := mulmod(x_hi, sub(0x00, d), d) + rem := mulmod(x_hi, sub(0x00, d), d) rem := addmod(x_lo, rem, d) r_hi := sub(x_hi, gt(rem, x_lo)) @@ -598,7 +624,7 @@ library Lib512MathArithmetic { function _roundDown(uint256 x_hi, uint256 x_lo, uint256 d_hi, uint256 d_lo) private view - returns (uint256 r_hi, uint256 r_lo) + returns (uint256 r_hi, uint256 r_lo, uint256 rem_hi, uint256 rem_lo) { uint512 r; assembly ("memory-safe") { @@ -632,7 +658,7 @@ library Lib512MathArithmetic { // to check for failure. pop(staticcall(gas(), 0x05, r, 0x100, r, 0x40)) } - (uint256 rem_hi, uint256 rem_lo) = r.into(); + (rem_hi, rem_lo) = r.into(); // Round down by subtracting the remainder from the numerator (r_hi, r_lo) = _sub(x_hi, x_lo, rem_hi, rem_lo); } @@ -754,7 +780,7 @@ library Lib512MathArithmetic { function _div(uint256 n_hi, uint256 n_lo, uint256 d) private pure returns (uint256) { // Round the numerator down to a multiple of the denominator. This makes // the division exact without affecting the result. - (n_hi, n_lo) = _roundDown(n_hi, n_lo, d); + (n_hi, n_lo,) = _roundDown(n_hi, n_lo, d); // Make `d` odd so that it has a multiplicative inverse mod 2²⁵⁶. // After this we can discard `n_hi` because our result is only 256 bits @@ -773,6 +799,33 @@ library Lib512MathArithmetic { } } + function _divUp(uint256 n_hi, uint256 n_lo, uint256 d) private pure returns (uint256) { + // Round the numerator down to a multiple of the denominator. This makes + // the division exact without affecting the result. Store the remainder + // for later to determine whether we must increment the result in order + // to round up. + uint256 rem; + (n_hi, n_lo, rem) = _roundDown(n_hi, n_lo, d); + + // Make `d` odd so that it has a multiplicative inverse mod 2²⁵⁶. + // After this we can discard `n_hi` because our result is only 256 bits + (n_lo, d) = _toOdd256(n_hi, n_lo, d); + + // We perform division by multiplying by the multiplicative inverse of + // the denominator mod 2²⁵⁶. Since `d` is odd, this inverse + // exists. Compute that inverse + d = _invert256(d); + + unchecked { + // Because the division is now exact (we rounded `n` down to a + // multiple of `d`), we perform it by multiplying with the modular + // inverse of the denominator. This is the floor of the division, + // mod 2²⁵⁶. To obtain the ceiling, we conditionally add 1 if the + // remainder was nonzero. + return (n_lo * d).unsafeInc(0 < rem); + } + } + function div(uint512 n, uint256 d) internal pure returns (uint256) { if (d == 0) { Panic.panic(Panic.DIVISION_BY_ZERO); @@ -786,6 +839,19 @@ library Lib512MathArithmetic { return _div(n_hi, n_lo, d); } + function divUp(uint512 n, uint256 d) internal pure returns (uint256) { + if (d == 0) { + Panic.panic(Panic.DIVISION_BY_ZERO); + } + + (uint256 n_hi, uint256 n_lo) = n.into(); + if (n_hi == 0) { + return n_lo.unsafeDivUp(d); + } + + return _divUp(n_hi, n_lo, d); + } + function _gt(uint256 x_hi, uint256 x_lo, uint256 y_hi, uint256 y_lo) private pure returns (bool r) { assembly ("memory-safe") { r := or(gt(x_hi, y_hi), and(eq(x_hi, y_hi), gt(x_lo, y_lo))) @@ -808,7 +874,7 @@ library Lib512MathArithmetic { // Round the numerator down to a multiple of the denominator. This makes // the division exact without affecting the result. - (n_hi, n_lo) = _roundDown(n_hi, n_lo, d_hi, d_lo); + (n_hi, n_lo,,) = _roundDown(n_hi, n_lo, d_hi, d_lo); // Make `d_lo` odd so that it has a multiplicative inverse mod 2²⁵⁶. // After this we can discard `n_hi` and `d_hi` because our result is @@ -828,6 +894,43 @@ library Lib512MathArithmetic { } } + function divUp(uint512 n, uint512 d) internal view returns (uint256) { + (uint256 d_hi, uint256 d_lo) = d.into(); + if (d_hi == 0) { + return divUp(n, d_lo); + } + (uint256 n_hi, uint256 n_lo) = n.into(); + if (d_lo == 0) { + return n_hi.unsafeDiv(d_hi).unsafeInc(0 < (n_lo | n_hi.unsafeMod(d_hi))); + } + + // Round the numerator down to a multiple of the denominator. This makes + // the division exact without affecting the result. Save the remainder + // for later to determine whether we need to increment to round up. + uint256 rem_hi; + uint256 rem_lo; + (n_hi, n_lo, rem_hi, rem_lo) = _roundDown(n_hi, n_lo, d_hi, d_lo); + + // Make `d_lo` odd so that it has a multiplicative inverse mod 2²⁵⁶. + // After this we can discard `n_hi` and `d_hi` because our result is + // only 256 bits + (n_lo, d_lo) = _toOdd256(n_hi, n_lo, d_hi, d_lo); + + // We perform division by multiplying by the multiplicative inverse of + // the denominator mod 2²⁵⁶. Since `d_lo` is odd, this inverse + // exists. Compute that inverse + d_lo = _invert256(d_lo); + + unchecked { + // Because the division is now exact (we rounded `n` down to a + // multiple of `d`), we perform it by multiplying with the modular + // inverse of the denominator. This is the floor of the division, + // mod 2²⁵⁶. To obtain the ceiling, we conditionally add 1 if the + // remainder was nonzero. + return (n_lo * d_lo).unsafeInc(0 < (rem_hi | rem_lo)); + } + } + function odiv(uint512 r, uint512 x, uint256 y) internal pure returns (uint512) { if (y == 0) { Panic.panic(Panic.DIVISION_BY_ZERO); @@ -847,7 +950,7 @@ library Lib512MathArithmetic { // Round the numerator down to a multiple of the denominator. This makes // the division exact without affecting the result. - (x_hi, x_lo) = _roundDown(x_hi, x_lo, y); + (x_hi, x_lo,) = _roundDown(x_hi, x_lo, y); // Make `y` odd so that it has a multiplicative inverse mod 2²⁵⁶. After // this we can discard `x_hi` because we have already obtained the upper @@ -875,6 +978,58 @@ library Lib512MathArithmetic { return odiv(r, r, y); } + function odivUp(uint512 r, uint512 x, uint256 y) internal pure returns (uint512) { + if (y == 0) { + Panic.panic(Panic.DIVISION_BY_ZERO); + } + + (uint256 x_hi, uint256 x_lo) = x.into(); + if (x_hi == 0) { + return r.from(0, x_lo.unsafeDivUp(y)); + } + + // The upper word of the quotient is straightforward. We can use + // "normal" division to obtain it. The remainder after that division + // must be carried forward to the later steps, however, because the next + // operation we perform is a `mulmod` of `x_hi` with `y`, there's no + // need to reduce `x_hi` mod `y` as would be ordinarily expected. + uint256 r_hi = x_hi.unsafeDiv(y); + + // Round the numerator down to a multiple of the denominator. This makes + // the division exact without affecting the result. Save the remainder + // for later to determine whether we need to increment to round up. + uint256 rem; + (x_hi, x_lo, rem) = _roundDown(x_hi, x_lo, y); + + // Make `y` odd so that it has a multiplicative inverse mod 2²⁵⁶. After + // this we can discard `x_hi` because we have already obtained the upper + // word. + (x_lo, y) = _toOdd256(x_hi, x_lo, y); + + // The lower word of the quotient is obtained from division by + // multiplying by the multiplicative inverse of the denominator mod + // 2²⁵⁶. Since `y` is odd, this inverse exists. Compute that inverse + y = _invert256(y); + + uint256 r_lo; + unchecked { + // Because the division is now exact (we rounded `x` down to a + // multiple of the original `y`), we perform it by multiplying with + // the modular inverse of the denominator. This is the floor of the + // division, mod 2²⁵⁶. + r_lo = x_lo * y; + } + // To obtain the ceiling, we conditionally add 1 if the remainder was + // nonzero. + (r_hi, r_lo) = _add(r_hi, r_lo, (0 < rem).toUint()); + + return r.from(r_hi, r_lo); + } + + function idivUp(uint512 r, uint256 y) internal pure returns (uint512) { + return odivUp(r, r, y); + } + function odiv(uint512 r, uint512 x, uint512 y) internal view returns (uint512) { (uint256 y_hi, uint256 y_lo) = y.into(); if (y_hi == 0) { @@ -891,7 +1046,7 @@ library Lib512MathArithmetic { // Round the numerator down to a multiple of the denominator. This makes // the division exact without affecting the result. - (x_hi, x_lo) = _roundDown(x_hi, x_lo, y_hi, y_lo); + (x_hi, x_lo,,) = _roundDown(x_hi, x_lo, y_hi, y_lo); // Make `y` odd so that it has a multiplicative inverse mod 2⁵¹² (x_hi, x_lo, y_hi, y_lo) = _toOdd512(x_hi, x_lo, y_hi, y_lo); @@ -916,17 +1071,62 @@ library Lib512MathArithmetic { return odiv(r, y, r); } + function odivUp(uint512 r, uint512 x, uint512 y) internal view returns (uint512) { + (uint256 y_hi, uint256 y_lo) = y.into(); + if (y_hi == 0) { + return odivUp(r, x, y_lo); + } + (uint256 x_hi, uint256 x_lo) = x.into(); + if (y_lo == 0) { + (uint256 r_hi_, uint256 r_lo_) = _add(0, x_hi.unsafeDiv(y_hi), (0 < (x_lo | x_hi.unsafeMod(y_hi))).toUint()); + return r.from(r_hi_, r_lo_); + } + + // Round the numerator down to a multiple of the denominator. This makes + // the division exact without affecting the result. Save the remainder + // for later to determine whether we need to increment to round up. + uint256 rem_hi; + uint256 rem_lo; + (x_hi, x_lo, rem_hi, rem_lo) = _roundDown(x_hi, x_lo, y_hi, y_lo); + + // Make `y` odd so that it has a multiplicative inverse mod 2⁵¹² + (x_hi, x_lo, y_hi, y_lo) = _toOdd512(x_hi, x_lo, y_hi, y_lo); + + // We perform division by multiplying by the multiplicative inverse of + // the denominator mod 2⁵¹². Since `y` is odd, this inverse + // exists. Compute that inverse + (y_hi, y_lo) = _invert512(y_hi, y_lo); + + // Because the division is now exact (we rounded `x` down to a multiple + // of `y`), we perform it by multiplying with the modular inverse of the + // denominator. This is the floor of the division. + (uint256 r_hi, uint256 r_lo) = _mul(x_hi, x_lo, y_hi, y_lo); + + // To obtain the ceiling, we conditionally add 1 if the remainder was + // nonzero. + (r_hi, r_lo) = _add(r_hi, r_lo, (0 < (rem_hi | rem_lo)).toUint()); + + return r.from(r_hi, r_lo); + } + + function idivUp(uint512 r, uint512 y) internal view returns (uint512) { + return odivUp(r, r, y); + } + + function irdivUp(uint512 r, uint512 y) internal view returns (uint512) { + return odivUp(r, y, r); + } + function _gt(uint256 x_ex, uint256 x_hi, uint256 x_lo, uint256 y_ex, uint256 y_hi, uint256 y_lo) private pure returns (bool r) { assembly ("memory-safe") { - r := - or( - or(gt(x_ex, y_ex), and(eq(x_ex, y_ex), gt(x_hi, y_hi))), - and(and(eq(x_ex, y_ex), eq(x_hi, y_hi)), gt(x_lo, y_lo)) - ) + r := or( + or(gt(x_ex, y_ex), and(eq(x_ex, y_ex), gt(x_hi, y_hi))), + and(and(eq(x_ex, y_ex), eq(x_hi, y_hi)), gt(x_lo, y_lo)) + ) } } @@ -963,16 +1163,13 @@ library Lib512MathArithmetic { // y is 4 limbs, x is 4 limbs, q is 1 limb // Normalize. Ensure the uppermost limb of y ≥ 2¹²⁷ (equivalently - // y_hi >= 2**255). This is step D1 of Algorithm D - // The author's copy of TAOCP (3rd edition) states to set `d = (2 ** - // 128 - 1) // y_hi`, however this is incorrect. Setting `d` in this - // fashion may result in overflow in the subsequent `_mul`. Setting - // `d` as implemented below still satisfies the postcondition (`y_hi - // >> 128 >= 1 << 127`) but never results in overflow. - uint256 d = uint256(1 << 128).unsafeDiv((y_hi >> 128).unsafeInc()); + // y_hi >= 2**255). This is step D1 of Algorithm D. We use `CLZ` to + // find the shift amount, then shift both `x` and `y` left. This is + // more gas-efficient than multiplication-based normalization. + uint256 s = y_hi.clz(); uint256 x_ex; - (x_ex, x_hi, x_lo) = _mul768(x_hi, x_lo, d); - (y_hi, y_lo) = _mul(y_hi, y_lo, d); + (x_ex, x_hi, x_lo) = _shl256(x_hi, x_lo, s); + (, y_hi, y_lo) = _shl256(y_hi, y_lo, s); // `n_approx` is the 2 most-significant limbs of x, after // normalization @@ -1006,9 +1203,9 @@ library Lib512MathArithmetic { // y is 3 limbs // Normalize. Ensure the most significant limb of y ≥ 2¹²⁷ (step D1) - // See above comment about the error in TAOCP. - uint256 d = uint256(1 << 128).unsafeDiv(y_hi.unsafeInc()); - (y_hi, y_lo) = _mul(y_hi, y_lo, d); + // We use `CLZ` to find the shift amount for normalization + uint256 s = (y_hi << 128).clz(); + (, y_hi, y_lo) = _shl256(y_hi, y_lo, s); // `y_next` is the second-most-significant, nonzero, normalized limb // of y uint256 y_next = y_lo >> 128; @@ -1021,7 +1218,7 @@ library Lib512MathArithmetic { // Finish normalizing (step D1) uint256 x_ex; - (x_ex, x_hi, x_lo) = _mul768(x_hi, x_lo, d); + (x_ex, x_hi, x_lo) = _shl256(x_hi, x_lo, s); uint256 n_approx = (x_ex << 128) | (x_hi >> 128); // As before, `q_hat` is the most significant limb of the @@ -1076,7 +1273,7 @@ library Lib512MathArithmetic { // x is 3 limbs, q is 1 limb // Finish normalizing (step D1) - (x_hi, x_lo) = _mul(x_hi, x_lo, d); + (, x_hi, x_lo) = _shl256(x_hi, x_lo, s); // `q` is the most significant (and only) limb of the quotient // and too high by at most 3 (step D3) @@ -1119,6 +1316,20 @@ library Lib512MathArithmetic { } } + function _shl(uint256 x_lo, uint256 s) private pure returns (uint256 r_hi, uint256 r_lo) { + (r_hi, r_lo) = _shl256(x_lo, s); + unchecked { + r_hi |= x_lo << s - 256; + } + } + + function _shl(uint256 x_hi, uint256 x_lo, uint256 s) private pure returns (uint256 r_hi, uint256 r_lo) { + (, r_hi, r_lo) = _shl256(x_hi, x_lo, s); + unchecked { + r_hi |= x_lo << s - 256; + } + } + function _shr256(uint256 x_hi, uint256 x_lo, uint256 s) private pure returns (uint256 r_hi, uint256 r_lo) { assembly ("memory-safe") { r_hi := shr(s, x_hi) @@ -1371,7 +1582,68 @@ library Lib512MathArithmetic { // At this point, we know that both `x` and `y` are fully represented by // 2 words. There is no simpler representation for the problem. We must // use Knuth's Algorithm D. - return _algorithmD(x_hi, x_lo, y_hi, y_lo); + uint256 q = _algorithmD(x_hi, x_lo, y_hi, y_lo); + return q; + } + + function divUpAlt(uint512 x, uint512 y) internal pure returns (uint256) { + (uint256 y_hi, uint256 y_lo) = y.into(); + if (y_hi == 0) { + return divUp(x, y_lo); + } + (uint256 x_hi, uint256 x_lo) = x.into(); + if (y_lo == 0) { + return x_hi.unsafeDiv(y_hi).unsafeInc(0 < (x_lo | x_hi.unsafeMod(y_hi))); + } + if (_gt(y_hi, y_lo, x_hi, x_lo)) { + return (0 < (x_hi | x_lo)).toUint(); + } + + // At this point, we know that both `x` and `y` are fully represented by + // 2 words. There is no simpler representation for the problem. We must + // use Knuth's Algorithm D. + uint256 q = _algorithmD(x_hi, x_lo, y_hi, y_lo); + + // If the division was not exact, then we must round up. This is more + // efficient than explicitly computing whether the remainder is nonzero + // inside `_algorithmD`. + (uint256 prod_hi, uint256 prod_lo) = _mul(y_hi, y_lo, q); + return q.unsafeInc(0 < (prod_hi ^ x_hi) | (prod_lo ^ x_lo)); + } + + function odivUpAlt(uint512 r, uint512 x, uint512 y) internal pure returns (uint512) { + (uint256 y_hi, uint256 y_lo) = y.into(); + if (y_hi == 0) { + return odivUp(r, x, y_lo); + } + (uint256 x_hi, uint256 x_lo) = x.into(); + if (y_lo == 0) { + (uint256 r_hi_, uint256 r_lo_) = _add(0, x_hi.unsafeDiv(y_hi), (0 < (x_lo | x_hi.unsafeMod(y_hi))).toUint()); + return r.from(r_hi_, r_lo_); + } + if (_gt(y_hi, y_lo, x_hi, x_lo)) { + return r.from(0, (0 < (x_hi | x_lo)).toUint()); + } + + // At this point, we know that both `x` and `y` are fully represented by + // 2 words. There is no simpler representation for the problem. We must + // use Knuth's Algorithm D. + uint256 q = _algorithmD(x_hi, x_lo, y_hi, y_lo); + + // If the division was not exact, then we must round up. This is more + // efficient than explicitly computing whether the remainder is nonzero + // inside `_algorithmD`. + (uint256 prod_hi, uint256 prod_lo) = _mul(y_hi, y_lo, q); + (uint256 r_hi, uint256 r_lo) = _add(0, q, (0 < (prod_hi ^ x_hi) | (prod_lo ^ x_lo)).toUint()); + return r.from(r_hi, r_lo); + } + + function idivUpAlt(uint512 r, uint512 y) internal pure returns (uint512) { + return odivUpAlt(r, r, y); + } + + function irdivUpAlt(uint512 r, uint512 y) internal pure returns (uint512) { + return odivUpAlt(r, y, r); } function omodAlt(uint512 r, uint512 x, uint512 y) internal pure returns (uint512) { @@ -1414,13 +1686,7 @@ library Lib512MathArithmetic { } // gas benchmark 2025/09/20: ~1425 gas - function sqrt(uint512 x) internal pure returns (uint256 r) { - (uint256 x_hi, uint256 x_lo) = x.into(); - - if (x_hi == 0) { - return x_lo.sqrt(); - } - + function _sqrt(uint256 x_hi, uint256 x_lo) private pure returns (uint256 r) { /// Our general approach here is to compute the inverse of the square root of the argument /// using Newton-Raphson iterations. Then we combine (multiply) this inverse square root /// approximation with the argument to approximate the square root of the argument. After @@ -1516,9 +1782,9 @@ library Lib512MathArithmetic { // Generally speaking, for relatively smaller `e` (lower values of `x`) and for // relatively larger `M`, we can skip the 5th N-R iteration. The constant `95` is // derived by extensive fuzzing. Attempting a higher-order approximation of the - // relationship between `M` and `invE` consumes, on average, more gas. The correct - // bits that this iteration would obtain are shifted away during the denormalization - // step. This branch is net gas-optimizing. + // relationship between `M` and `invE` consumes, on average, more gas. When this + // branch is not taken, the correct bits that this iteration would obtain are + // shifted away during the denormalization step. This branch is net gas-optimizing. uint256 Y2 = Y * Y; // scale: 2²⁵⁴ uint256 MY2 = _inaccurateMulHi(M, Y2); // scale: 2²⁵⁴ uint256 T = 1.5 * 2 ** 254 - MY2; // scale: 2²⁵⁴ @@ -1549,7 +1815,7 @@ library Lib512MathArithmetic { /// `r0` is only an approximation of √x, so we perform a single Babylonian step to fully /// converge on ⌊√x⌋ or ⌈√x⌉. The Babylonian step is: - /// r1 = ⌊(r0 + ⌊x/r0⌋) / 2⌋ + /// r = ⌊(r0 + ⌊x/r0⌋) / 2⌋ // Rather than use the more-expensive division routine that returns a 512-bit result, // because the value the upper word of the quotient can take is highly constrained, we // can compute the quotient mod 2²⁵⁶ and recover the high word separately. Although @@ -1561,18 +1827,100 @@ library Lib512MathArithmetic { uint256 q_lo = _div(x_hi, x_lo, r0); uint256 q_hi = (r0 <= x_hi).toUint(); (uint256 s_hi, uint256 s_lo) = _add(q_hi, q_lo, r0); - // `oflo` here is either 0 or 1. When `oflo == 1`, `r1 == 0`, and the correct value for - // `r1` is `type(uint256).max`. - (uint256 oflo, uint256 r1) = _shr256(s_hi, s_lo, 1); - r1 -= oflo; // underflow is desired + // `oflo` here is either 0 or 1. When `oflo == 1`, `r == 0`, and the correct value for + // `r` is `type(uint256).max`. + uint256 oflo; + (oflo, r) = _shr256(s_hi, s_lo, 1); + r -= oflo; // underflow is desired + } + } + + function sqrt(uint512 x) internal pure returns (uint256) { + (uint256 x_hi, uint256 x_lo) = x.into(); + + if (x_hi == 0) { + return x_lo.sqrt(); + } + + uint256 r = _sqrt(x_hi, x_lo); + + // Because the Babylonian step can give ⌈√x⌉ if x+1 is a perfect square, we have to + // check whether we've overstepped by 1 and clamp as appropriate. ref: + // https://en.wikipedia.org/wiki/Integer_square_root#Using_only_integer_division + (uint256 r2_hi, uint256 r2_lo) = _mul(r, r); + return r.unsafeDec(_gt(r2_hi, r2_lo, x_hi, x_lo)); + } + + function osqrtUp(uint512 r, uint512 x) internal pure returns (uint512) { + (uint256 x_hi, uint256 x_lo) = x.into(); + + if (x_hi == 0) { + return r.from(0, x_lo.sqrtUp()); + } + + uint256 r_lo = _sqrt(x_hi, x_lo); + + // The Babylonian step can give ⌈√x⌉ if x+1 is a perfect square. This is + // fine. If the Babylonian step gave ⌊√x⌋ != √x, we have to round up. + (uint256 r2_hi, uint256 r2_lo) = _mul(r_lo, r_lo); + uint256 r_hi; + (r_hi, r_lo) = _add(0, r_lo, _gt(x_hi, x_lo, r2_hi, r2_lo).toUint()); + return r.from(r_hi, r_lo); + } + + function isqrtUp(uint512 r) internal pure returns (uint512) { + return osqrtUp(r, r); + } + + function oshr(uint512 r, uint512 x, uint256 s) internal pure returns (uint512) { + (uint256 x_hi, uint256 x_lo) = x.into(); + (uint256 r_hi, uint256 r_lo) = _shr(x_hi, x_lo, s); + return r.from(r_hi, r_lo); + } + + function ishr(uint512 r, uint256 s) internal pure returns (uint512) { + return oshr(r, r, s); + } + + function _shrUp(uint256 x_hi, uint256 x_lo, uint256 s) internal pure returns (uint256 r_hi, uint256 r_lo) { + assembly ("memory-safe") { + let neg_s := sub(0x100, s) + let s_256 := sub(s, 0x100) + + // compute `(x_hi, x_lo) >> s`, retaining intermediate values + let x_lo_shr := shr(s, x_lo) + let x_hi_shr := shr(s_256, x_hi) + r_hi := shr(s, x_hi) + r_lo := or(or(shl(neg_s, x_hi), x_lo_shr), x_hi_shr) + + // detect if nonzero bits were truncated + let inc := lt(0x00, or(xor(x_lo, shl(s, x_lo_shr)), mul(xor(x_hi, shl(s_256, x_hi_shr)), lt(0x100, neg_s)))) - /// Because the Babylonian step can give ⌈√x⌉ if x+1 is a perfect square, we have to - /// check whether we've overstepped by 1 and clamp as appropriate. ref: - /// https://en.wikipedia.org/wiki/Integer_square_root#Using_only_integer_division - (uint256 r2_hi, uint256 r2_lo) = _mul(r1, r1); - r = r1.unsafeDec(_gt(r2_hi, r2_lo, x_hi, x_lo)); + // conditionally increment the result + r_lo := add(inc, r_lo) + r_hi := add(lt(r_lo, inc), r_hi) } } + + function oshrUp(uint512 r, uint512 x, uint256 s) internal pure returns (uint512) { + (uint256 x_hi, uint256 x_lo) = x.into(); + (uint256 r_hi, uint256 r_lo) = _shrUp(x_hi, x_lo, s); + return r.from(r_hi, r_lo); + } + + function ishrUp(uint512 r, uint256 s) internal pure returns (uint512) { + return oshrUp(r, r, s); + } + + function oshl(uint512 r, uint512 x, uint256 s) internal pure returns (uint512) { + (uint256 x_hi, uint256 x_lo) = x.into(); + (uint256 r_hi, uint256 r_lo) = _shl(x_hi, x_lo, s); + return r.from(r_hi, r_lo); + } + + function ishl(uint512 r, uint256 s) internal pure returns (uint512) { + return oshl(r, r, s); + } } using Lib512MathArithmetic for uint512 global; @@ -1587,10 +1935,10 @@ library Lib512MathUserDefinedHelpers { } } - function smuggleToPure(function (uint512, uint512, uint512) internal view returns (uint512) f) + function smuggleToPure(function(uint512, uint512, uint512) internal view returns (uint512) f) internal pure - returns (function (uint512, uint512, uint512) internal pure returns (uint512) r) + returns (function(uint512, uint512, uint512) internal pure returns (uint512) r) { assembly ("memory-safe") { r := f @@ -1631,7 +1979,7 @@ function __div(uint512 x, uint512 y) pure returns (uint512 r) { Lib512MathUserDefinedHelpers.smuggleToPure(Lib512MathUserDefinedHelpers.odiv)(r, x, y); } -using {__add as +, __sub as -, __mul as *, __mod as %, __div as / } for uint512 global; +using {__add as +, __sub as -, __mul as *, __mod as %, __div as /} for uint512 global; struct uint512_external { uint256 hi; diff --git a/src/vendor/Sqrt.sol b/src/vendor/Sqrt.sol index ea757859a..f220d8162 100644 --- a/src/vendor/Sqrt.sol +++ b/src/vendor/Sqrt.sol @@ -1,46 +1,22 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.25; -// @author Modified from Solady by Vectorized https://github.com/Vectorized/solady/blob/701406e8126cfed931645727b274df303fbcd94d/src/utils/FixedPointMathLib.sol#L774-L826 under the MIT license. +// @author Modified from Solady by Vectorized and Akshay Tarpara https://github.com/Vectorized/solady/blob/1198c9f70b30d472a7d0ec021bec080622191b03/src/utils/clz/FixedPointMathLib.sol#L769-L797 under the MIT license. library Sqrt { - /// @dev Returns the square root of `x`, rounded down. + /// @dev Returns the square root of `x`, rounded maybe-up maybe-down. For expert use only. function _sqrt(uint256 x) private pure returns (uint256 z) { assembly ("memory-safe") { - // `floor(sqrt(2**15)) = 181`. `sqrt(2**15) - 181 = 2.84`. - z := 181 // The "correct" value is 1, but this saves a multiplication later. - - // This segment is to get a reasonable initial estimate for the Babylonian method. With a bad - // start, the correct # of bits increases ~linearly each iteration instead of ~quadratically. - - // Let `y = x / 2**r`. We check `y >= 2**(k + 8)` - // but shift right by `k` bits to ensure that if `x >= 256`, then `y >= 256`. - let r := shl(7, lt(0xffffffffffffffffffffffffffffffffff, x)) - r := or(r, shl(6, lt(0xffffffffffffffffff, shr(r, x)))) - r := or(r, shl(5, lt(0xffffffffff, shr(r, x)))) - r := or(r, shl(4, lt(0xffffff, shr(r, x)))) - z := shl(shr(1, r), z) - - // Goal was to get `z*z*y` within a small factor of `x`. More iterations could - // get y in a tighter range. Currently, we will have y in `[256, 256*(2**16))`. - // We ensured `y >= 256` so that the relative difference between `y` and `y+1` is small. - // That's not possible if `x < 256` but we can just verify those cases exhaustively. - - // Now, `z*z*y <= x < z*z*(y+1)`, and `y <= 2**(16+8)`, and either `y >= 256`, or `x < 256`. - // Correctness can be checked exhaustively for `x < 256`, so we assume `y >= 256`. - // Then `z*sqrt(y)` is within `sqrt(257)/sqrt(256)` of `sqrt(x)`, or about 20bps. - - // For `s` in the range `[1/256, 256]`, the estimate `f(s) = (181/1024) * (s+1)` - // is in the range `(1/2.84 * sqrt(s), 2.84 * sqrt(s))`, - // with largest error when `s = 1` and when `s = 256` or `1/256`. - - // Since `y` is in `[256, 256*(2**16))`, let `a = y/65536`, so that `a` is in `[1/256, 256)`. - // Then we can estimate `sqrt(y)` using - // `sqrt(65536) * 181/1024 * (a + 1) = 181/4 * (y + 65536)/65536 = 181 * (y + 65536)/2**18`. - - // There is no overflow risk here since `y < 2**136` after the first branch above. - z := shr(18, mul(z, add(shr(r, x), 65536))) // A `mul()` is saved from starting `z` at 181. - - // Given the worst case multiplicative error of 2.84 above, 7 iterations should be enough. + // Step 1: Get the bit position of the most significant bit + // n = floor(log2(x)) + // For x ≈ 2^n, we know sqrt(x) ≈ 2^(n/2) + // We use (n+1)/2 instead of n/2 to round up slightly + // This gives a better initial approximation + // + // Formula: z = 2^((n+1)/2) = 2^(floor((n+1)/2)) + // Implemented as: z = 1 << ((n+1) >> 1) + z := shl(shr(1, sub(256, clz(x))), 1) + + /// (x/z + z) / 2 z := shr(1, add(z, div(x, z))) z := shr(1, add(z, div(x, z))) z := shr(1, add(z, div(x, z))) @@ -51,6 +27,7 @@ library Sqrt { } } + /// @dev Returns the square root of `x`, rounded down. function sqrt(uint256 x) internal pure returns (uint256 z) { z = _sqrt(x); assembly ("memory-safe") { @@ -61,9 +38,14 @@ library Sqrt { } } + /// @dev Returns the square root of `x`, rounded up. function sqrtUp(uint256 x) internal pure returns (uint256 z) { z = _sqrt(x); assembly ("memory-safe") { + // If `x == type(uint256).max`, then according to its contract `_sqrt(x)` could return + // `2**128`. This would cause `mul(z, z)` to overflow and `sqrtUp` to return `2**128 + + // 1`. However, for this specific input in practice, `_sqrt` returns `2**128 - 1`, + // defusing this scenario. z := add(lt(mul(z, z), x), z) } } diff --git a/test/0.8.25/512Math.t.sol b/test/0.8.25/512Math.t.sol index c7cbb8ded..d202c4518 100644 --- a/test/0.8.25/512Math.t.sol +++ b/test/0.8.25/512Math.t.sol @@ -238,4 +238,208 @@ contract Lib512MathTest is Test { assertTrue((r2_hi > x_hi) || (r2_hi == x_hi && r2_lo > x_lo), "sqrt too low"); } } + + function test512Math_divUpAlt(uint256 x_hi, uint256 x_lo, uint256 y_hi, uint256 y_lo) external view { + vm.assume(y_hi != 0); + + uint512 x = alloc().from(x_hi, x_lo); + uint512 y = alloc().from(y_hi, y_lo); + + uint256 ceil_q = x.divUpAlt(y); + uint256 floor_q = x.div(y); + + (uint256 e_lo, uint256 e_hi) = SlowMath.fullMul(y_lo, y_hi, floor_q, 0); + uint512 e = alloc().from(e_hi, e_lo); + + assertTrue( + ceil_q == floor_q || (floor_q == type(uint256).max && ceil_q == 0) || (e != x && ceil_q == floor_q + 1) + ); + } + + function test512Math_odivUpAlt(uint256 x_hi, uint256 x_lo, uint256 y_hi, uint256 y_lo) external view { + vm.assume(y_hi != 0); + + uint512 x = alloc().from(x_hi, x_lo); + uint512 y = alloc().from(y_hi, y_lo); + + uint512 ceil_q = alloc().odivUpAlt(x, y); + uint512 floor_q = alloc().odiv(x, y); + + (uint256 floor_q_hi, uint256 floor_q_lo) = floor_q.into(); + (uint256 e_lo, uint256 e_hi) = SlowMath.fullMul(y_lo, y_hi, floor_q_lo, floor_q_hi); + uint512 e = alloc().from(e_hi, e_lo); + + assertTrue( + ceil_q == floor_q + || (floor_q == tmp().from(type(uint256).max, type(uint256).max) && ceil_q == tmp().from(0, 0)) + || (e != x && ceil_q == tmp().oadd(floor_q, 1)) + ); + } + + function test512Math_divUpForeign(uint256 x_hi, uint256 x_lo, uint256 y) external pure { + vm.assume(y != 0); + + uint512 x = alloc().from(x_hi, x_lo); + + uint256 ceil_q = x.divUp(y); + uint256 floor_q = x.div(y); + + (uint256 e_lo, uint256 e_hi) = SlowMath.fullMul(y, floor_q); + uint512 e = alloc().from(e_hi, e_lo); + + assertTrue( + ceil_q == floor_q || (floor_q == type(uint256).max && ceil_q == 0) || (e != x && ceil_q == floor_q + 1) + ); + } + + function test512Math_divUpNative(uint256 x_hi, uint256 x_lo, uint256 y_hi, uint256 y_lo) external view { + vm.assume(y_hi != 0); + + uint512 x = alloc().from(x_hi, x_lo); + uint512 y = alloc().from(y_hi, y_lo); + + uint256 ceil_q = x.divUp(y); + uint256 floor_q = x.div(y); + + (uint256 e_lo, uint256 e_hi) = SlowMath.fullMul(y_lo, y_hi, floor_q, 0); + uint512 e = alloc().from(e_hi, e_lo); + + assertTrue( + ceil_q == floor_q || (floor_q == type(uint256).max && ceil_q == 0) || (e != x && ceil_q == floor_q + 1) + ); + } + + function test512Math_odivUpForeign(uint256 x_hi, uint256 x_lo, uint256 y) external pure { + vm.assume(y != 0); + + uint512 x = alloc().from(x_hi, x_lo); + + uint512 ceil_q = alloc().odivUp(x, y); + uint512 floor_q = alloc().odiv(x, y); + + (uint256 floor_q_hi, uint256 floor_q_lo) = floor_q.into(); + (uint256 e_lo, uint256 e_hi) = SlowMath.fullMul(y, 0, floor_q_lo, floor_q_hi); + uint512 e = alloc().from(e_hi, e_lo); + + assertTrue( + ceil_q == floor_q + || (floor_q == tmp().from(type(uint256).max, type(uint256).max) && ceil_q == tmp().from(0, 0)) + || (e != x && ceil_q == tmp().oadd(floor_q, 1)) + ); + } + + function test512Math_idivUpForeign(uint256 x_hi, uint256 x_lo, uint256 y) external pure { + vm.assume(y != 0); + + uint512 x = alloc().from(x_hi, x_lo); + + (uint256 ceil_q_hi, uint256 ceil_q_lo) = tmp().from(x).idivUp(y).into(); + (uint256 floor_q_hi, uint256 floor_q_lo) = tmp().from(x).idiv(y).into(); + + (uint256 e_lo, uint256 e_hi) = SlowMath.fullMul(y, 0, floor_q_lo, floor_q_hi); + uint512 e = alloc().from(e_hi, e_lo); + + uint512 ceil_q = alloc().from(ceil_q_hi, ceil_q_lo); + uint512 floor_q = alloc().from(floor_q_hi, floor_q_lo); + + assertTrue( + ceil_q == floor_q + || (floor_q == tmp().from(type(uint256).max, type(uint256).max) && ceil_q == tmp().from(0, 0)) + || (e != x && ceil_q == tmp().oadd(floor_q, 1)) + ); + } + + function test512Math_odivUpNative(uint256 x_hi, uint256 x_lo, uint256 y_hi, uint256 y_lo) external view { + vm.assume(y_hi != 0); + + uint512 x = alloc().from(x_hi, x_lo); + uint512 y = alloc().from(y_hi, y_lo); + + uint512 ceil_q = alloc().odivUp(x, y); + uint512 floor_q = alloc().odiv(x, y); + + (uint256 floor_q_hi, uint256 floor_q_lo) = floor_q.into(); + (uint256 e_lo, uint256 e_hi) = SlowMath.fullMul(y_lo, y_hi, floor_q_lo, floor_q_hi); + uint512 e = alloc().from(e_hi, e_lo); + + assertTrue( + ceil_q == floor_q + || (floor_q == tmp().from(type(uint256).max, type(uint256).max) && ceil_q == tmp().from(0, 0)) + || (e != x && ceil_q == tmp().oadd(floor_q, 1)) + ); + } + + function test512Math_idivUpNative(uint256 x_hi, uint256 x_lo, uint256 y_hi, uint256 y_lo) external view { + vm.assume(y_hi != 0); + + uint512 x = alloc().from(x_hi, x_lo); + uint512 y = alloc().from(y_hi, y_lo); + + (uint256 ceil_q_hi, uint256 ceil_q_lo) = tmp().from(x).idivUp(y).into(); + (uint256 floor_q_hi, uint256 floor_q_lo) = tmp().from(x).idiv(y).into(); + + (uint256 e_lo, uint256 e_hi) = SlowMath.fullMul(y_lo, y_hi, floor_q_lo, floor_q_hi); + uint512 e = alloc().from(e_hi, e_lo); + + uint512 ceil_q = alloc().from(ceil_q_hi, ceil_q_lo); + uint512 floor_q = alloc().from(floor_q_hi, floor_q_lo); + + assertTrue( + ceil_q == floor_q + || (floor_q == tmp().from(type(uint256).max, type(uint256).max) && ceil_q == tmp().from(0, 0)) + || (e != x && ceil_q == tmp().oadd(floor_q, 1)) + ); + } + + function test512Math_osqrtUp(uint256 x_hi, uint256 x_lo) external pure { + uint512 x = alloc().from(x_hi, x_lo); + (uint256 r_hi, uint256 r_lo) = alloc().osqrtUp(x).into(); + + if (r_hi == 0 && r_lo == 0) { + // sqrtUp(0) = 0 + assertTrue(x_hi == 0 && x_lo == 0, "sqrtUp of nonzero is zero"); + } else if (r_hi != 0) { + // r >= 2^256, which means r must be exactly 2^256 (since sqrt of 512-bit is at most 2^256) + // r^2 = 2^512 exceeds max 512-bit value, so r^2 >= x is trivially true + // We only need to verify (r-1)^2 < x, i.e., (type(uint256).max)^2 < x + assertTrue(r_hi == 1 && r_lo == 0, "overflow result must be exactly 2^256"); + (uint256 r_dec2_lo, uint256 r_dec2_hi) = SlowMath.fullMul(type(uint256).max, type(uint256).max); + assertTrue((r_dec2_hi < x_hi) || (r_dec2_hi == x_hi && r_dec2_lo < x_lo), "sqrtUp too high"); + } else { + // Normal case: r fits in 256 bits + (uint256 r2_lo, uint256 r2_hi) = SlowMath.fullMul(r_lo, r_lo); + assertTrue((r2_hi > x_hi) || (r2_hi == x_hi && r2_lo >= x_lo), "sqrtUp too low"); + + // Check (r-1)^2 < x + if (r_lo == 1) { + // (r-1)^2 = 0, which must be less than any nonzero x. Already verified x != 0 + // since we're in the r != 0 branch. + } else { + uint256 r_dec_lo = r_lo - 1; + (r2_lo, r2_hi) = SlowMath.fullMul(r_dec_lo, r_dec_lo); + assertTrue((r2_hi < x_hi) || (r2_hi == x_hi && r2_lo < x_lo), "sqrtUp too high"); + } + } + } + + function test512Math_oshrUp(uint256 x_hi, uint256 x_lo, uint256 s) external pure { + s = bound(s, 0, 512); + + uint512 x = alloc().from(x_hi, x_lo); + (uint256 r_hi, uint256 r_lo) = tmp().oshrUp(x, s).into(); + + (uint256 e_lo, uint256 e_hi) = SlowMath.fullShrUp(x_lo, x_hi, s); + assertEq(r_hi, e_hi); + assertEq(r_lo, e_lo); + } + + function test512Math_ishrUp(uint256 x_hi, uint256 x_lo, uint256 s) external pure { + s = bound(s, 0, 512); + + (uint256 r_hi, uint256 r_lo) = tmp().from(x_hi, x_lo).ishrUp(s).into(); + + (uint256 e_lo, uint256 e_hi) = SlowMath.fullShrUp(x_lo, x_hi, s); + assertEq(r_hi, e_hi); + assertEq(r_lo, e_lo); + } } diff --git a/test/0.8.25/SlowMath.sol b/test/0.8.25/SlowMath.sol index f70604598..a06d39bfe 100644 --- a/test/0.8.25/SlowMath.sol +++ b/test/0.8.25/SlowMath.sol @@ -128,4 +128,44 @@ library SlowMath { } revert("Not converged"); } + + function fullShrUp(uint256 x_lo, uint256 x_hi, uint256 s) internal pure returns (uint256 r_lo, uint256 r_hi) { + if (s == 0) { + return (x_lo, x_hi); + } else if (s >= 512) { + // ceil(x / 2^512) = 0 if x == 0, else 1 + return ((x_hi | x_lo) != 0 ? 1 : 0, 0); + } else if (s >= 256) { + // floor = x_hi >> (s - 256), fits in r_lo + uint256 shift = s - 256; + r_lo = x_hi >> shift; + // Remainder exists if x_lo != 0 OR x_hi has any of bottom 'shift' bits set + bool hasRemainder = (x_lo != 0) || (shift != 0 && (x_hi & ((uint256(1) << shift) - 1)) != 0); + if (hasRemainder) { + unchecked { + r_lo += 1; + } + if (r_lo == 0) { + r_hi = 1; + } + } + } else { + // s < 256: use fullDiv with divisor 2^s + uint256 d = uint256(1) << s; + (r_lo, r_hi) = fullDiv(x_lo, x_hi, d); + // Check remainder: multiply back and compare + (uint256 prod_lo, uint256 prod_hi) = fullMul(r_lo, r_hi, d, 0); + if (prod_lo != x_lo || prod_hi != x_hi) { + // Has remainder, add 1 + unchecked { + r_lo += 1; + } + if (r_lo == 0) { + unchecked { + r_hi += 1; + } + } + } + } + } }