001/*
002 * Copyright (C) 2011 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
005 * in compliance with the License. You may obtain a copy of the License at
006 *
007 * http://www.apache.org/licenses/LICENSE-2.0
008 *
009 * Unless required by applicable law or agreed to in writing, software distributed under the License
010 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
011 * or implied. See the License for the specific language governing permissions and limitations under
012 * the License.
013 */
014
015package com.google.common.math;
016
017import static com.google.common.base.Preconditions.checkArgument;
018import static com.google.common.base.Preconditions.checkNotNull;
019import static com.google.common.math.MathPreconditions.checkNoOverflow;
020import static com.google.common.math.MathPreconditions.checkNonNegative;
021import static com.google.common.math.MathPreconditions.checkPositive;
022import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
023import static java.lang.Math.abs;
024import static java.lang.Math.min;
025import static java.math.RoundingMode.HALF_EVEN;
026import static java.math.RoundingMode.HALF_UP;
027
028import com.google.common.annotations.Beta;
029import com.google.common.annotations.GwtCompatible;
030import com.google.common.annotations.GwtIncompatible;
031import com.google.common.annotations.VisibleForTesting;
032import com.google.common.primitives.UnsignedLongs;
033import java.math.BigInteger;
034import java.math.RoundingMode;
035
036/**
037 * A class for arithmetic on values of type {@code long}. Where possible, methods are defined and
038 * named analogously to their {@code BigInteger} counterparts.
039 *
040 * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
041 * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
042 *
043 * <p>Similar functionality for {@code int} and for {@link BigInteger} can be found in {@link
044 * IntMath} and {@link BigIntegerMath} respectively. For other common operations on {@code long}
045 * values, see {@link com.google.common.primitives.Longs}.
046 *
047 * @author Louis Wasserman
048 * @since 11.0
049 */
050@GwtCompatible(emulated = true)
051public final class LongMath {
052  // NOTE: Whenever both tests are cheap and functional, it's faster to use &, | instead of &&, ||
053
054  @VisibleForTesting static final long MAX_SIGNED_POWER_OF_TWO = 1L << (Long.SIZE - 2);
055
056  /**
057   * Returns the smallest power of two greater than or equal to {@code x}. This is equivalent to
058   * {@code checkedPow(2, log2(x, CEILING))}.
059   *
060   * @throws IllegalArgumentException if {@code x <= 0}
061   * @throws ArithmeticException of the next-higher power of two is not representable as a {@code
062   *     long}, i.e. when {@code x > 2^62}
063   * @since 20.0
064   */
065  @Beta
066  public static long ceilingPowerOfTwo(long x) {
067    checkPositive("x", x);
068    if (x > MAX_SIGNED_POWER_OF_TWO) {
069      throw new ArithmeticException("ceilingPowerOfTwo(" + x + ") is not representable as a long");
070    }
071    return 1L << -Long.numberOfLeadingZeros(x - 1);
072  }
073
074  /**
075   * Returns the largest power of two less than or equal to {@code x}. This is equivalent to {@code
076   * checkedPow(2, log2(x, FLOOR))}.
077   *
078   * @throws IllegalArgumentException if {@code x <= 0}
079   * @since 20.0
080   */
081  @Beta
082  public static long floorPowerOfTwo(long x) {
083    checkPositive("x", x);
084
085    // Long.highestOneBit was buggy on GWT.  We've fixed it, but I'm not certain when the fix will
086    // be released.
087    return 1L << ((Long.SIZE - 1) - Long.numberOfLeadingZeros(x));
088  }
089
090  /**
091   * Returns {@code true} if {@code x} represents a power of two.
092   *
093   * <p>This differs from {@code Long.bitCount(x) == 1}, because {@code
094   * Long.bitCount(Long.MIN_VALUE) == 1}, but {@link Long#MIN_VALUE} is not a power of two.
095   */
096  public static boolean isPowerOfTwo(long x) {
097    return x > 0 & (x & (x - 1)) == 0;
098  }
099
100  /**
101   * Returns 1 if {@code x < y} as unsigned longs, and 0 otherwise. Assumes that x - y fits into a
102   * signed long. The implementation is branch-free, and benchmarks suggest it is measurably faster
103   * than the straightforward ternary expression.
104   */
105  @VisibleForTesting
106  static int lessThanBranchFree(long x, long y) {
107    // Returns the sign bit of x - y.
108    return (int) (~~(x - y) >>> (Long.SIZE - 1));
109  }
110
111  /**
112   * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
113   *
114   * @throws IllegalArgumentException if {@code x <= 0}
115   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
116   *     is not a power of two
117   */
118  @SuppressWarnings("fallthrough")
119  // TODO(kevinb): remove after this warning is disabled globally
120  public static int log2(long x, RoundingMode mode) {
121    checkPositive("x", x);
122    switch (mode) {
123      case UNNECESSARY:
124        checkRoundingUnnecessary(isPowerOfTwo(x));
125        // fall through
126      case DOWN:
127      case FLOOR:
128        return (Long.SIZE - 1) - Long.numberOfLeadingZeros(x);
129
130      case UP:
131      case CEILING:
132        return Long.SIZE - Long.numberOfLeadingZeros(x - 1);
133
134      case HALF_DOWN:
135      case HALF_UP:
136      case HALF_EVEN:
137        // Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
138        int leadingZeros = Long.numberOfLeadingZeros(x);
139        long cmp = MAX_POWER_OF_SQRT2_UNSIGNED >>> leadingZeros;
140        // floor(2^(logFloor + 0.5))
141        int logFloor = (Long.SIZE - 1) - leadingZeros;
142        return logFloor + lessThanBranchFree(cmp, x);
143
144      default:
145        throw new AssertionError("impossible");
146    }
147  }
148
149  /** The biggest half power of two that fits into an unsigned long */
150  @VisibleForTesting static final long MAX_POWER_OF_SQRT2_UNSIGNED = 0xB504F333F9DE6484L;
151
152  /**
153   * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
154   *
155   * @throws IllegalArgumentException if {@code x <= 0}
156   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
157   *     is not a power of ten
158   */
159  @GwtIncompatible // TODO
160  @SuppressWarnings("fallthrough")
161  // TODO(kevinb): remove after this warning is disabled globally
162  public static int log10(long x, RoundingMode mode) {
163    checkPositive("x", x);
164    int logFloor = log10Floor(x);
165    long floorPow = powersOf10[logFloor];
166    switch (mode) {
167      case UNNECESSARY:
168        checkRoundingUnnecessary(x == floorPow);
169        // fall through
170      case FLOOR:
171      case DOWN:
172        return logFloor;
173      case CEILING:
174      case UP:
175        return logFloor + lessThanBranchFree(floorPow, x);
176      case HALF_DOWN:
177      case HALF_UP:
178      case HALF_EVEN:
179        // sqrt(10) is irrational, so log10(x)-logFloor is never exactly 0.5
180        return logFloor + lessThanBranchFree(halfPowersOf10[logFloor], x);
181      default:
182        throw new AssertionError();
183    }
184  }
185
186  @GwtIncompatible // TODO
187  static int log10Floor(long x) {
188    /*
189     * Based on Hacker's Delight Fig. 11-5, the two-table-lookup, branch-free implementation.
190     *
191     * The key idea is that based on the number of leading zeros (equivalently, floor(log2(x))), we
192     * can narrow the possible floor(log10(x)) values to two. For example, if floor(log2(x)) is 6,
193     * then 64 <= x < 128, so floor(log10(x)) is either 1 or 2.
194     */
195    int y = maxLog10ForLeadingZeros[Long.numberOfLeadingZeros(x)];
196    /*
197     * y is the higher of the two possible values of floor(log10(x)). If x < 10^y, then we want the
198     * lower of the two possible values, or y - 1, otherwise, we want y.
199     */
200    return y - lessThanBranchFree(x, powersOf10[y]);
201  }
202
203  // maxLog10ForLeadingZeros[i] == floor(log10(2^(Long.SIZE - i)))
204  @VisibleForTesting
205  static final byte[] maxLog10ForLeadingZeros = {
206    19, 18, 18, 18, 18, 17, 17, 17, 16, 16, 16, 15, 15, 15, 15, 14, 14, 14, 13, 13, 13, 12, 12, 12,
207    12, 11, 11, 11, 10, 10, 10, 9, 9, 9, 9, 8, 8, 8, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 4, 4, 4, 3, 3, 3,
208    3, 2, 2, 2, 1, 1, 1, 0, 0, 0
209  };
210
211  @GwtIncompatible // TODO
212  @VisibleForTesting
213  static final long[] powersOf10 = {
214    1L,
215    10L,
216    100L,
217    1000L,
218    10000L,
219    100000L,
220    1000000L,
221    10000000L,
222    100000000L,
223    1000000000L,
224    10000000000L,
225    100000000000L,
226    1000000000000L,
227    10000000000000L,
228    100000000000000L,
229    1000000000000000L,
230    10000000000000000L,
231    100000000000000000L,
232    1000000000000000000L
233  };
234
235  // halfPowersOf10[i] = largest long less than 10^(i + 0.5)
236  @GwtIncompatible // TODO
237  @VisibleForTesting
238  static final long[] halfPowersOf10 = {
239    3L,
240    31L,
241    316L,
242    3162L,
243    31622L,
244    316227L,
245    3162277L,
246    31622776L,
247    316227766L,
248    3162277660L,
249    31622776601L,
250    316227766016L,
251    3162277660168L,
252    31622776601683L,
253    316227766016837L,
254    3162277660168379L,
255    31622776601683793L,
256    316227766016837933L,
257    3162277660168379331L
258  };
259
260  /**
261   * Returns {@code b} to the {@code k}th power. Even if the result overflows, it will be equal to
262   * {@code BigInteger.valueOf(b).pow(k).longValue()}. This implementation runs in {@code O(log k)}
263   * time.
264   *
265   * @throws IllegalArgumentException if {@code k < 0}
266   */
267  @GwtIncompatible // TODO
268  public static long pow(long b, int k) {
269    checkNonNegative("exponent", k);
270    if (-2 <= b && b <= 2) {
271      switch ((int) b) {
272        case 0:
273          return (k == 0) ? 1 : 0;
274        case 1:
275          return 1;
276        case (-1):
277          return ((k & 1) == 0) ? 1 : -1;
278        case 2:
279          return (k < Long.SIZE) ? 1L << k : 0;
280        case (-2):
281          if (k < Long.SIZE) {
282            return ((k & 1) == 0) ? 1L << k : -(1L << k);
283          } else {
284            return 0;
285          }
286        default:
287          throw new AssertionError();
288      }
289    }
290    for (long accum = 1; ; k >>= 1) {
291      switch (k) {
292        case 0:
293          return accum;
294        case 1:
295          return accum * b;
296        default:
297          accum *= ((k & 1) == 0) ? 1 : b;
298          b *= b;
299      }
300    }
301  }
302
303  /**
304   * Returns the square root of {@code x}, rounded with the specified rounding mode.
305   *
306   * @throws IllegalArgumentException if {@code x < 0}
307   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code
308   *     sqrt(x)} is not an integer
309   */
310  @GwtIncompatible // TODO
311  @SuppressWarnings("fallthrough")
312  public static long sqrt(long x, RoundingMode mode) {
313    checkNonNegative("x", x);
314    if (fitsInInt(x)) {
315      return IntMath.sqrt((int) x, mode);
316    }
317    /*
318     * Let k be the true value of floor(sqrt(x)), so that
319     *
320     *            k * k <= x          <  (k + 1) * (k + 1)
321     * (double) (k * k) <= (double) x <= (double) ((k + 1) * (k + 1))
322     *          since casting to double is nondecreasing.
323     *          Note that the right-hand inequality is no longer strict.
324     * Math.sqrt(k * k) <= Math.sqrt(x) <= Math.sqrt((k + 1) * (k + 1))
325     *          since Math.sqrt is monotonic.
326     * (long) Math.sqrt(k * k) <= (long) Math.sqrt(x) <= (long) Math.sqrt((k + 1) * (k + 1))
327     *          since casting to long is monotonic
328     * k <= (long) Math.sqrt(x) <= k + 1
329     *          since (long) Math.sqrt(k * k) == k, as checked exhaustively in
330     *          {@link LongMathTest#testSqrtOfPerfectSquareAsDoubleIsPerfect}
331     */
332    long guess = (long) Math.sqrt(x);
333    // Note: guess is always <= FLOOR_SQRT_MAX_LONG.
334    long guessSquared = guess * guess;
335    // Note (2013-2-26): benchmarks indicate that, inscrutably enough, using if statements is
336    // faster here than using lessThanBranchFree.
337    switch (mode) {
338      case UNNECESSARY:
339        checkRoundingUnnecessary(guessSquared == x);
340        return guess;
341      case FLOOR:
342      case DOWN:
343        if (x < guessSquared) {
344          return guess - 1;
345        }
346        return guess;
347      case CEILING:
348      case UP:
349        if (x > guessSquared) {
350          return guess + 1;
351        }
352        return guess;
353      case HALF_DOWN:
354      case HALF_UP:
355      case HALF_EVEN:
356        long sqrtFloor = guess - ((x < guessSquared) ? 1 : 0);
357        long halfSquare = sqrtFloor * sqrtFloor + sqrtFloor;
358        /*
359         * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both x
360         * and halfSquare are integers, this is equivalent to testing whether or not x <=
361         * halfSquare. (We have to deal with overflow, though.)
362         *
363         * If we treat halfSquare as an unsigned long, we know that
364         *            sqrtFloor^2 <= x < (sqrtFloor + 1)^2
365         * halfSquare - sqrtFloor <= x < halfSquare + sqrtFloor + 1
366         * so |x - halfSquare| <= sqrtFloor.  Therefore, it's safe to treat x - halfSquare as a
367         * signed long, so lessThanBranchFree is safe for use.
368         */
369        return sqrtFloor + lessThanBranchFree(halfSquare, x);
370      default:
371        throw new AssertionError();
372    }
373  }
374
375  /**
376   * Returns the result of dividing {@code p} by {@code q}, rounding using the specified {@code
377   * RoundingMode}.
378   *
379   * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
380   *     is not an integer multiple of {@code b}
381   */
382  @GwtIncompatible // TODO
383  @SuppressWarnings("fallthrough")
384  public static long divide(long p, long q, RoundingMode mode) {
385    checkNotNull(mode);
386    long div = p / q; // throws if q == 0
387    long rem = p - q * div; // equals p % q
388
389    if (rem == 0) {
390      return div;
391    }
392
393    /*
394     * Normal Java division rounds towards 0, consistently with RoundingMode.DOWN. We just have to
395     * deal with the cases where rounding towards 0 is wrong, which typically depends on the sign of
396     * p / q.
397     *
398     * signum is 1 if p and q are both nonnegative or both negative, and -1 otherwise.
399     */
400    int signum = 1 | (int) ((p ^ q) >> (Long.SIZE - 1));
401    boolean increment;
402    switch (mode) {
403      case UNNECESSARY:
404        checkRoundingUnnecessary(rem == 0);
405        // fall through
406      case DOWN:
407        increment = false;
408        break;
409      case UP:
410        increment = true;
411        break;
412      case CEILING:
413        increment = signum > 0;
414        break;
415      case FLOOR:
416        increment = signum < 0;
417        break;
418      case HALF_EVEN:
419      case HALF_DOWN:
420      case HALF_UP:
421        long absRem = abs(rem);
422        long cmpRemToHalfDivisor = absRem - (abs(q) - absRem);
423        // subtracting two nonnegative longs can't overflow
424        // cmpRemToHalfDivisor has the same sign as compare(abs(rem), abs(q) / 2).
425        if (cmpRemToHalfDivisor == 0) { // exactly on the half mark
426          increment = (mode == HALF_UP | (mode == HALF_EVEN & (div & 1) != 0));
427        } else {
428          increment = cmpRemToHalfDivisor > 0; // closer to the UP value
429        }
430        break;
431      default:
432        throw new AssertionError();
433    }
434    return increment ? div + signum : div;
435  }
436
437  /**
438   * Returns {@code x mod m}, a non-negative value less than {@code m}. This differs from {@code x %
439   * m}, which might be negative.
440   *
441   * <p>For example:
442   *
443   * <pre>{@code
444   * mod(7, 4) == 3
445   * mod(-7, 4) == 1
446   * mod(-1, 4) == 3
447   * mod(-8, 4) == 0
448   * mod(8, 4) == 0
449   * }</pre>
450   *
451   * @throws ArithmeticException if {@code m <= 0}
452   * @see <a href="http://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.17.3">
453   *     Remainder Operator</a>
454   */
455  @GwtIncompatible // TODO
456  public static int mod(long x, int m) {
457    // Cast is safe because the result is guaranteed in the range [0, m)
458    return (int) mod(x, (long) m);
459  }
460
461  /**
462   * Returns {@code x mod m}, a non-negative value less than {@code m}. This differs from {@code x %
463   * m}, which might be negative.
464   *
465   * <p>For example:
466   *
467   * <pre>{@code
468   * mod(7, 4) == 3
469   * mod(-7, 4) == 1
470   * mod(-1, 4) == 3
471   * mod(-8, 4) == 0
472   * mod(8, 4) == 0
473   * }</pre>
474   *
475   * @throws ArithmeticException if {@code m <= 0}
476   * @see <a href="http://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.17.3">
477   *     Remainder Operator</a>
478   */
479  @GwtIncompatible // TODO
480  public static long mod(long x, long m) {
481    if (m <= 0) {
482      throw new ArithmeticException("Modulus must be positive");
483    }
484    long result = x % m;
485    return (result >= 0) ? result : result + m;
486  }
487
488  /**
489   * Returns the greatest common divisor of {@code a, b}. Returns {@code 0} if {@code a == 0 && b ==
490   * 0}.
491   *
492   * @throws IllegalArgumentException if {@code a < 0} or {@code b < 0}
493   */
494  public static long gcd(long a, long b) {
495    /*
496     * The reason we require both arguments to be >= 0 is because otherwise, what do you return on
497     * gcd(0, Long.MIN_VALUE)? BigInteger.gcd would return positive 2^63, but positive 2^63 isn't an
498     * int.
499     */
500    checkNonNegative("a", a);
501    checkNonNegative("b", b);
502    if (a == 0) {
503      // 0 % b == 0, so b divides a, but the converse doesn't hold.
504      // BigInteger.gcd is consistent with this decision.
505      return b;
506    } else if (b == 0) {
507      return a; // similar logic
508    }
509    /*
510     * Uses the binary GCD algorithm; see http://en.wikipedia.org/wiki/Binary_GCD_algorithm. This is
511     * >60% faster than the Euclidean algorithm in benchmarks.
512     */
513    int aTwos = Long.numberOfTrailingZeros(a);
514    a >>= aTwos; // divide out all 2s
515    int bTwos = Long.numberOfTrailingZeros(b);
516    b >>= bTwos; // divide out all 2s
517    while (a != b) { // both a, b are odd
518      // The key to the binary GCD algorithm is as follows:
519      // Both a and b are odd. Assume a > b; then gcd(a - b, b) = gcd(a, b).
520      // But in gcd(a - b, b), a - b is even and b is odd, so we can divide out powers of two.
521
522      // We bend over backwards to avoid branching, adapting a technique from
523      // http://graphics.stanford.edu/~seander/bithacks.html#IntegerMinOrMax
524
525      long delta = a - b; // can't overflow, since a and b are nonnegative
526
527      long minDeltaOrZero = delta & (delta >> (Long.SIZE - 1));
528      // equivalent to Math.min(delta, 0)
529
530      a = delta - minDeltaOrZero - minDeltaOrZero; // sets a to Math.abs(a - b)
531      // a is now nonnegative and even
532
533      b += minDeltaOrZero; // sets b to min(old a, b)
534      a >>= Long.numberOfTrailingZeros(a); // divide out all 2s, since 2 doesn't divide b
535    }
536    return a << min(aTwos, bTwos);
537  }
538
539  /**
540   * Returns the sum of {@code a} and {@code b}, provided it does not overflow.
541   *
542   * @throws ArithmeticException if {@code a + b} overflows in signed {@code long} arithmetic
543   */
544  @GwtIncompatible // TODO
545  public static long checkedAdd(long a, long b) {
546    long result = a + b;
547    checkNoOverflow((a ^ b) < 0 | (a ^ result) >= 0);
548    return result;
549  }
550
551  /**
552   * Returns the difference of {@code a} and {@code b}, provided it does not overflow.
553   *
554   * @throws ArithmeticException if {@code a - b} overflows in signed {@code long} arithmetic
555   */
556  @GwtIncompatible // TODO
557  public static long checkedSubtract(long a, long b) {
558    long result = a - b;
559    checkNoOverflow((a ^ b) >= 0 | (a ^ result) >= 0);
560    return result;
561  }
562
563  /**
564   * Returns the product of {@code a} and {@code b}, provided it does not overflow.
565   *
566   * @throws ArithmeticException if {@code a * b} overflows in signed {@code long} arithmetic
567   */
568  @GwtIncompatible // TODO
569  public static long checkedMultiply(long a, long b) {
570    // Hacker's Delight, Section 2-12
571    int leadingZeros =
572        Long.numberOfLeadingZeros(a)
573            + Long.numberOfLeadingZeros(~a)
574            + Long.numberOfLeadingZeros(b)
575            + Long.numberOfLeadingZeros(~b);
576    /*
577     * If leadingZeros > Long.SIZE + 1 it's definitely fine, if it's < Long.SIZE it's definitely
578     * bad. We do the leadingZeros check to avoid the division below if at all possible.
579     *
580     * Otherwise, if b == Long.MIN_VALUE, then the only allowed values of a are 0 and 1. We take
581     * care of all a < 0 with their own check, because in particular, the case a == -1 will
582     * incorrectly pass the division check below.
583     *
584     * In all other cases, we check that either a is 0 or the result is consistent with division.
585     */
586    if (leadingZeros > Long.SIZE + 1) {
587      return a * b;
588    }
589    checkNoOverflow(leadingZeros >= Long.SIZE);
590    checkNoOverflow(a >= 0 | b != Long.MIN_VALUE);
591    long result = a * b;
592    checkNoOverflow(a == 0 || result / a == b);
593    return result;
594  }
595
596  /**
597   * Returns the {@code b} to the {@code k}th power, provided it does not overflow.
598   *
599   * @throws ArithmeticException if {@code b} to the {@code k}th power overflows in signed {@code
600   *     long} arithmetic
601   */
602  @GwtIncompatible // TODO
603  public static long checkedPow(long b, int k) {
604    checkNonNegative("exponent", k);
605    if (b >= -2 & b <= 2) {
606      switch ((int) b) {
607        case 0:
608          return (k == 0) ? 1 : 0;
609        case 1:
610          return 1;
611        case (-1):
612          return ((k & 1) == 0) ? 1 : -1;
613        case 2:
614          checkNoOverflow(k < Long.SIZE - 1);
615          return 1L << k;
616        case (-2):
617          checkNoOverflow(k < Long.SIZE);
618          return ((k & 1) == 0) ? (1L << k) : (-1L << k);
619        default:
620          throw new AssertionError();
621      }
622    }
623    long accum = 1;
624    while (true) {
625      switch (k) {
626        case 0:
627          return accum;
628        case 1:
629          return checkedMultiply(accum, b);
630        default:
631          if ((k & 1) != 0) {
632            accum = checkedMultiply(accum, b);
633          }
634          k >>= 1;
635          if (k > 0) {
636            checkNoOverflow(-FLOOR_SQRT_MAX_LONG <= b && b <= FLOOR_SQRT_MAX_LONG);
637            b *= b;
638          }
639      }
640    }
641  }
642
643  /**
644   * Returns the sum of {@code a} and {@code b} unless it would overflow or underflow in which case
645   * {@code Long.MAX_VALUE} or {@code Long.MIN_VALUE} is returned, respectively.
646   *
647   * @since 20.0
648   */
649  @Beta
650  public static long saturatedAdd(long a, long b) {
651    long naiveSum = a + b;
652    if ((a ^ b) < 0 | (a ^ naiveSum) >= 0) {
653      // If a and b have different signs or a has the same sign as the result then there was no
654      // overflow, return.
655      return naiveSum;
656    }
657    // we did over/under flow, if the sign is negative we should return MAX otherwise MIN
658    return Long.MAX_VALUE + ((naiveSum >>> (Long.SIZE - 1)) ^ 1);
659  }
660
661  /**
662   * Returns the difference of {@code a} and {@code b} unless it would overflow or underflow in
663   * which case {@code Long.MAX_VALUE} or {@code Long.MIN_VALUE} is returned, respectively.
664   *
665   * @since 20.0
666   */
667  @Beta
668  public static long saturatedSubtract(long a, long b) {
669    long naiveDifference = a - b;
670    if ((a ^ b) >= 0 | (a ^ naiveDifference) >= 0) {
671      // If a and b have the same signs or a has the same sign as the result then there was no
672      // overflow, return.
673      return naiveDifference;
674    }
675    // we did over/under flow
676    return Long.MAX_VALUE + ((naiveDifference >>> (Long.SIZE - 1)) ^ 1);
677  }
678
679  /**
680   * Returns the product of {@code a} and {@code b} unless it would overflow or underflow in which
681   * case {@code Long.MAX_VALUE} or {@code Long.MIN_VALUE} is returned, respectively.
682   *
683   * @since 20.0
684   */
685  @Beta
686  public static long saturatedMultiply(long a, long b) {
687    // see checkedMultiply for explanation
688    int leadingZeros =
689        Long.numberOfLeadingZeros(a)
690            + Long.numberOfLeadingZeros(~a)
691            + Long.numberOfLeadingZeros(b)
692            + Long.numberOfLeadingZeros(~b);
693    if (leadingZeros > Long.SIZE + 1) {
694      return a * b;
695    }
696    // the return value if we will overflow (which we calculate by overflowing a long :) )
697    long limit = Long.MAX_VALUE + ((a ^ b) >>> (Long.SIZE - 1));
698    if (leadingZeros < Long.SIZE | (a < 0 & b == Long.MIN_VALUE)) {
699      // overflow
700      return limit;
701    }
702    long result = a * b;
703    if (a == 0 || result / a == b) {
704      return result;
705    }
706    return limit;
707  }
708
709  /**
710   * Returns the {@code b} to the {@code k}th power, unless it would overflow or underflow in which
711   * case {@code Long.MAX_VALUE} or {@code Long.MIN_VALUE} is returned, respectively.
712   *
713   * @since 20.0
714   */
715  @Beta
716  public static long saturatedPow(long b, int k) {
717    checkNonNegative("exponent", k);
718    if (b >= -2 & b <= 2) {
719      switch ((int) b) {
720        case 0:
721          return (k == 0) ? 1 : 0;
722        case 1:
723          return 1;
724        case (-1):
725          return ((k & 1) == 0) ? 1 : -1;
726        case 2:
727          if (k >= Long.SIZE - 1) {
728            return Long.MAX_VALUE;
729          }
730          return 1L << k;
731        case (-2):
732          if (k >= Long.SIZE) {
733            return Long.MAX_VALUE + (k & 1);
734          }
735          return ((k & 1) == 0) ? (1L << k) : (-1L << k);
736        default:
737          throw new AssertionError();
738      }
739    }
740    long accum = 1;
741    // if b is negative and k is odd then the limit is MIN otherwise the limit is MAX
742    long limit = Long.MAX_VALUE + ((b >>> Long.SIZE - 1) & (k & 1));
743    while (true) {
744      switch (k) {
745        case 0:
746          return accum;
747        case 1:
748          return saturatedMultiply(accum, b);
749        default:
750          if ((k & 1) != 0) {
751            accum = saturatedMultiply(accum, b);
752          }
753          k >>= 1;
754          if (k > 0) {
755            if (-FLOOR_SQRT_MAX_LONG > b | b > FLOOR_SQRT_MAX_LONG) {
756              return limit;
757            }
758            b *= b;
759          }
760      }
761    }
762  }
763
764  @VisibleForTesting static final long FLOOR_SQRT_MAX_LONG = 3037000499L;
765
766  /**
767   * Returns {@code n!}, that is, the product of the first {@code n} positive integers, {@code 1} if
768   * {@code n == 0}, or {@link Long#MAX_VALUE} if the result does not fit in a {@code long}.
769   *
770   * @throws IllegalArgumentException if {@code n < 0}
771   */
772  @GwtIncompatible // TODO
773  public static long factorial(int n) {
774    checkNonNegative("n", n);
775    return (n < factorials.length) ? factorials[n] : Long.MAX_VALUE;
776  }
777
778  static final long[] factorials = {
779    1L,
780    1L,
781    1L * 2,
782    1L * 2 * 3,
783    1L * 2 * 3 * 4,
784    1L * 2 * 3 * 4 * 5,
785    1L * 2 * 3 * 4 * 5 * 6,
786    1L * 2 * 3 * 4 * 5 * 6 * 7,
787    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8,
788    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9,
789    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10,
790    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11,
791    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12,
792    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13,
793    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14,
794    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15,
795    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16,
796    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17,
797    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18,
798    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18 * 19,
799    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18 * 19 * 20
800  };
801
802  /**
803   * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
804   * {@code k}, or {@link Long#MAX_VALUE} if the result does not fit in a {@code long}.
805   *
806   * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
807   */
808  public static long binomial(int n, int k) {
809    checkNonNegative("n", n);
810    checkNonNegative("k", k);
811    checkArgument(k <= n, "k (%s) > n (%s)", k, n);
812    if (k > (n >> 1)) {
813      k = n - k;
814    }
815    switch (k) {
816      case 0:
817        return 1;
818      case 1:
819        return n;
820      default:
821        if (n < factorials.length) {
822          return factorials[n] / (factorials[k] * factorials[n - k]);
823        } else if (k >= biggestBinomials.length || n > biggestBinomials[k]) {
824          return Long.MAX_VALUE;
825        } else if (k < biggestSimpleBinomials.length && n <= biggestSimpleBinomials[k]) {
826          // guaranteed not to overflow
827          long result = n--;
828          for (int i = 2; i <= k; n--, i++) {
829            result *= n;
830            result /= i;
831          }
832          return result;
833        } else {
834          int nBits = LongMath.log2(n, RoundingMode.CEILING);
835
836          long result = 1;
837          long numerator = n--;
838          long denominator = 1;
839
840          int numeratorBits = nBits;
841          // This is an upper bound on log2(numerator, ceiling).
842
843          /*
844           * We want to do this in long math for speed, but want to avoid overflow. We adapt the
845           * technique previously used by BigIntegerMath: maintain separate numerator and
846           * denominator accumulators, multiplying the fraction into result when near overflow.
847           */
848          for (int i = 2; i <= k; i++, n--) {
849            if (numeratorBits + nBits < Long.SIZE - 1) {
850              // It's definitely safe to multiply into numerator and denominator.
851              numerator *= n;
852              denominator *= i;
853              numeratorBits += nBits;
854            } else {
855              // It might not be safe to multiply into numerator and denominator,
856              // so multiply (numerator / denominator) into result.
857              result = multiplyFraction(result, numerator, denominator);
858              numerator = n;
859              denominator = i;
860              numeratorBits = nBits;
861            }
862          }
863          return multiplyFraction(result, numerator, denominator);
864        }
865    }
866  }
867
868  /** Returns (x * numerator / denominator), which is assumed to come out to an integral value. */
869  static long multiplyFraction(long x, long numerator, long denominator) {
870    if (x == 1) {
871      return numerator / denominator;
872    }
873    long commonDivisor = gcd(x, denominator);
874    x /= commonDivisor;
875    denominator /= commonDivisor;
876    // We know gcd(x, denominator) = 1, and x * numerator / denominator is exact,
877    // so denominator must be a divisor of numerator.
878    return x * (numerator / denominator);
879  }
880
881  /*
882   * binomial(biggestBinomials[k], k) fits in a long, but not binomial(biggestBinomials[k] + 1, k).
883   */
884  static final int[] biggestBinomials = {
885    Integer.MAX_VALUE,
886    Integer.MAX_VALUE,
887    Integer.MAX_VALUE,
888    3810779,
889    121977,
890    16175,
891    4337,
892    1733,
893    887,
894    534,
895    361,
896    265,
897    206,
898    169,
899    143,
900    125,
901    111,
902    101,
903    94,
904    88,
905    83,
906    79,
907    76,
908    74,
909    72,
910    70,
911    69,
912    68,
913    67,
914    67,
915    66,
916    66,
917    66,
918    66
919  };
920
921  /*
922   * binomial(biggestSimpleBinomials[k], k) doesn't need to use the slower GCD-based impl, but
923   * binomial(biggestSimpleBinomials[k] + 1, k) does.
924   */
925  @VisibleForTesting
926  static final int[] biggestSimpleBinomials = {
927    Integer.MAX_VALUE,
928    Integer.MAX_VALUE,
929    Integer.MAX_VALUE,
930    2642246,
931    86251,
932    11724,
933    3218,
934    1313,
935    684,
936    419,
937    287,
938    214,
939    169,
940    139,
941    119,
942    105,
943    95,
944    87,
945    81,
946    76,
947    73,
948    70,
949    68,
950    66,
951    64,
952    63,
953    62,
954    62,
955    61,
956    61,
957    61
958  };
959  // These values were generated by using checkedMultiply to see when the simple multiply/divide
960  // algorithm would lead to an overflow.
961
962  static boolean fitsInInt(long x) {
963    return (int) x == x;
964  }
965
966  /**
967   * Returns the arithmetic mean of {@code x} and {@code y}, rounded toward negative infinity. This
968   * method is resilient to overflow.
969   *
970   * @since 14.0
971   */
972  public static long mean(long x, long y) {
973    // Efficient method for computing the arithmetic mean.
974    // The alternative (x + y) / 2 fails for large values.
975    // The alternative (x + y) >>> 1 fails for negative values.
976    return (x & y) + ((x ^ y) >> 1);
977  }
978
979  /*
980   * This bitmask is used as an optimization for cheaply testing for divisiblity by 2, 3, or 5.
981   * Each bit is set to 1 for all remainders that indicate divisibility by 2, 3, or 5, so
982   * 1, 7, 11, 13, 17, 19, 23, 29 are set to 0. 30 and up don't matter because they won't be hit.
983   */
984  private static final int SIEVE_30 =
985      ~((1 << 1) | (1 << 7) | (1 << 11) | (1 << 13) | (1 << 17) | (1 << 19) | (1 << 23)
986          | (1 << 29));
987
988  /**
989   * Returns {@code true} if {@code n} is a <a
990   * href="http://mathworld.wolfram.com/PrimeNumber.html">prime number</a>: an integer <i>greater
991   * than one</i> that cannot be factored into a product of <i>smaller</i> positive integers.
992   * Returns {@code false} if {@code n} is zero, one, or a composite number (one which <i>can</i> be
993   * factored into smaller positive integers).
994   *
995   * <p>To test larger numbers, use {@link BigInteger#isProbablePrime}.
996   *
997   * @throws IllegalArgumentException if {@code n} is negative
998   * @since 20.0
999   */
1000  @GwtIncompatible // TODO
1001  @Beta
1002  public static boolean isPrime(long n) {
1003    if (n < 2) {
1004      checkNonNegative("n", n);
1005      return false;
1006    }
1007    if (n == 2 || n == 3 || n == 5 || n == 7 || n == 11 || n == 13) {
1008      return true;
1009    }
1010
1011    if ((SIEVE_30 & (1 << (n % 30))) != 0) {
1012      return false;
1013    }
1014    if (n % 7 == 0 || n % 11 == 0 || n % 13 == 0) {
1015      return false;
1016    }
1017    if (n < 17 * 17) {
1018      return true;
1019    }
1020
1021    for (long[] baseSet : millerRabinBaseSets) {
1022      if (n <= baseSet[0]) {
1023        for (int i = 1; i < baseSet.length; i++) {
1024          if (!MillerRabinTester.test(baseSet[i], n)) {
1025            return false;
1026          }
1027        }
1028        return true;
1029      }
1030    }
1031    throw new AssertionError();
1032  }
1033
1034  /*
1035   * If n <= millerRabinBases[i][0], then testing n against bases millerRabinBases[i][1..] suffices
1036   * to prove its primality. Values from miller-rabin.appspot.com.
1037   *
1038   * NOTE: We could get slightly better bases that would be treated as unsigned, but benchmarks
1039   * showed negligible performance improvements.
1040   */
1041  private static final long[][] millerRabinBaseSets = {
1042    {291830, 126401071349994536L},
1043    {885594168, 725270293939359937L, 3569819667048198375L},
1044    {273919523040L, 15, 7363882082L, 992620450144556L},
1045    {47636622961200L, 2, 2570940, 211991001, 3749873356L},
1046    {
1047      7999252175582850L,
1048      2,
1049      4130806001517L,
1050      149795463772692060L,
1051      186635894390467037L,
1052      3967304179347715805L
1053    },
1054    {
1055      585226005592931976L,
1056      2,
1057      123635709730000L,
1058      9233062284813009L,
1059      43835965440333360L,
1060      761179012939631437L,
1061      1263739024124850375L
1062    },
1063    {Long.MAX_VALUE, 2, 325, 9375, 28178, 450775, 9780504, 1795265022}
1064  };
1065
1066  private enum MillerRabinTester {
1067    /** Works for inputs ≤ FLOOR_SQRT_MAX_LONG. */
1068    SMALL {
1069      @Override
1070      long mulMod(long a, long b, long m) {
1071        /*
1072         * NOTE(lowasser, 2015-Feb-12): Benchmarks suggest that changing this to
1073         * UnsignedLongs.remainder and increasing the threshold to 2^32 doesn't pay for itself, and
1074         * adding another enum constant hurts performance further -- I suspect because bimorphic
1075         * implementation is a sweet spot for the JVM.
1076         */
1077        return (a * b) % m;
1078      }
1079
1080      @Override
1081      long squareMod(long a, long m) {
1082        return (a * a) % m;
1083      }
1084    },
1085    /** Works for all nonnegative signed longs. */
1086    LARGE {
1087      /** Returns (a + b) mod m. Precondition: {@code 0 <= a}, {@code b < m < 2^63}. */
1088      private long plusMod(long a, long b, long m) {
1089        return (a >= m - b) ? (a + b - m) : (a + b);
1090      }
1091
1092      /** Returns (a * 2^32) mod m. a may be any unsigned long. */
1093      private long times2ToThe32Mod(long a, long m) {
1094        int remainingPowersOf2 = 32;
1095        do {
1096          int shift = Math.min(remainingPowersOf2, Long.numberOfLeadingZeros(a));
1097          // shift is either the number of powers of 2 left to multiply a by, or the biggest shift
1098          // possible while keeping a in an unsigned long.
1099          a = UnsignedLongs.remainder(a << shift, m);
1100          remainingPowersOf2 -= shift;
1101        } while (remainingPowersOf2 > 0);
1102        return a;
1103      }
1104
1105      @Override
1106      long mulMod(long a, long b, long m) {
1107        long aHi = a >>> 32; // < 2^31
1108        long bHi = b >>> 32; // < 2^31
1109        long aLo = a & 0xFFFFFFFFL; // < 2^32
1110        long bLo = b & 0xFFFFFFFFL; // < 2^32
1111
1112        /*
1113         * a * b == aHi * bHi * 2^64 + (aHi * bLo + aLo * bHi) * 2^32 + aLo * bLo.
1114         *       == (aHi * bHi * 2^32 + aHi * bLo + aLo * bHi) * 2^32 + aLo * bLo
1115         *
1116         * We carry out this computation in modular arithmetic. Since times2ToThe32Mod accepts any
1117         * unsigned long, we don't have to do a mod on every operation, only when intermediate
1118         * results can exceed 2^63.
1119         */
1120        long result = times2ToThe32Mod(aHi * bHi /* < 2^62 */, m); // < m < 2^63
1121        result += aHi * bLo; // aHi * bLo < 2^63, result < 2^64
1122        if (result < 0) {
1123          result = UnsignedLongs.remainder(result, m);
1124        }
1125        // result < 2^63 again
1126        result += aLo * bHi; // aLo * bHi < 2^63, result < 2^64
1127        result = times2ToThe32Mod(result, m); // result < m < 2^63
1128        return plusMod(result, UnsignedLongs.remainder(aLo * bLo /* < 2^64 */, m), m);
1129      }
1130
1131      @Override
1132      long squareMod(long a, long m) {
1133        long aHi = a >>> 32; // < 2^31
1134        long aLo = a & 0xFFFFFFFFL; // < 2^32
1135
1136        /*
1137         * a^2 == aHi^2 * 2^64 + aHi * aLo * 2^33 + aLo^2
1138         *     == (aHi^2 * 2^32 + aHi * aLo * 2) * 2^32 + aLo^2
1139         * We carry out this computation in modular arithmetic.  Since times2ToThe32Mod accepts any
1140         * unsigned long, we don't have to do a mod on every operation, only when intermediate
1141         * results can exceed 2^63.
1142         */
1143        long result = times2ToThe32Mod(aHi * aHi /* < 2^62 */, m); // < m < 2^63
1144        long hiLo = aHi * aLo * 2;
1145        if (hiLo < 0) {
1146          hiLo = UnsignedLongs.remainder(hiLo, m);
1147        }
1148        // hiLo < 2^63
1149        result += hiLo; // result < 2^64
1150        result = times2ToThe32Mod(result, m); // result < m < 2^63
1151        return plusMod(result, UnsignedLongs.remainder(aLo * aLo /* < 2^64 */, m), m);
1152      }
1153    };
1154
1155    static boolean test(long base, long n) {
1156      // Since base will be considered % n, it's okay if base > FLOOR_SQRT_MAX_LONG,
1157      // so long as n <= FLOOR_SQRT_MAX_LONG.
1158      return ((n <= FLOOR_SQRT_MAX_LONG) ? SMALL : LARGE).testWitness(base, n);
1159    }
1160
1161    /** Returns a * b mod m. */
1162    abstract long mulMod(long a, long b, long m);
1163
1164    /** Returns a^2 mod m. */
1165    abstract long squareMod(long a, long m);
1166
1167    /** Returns a^p mod m. */
1168    private long powMod(long a, long p, long m) {
1169      long res = 1;
1170      for (; p != 0; p >>= 1) {
1171        if ((p & 1) != 0) {
1172          res = mulMod(res, a, m);
1173        }
1174        a = squareMod(a, m);
1175      }
1176      return res;
1177    }
1178
1179    /** Returns true if n is a strong probable prime relative to the specified base. */
1180    private boolean testWitness(long base, long n) {
1181      int r = Long.numberOfTrailingZeros(n - 1);
1182      long d = (n - 1) >> r;
1183      base %= n;
1184      if (base == 0) {
1185        return true;
1186      }
1187      // Calculate a := base^d mod n.
1188      long a = powMod(base, d, n);
1189      // n passes this test if
1190      //    base^d = 1 (mod n)
1191      // or base^(2^j * d) = -1 (mod n) for some 0 <= j < r.
1192      if (a == 1) {
1193        return true;
1194      }
1195      int j = 0;
1196      while (a != n - 1) {
1197        if (++j == r) {
1198          return false;
1199        }
1200        a = squareMod(a, n);
1201      }
1202      return true;
1203    }
1204  }
1205
1206  private LongMath() {}
1207}