001/*
002 * Copyright (C) 2011 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
005 * in compliance with the License. You may obtain a copy of the License at
006 *
007 * http://www.apache.org/licenses/LICENSE-2.0
008 *
009 * Unless required by applicable law or agreed to in writing, software distributed under the License
010 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
011 * or implied. See the License for the specific language governing permissions and limitations under
012 * the License.
013 */
014
015package com.google.common.math;
016
017import static com.google.common.base.Preconditions.checkArgument;
018import static com.google.common.base.Preconditions.checkNotNull;
019import static com.google.common.math.MathPreconditions.checkNoOverflow;
020import static com.google.common.math.MathPreconditions.checkNonNegative;
021import static com.google.common.math.MathPreconditions.checkPositive;
022import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
023import static java.lang.Math.abs;
024import static java.lang.Math.min;
025import static java.math.RoundingMode.HALF_EVEN;
026import static java.math.RoundingMode.HALF_UP;
027
028import com.google.common.annotations.Beta;
029import com.google.common.annotations.GwtCompatible;
030import com.google.common.annotations.GwtIncompatible;
031import com.google.common.annotations.VisibleForTesting;
032import com.google.common.primitives.UnsignedLongs;
033import java.math.BigInteger;
034import java.math.RoundingMode;
035
036/**
037 * A class for arithmetic on values of type {@code long}. Where possible, methods are defined and
038 * named analogously to their {@code BigInteger} counterparts.
039 *
040 * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
041 * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
042 *
043 * <p>Similar functionality for {@code int} and for {@link BigInteger} can be found in
044 * {@link IntMath} and {@link BigIntegerMath} respectively. For other common operations on
045 * {@code long} values, see {@link com.google.common.primitives.Longs}.
046 *
047 * @author Louis Wasserman
048 * @since 11.0
049 */
050@GwtCompatible(emulated = true)
051public final class LongMath {
052  // NOTE: Whenever both tests are cheap and functional, it's faster to use &, | instead of &&, ||
053
054  @VisibleForTesting static final long MAX_SIGNED_POWER_OF_TWO = 1L << (Long.SIZE - 2);
055
056  /**
057   * Returns the smallest power of two greater than or equal to {@code x}.  This is equivalent to
058   * {@code checkedPow(2, log2(x, CEILING))}.
059   *
060   * @throws IllegalArgumentException if {@code x <= 0}
061   * @throws ArithmeticException of the next-higher power of two is not representable as a
062   *         {@code long}, i.e. when {@code x > 2^62}
063   * @since 20.0
064   */
065  @Beta
066  public static long ceilingPowerOfTwo(long x) {
067    checkPositive("x", x);
068    if (x > MAX_SIGNED_POWER_OF_TWO) {
069      throw new ArithmeticException("ceilingPowerOfTwo(" + x + ") is not representable as a long");
070    }
071    return 1L << -Long.numberOfLeadingZeros(x - 1);
072  }
073
074  /**
075   * Returns the largest power of two less than or equal to {@code x}.  This is equivalent to
076   * {@code checkedPow(2, log2(x, FLOOR))}.
077   *
078   * @throws IllegalArgumentException if {@code x <= 0}
079   * @since 20.0
080   */
081  @Beta
082  public static long floorPowerOfTwo(long x) {
083    checkPositive("x", x);
084
085    // Long.highestOneBit was buggy on GWT.  We've fixed it, but I'm not certain when the fix will
086    // be released.
087    return 1L << ((Long.SIZE - 1) - Long.numberOfLeadingZeros(x));
088  }
089
090  /**
091   * Returns {@code true} if {@code x} represents a power of two.
092   *
093   * <p>This differs from {@code Long.bitCount(x) == 1}, because
094   * {@code Long.bitCount(Long.MIN_VALUE) == 1}, but {@link Long#MIN_VALUE} is not a power of two.
095   */
096  public static boolean isPowerOfTwo(long x) {
097    return x > 0 & (x & (x - 1)) == 0;
098  }
099
100  /**
101   * Returns 1 if {@code x < y} as unsigned longs, and 0 otherwise. Assumes that x - y fits into a
102   * signed long. The implementation is branch-free, and benchmarks suggest it is measurably faster
103   * than the straightforward ternary expression.
104   */
105  @VisibleForTesting
106  static int lessThanBranchFree(long x, long y) {
107    // Returns the sign bit of x - y.
108    return (int) (~~(x - y) >>> (Long.SIZE - 1));
109  }
110
111  /**
112   * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
113   *
114   * @throws IllegalArgumentException if {@code x <= 0}
115   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
116   *     is not a power of two
117   */
118  @SuppressWarnings("fallthrough")
119  // TODO(kevinb): remove after this warning is disabled globally
120  public static int log2(long x, RoundingMode mode) {
121    checkPositive("x", x);
122    switch (mode) {
123      case UNNECESSARY:
124        checkRoundingUnnecessary(isPowerOfTwo(x));
125        // fall through
126      case DOWN:
127      case FLOOR:
128        return (Long.SIZE - 1) - Long.numberOfLeadingZeros(x);
129
130      case UP:
131      case CEILING:
132        return Long.SIZE - Long.numberOfLeadingZeros(x - 1);
133
134      case HALF_DOWN:
135      case HALF_UP:
136      case HALF_EVEN:
137        // Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
138        int leadingZeros = Long.numberOfLeadingZeros(x);
139        long cmp = MAX_POWER_OF_SQRT2_UNSIGNED >>> leadingZeros;
140        // floor(2^(logFloor + 0.5))
141        int logFloor = (Long.SIZE - 1) - leadingZeros;
142        return logFloor + lessThanBranchFree(cmp, x);
143
144      default:
145        throw new AssertionError("impossible");
146    }
147  }
148
149  /** The biggest half power of two that fits into an unsigned long */
150  @VisibleForTesting static final long MAX_POWER_OF_SQRT2_UNSIGNED = 0xB504F333F9DE6484L;
151
152  /**
153   * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
154   *
155   * @throws IllegalArgumentException if {@code x <= 0}
156   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
157   *     is not a power of ten
158   */
159  @GwtIncompatible // TODO
160  @SuppressWarnings("fallthrough")
161  // TODO(kevinb): remove after this warning is disabled globally
162  public static int log10(long x, RoundingMode mode) {
163    checkPositive("x", x);
164    int logFloor = log10Floor(x);
165    long floorPow = powersOf10[logFloor];
166    switch (mode) {
167      case UNNECESSARY:
168        checkRoundingUnnecessary(x == floorPow);
169        // fall through
170      case FLOOR:
171      case DOWN:
172        return logFloor;
173      case CEILING:
174      case UP:
175        return logFloor + lessThanBranchFree(floorPow, x);
176      case HALF_DOWN:
177      case HALF_UP:
178      case HALF_EVEN:
179        // sqrt(10) is irrational, so log10(x)-logFloor is never exactly 0.5
180        return logFloor + lessThanBranchFree(halfPowersOf10[logFloor], x);
181      default:
182        throw new AssertionError();
183    }
184  }
185
186  @GwtIncompatible // TODO
187  static int log10Floor(long x) {
188    /*
189     * Based on Hacker's Delight Fig. 11-5, the two-table-lookup, branch-free implementation.
190     *
191     * The key idea is that based on the number of leading zeros (equivalently, floor(log2(x))), we
192     * can narrow the possible floor(log10(x)) values to two. For example, if floor(log2(x)) is 6,
193     * then 64 <= x < 128, so floor(log10(x)) is either 1 or 2.
194     */
195    int y = maxLog10ForLeadingZeros[Long.numberOfLeadingZeros(x)];
196    /*
197     * y is the higher of the two possible values of floor(log10(x)). If x < 10^y, then we want the
198     * lower of the two possible values, or y - 1, otherwise, we want y.
199     */
200    return y - lessThanBranchFree(x, powersOf10[y]);
201  }
202
203  // maxLog10ForLeadingZeros[i] == floor(log10(2^(Long.SIZE - i)))
204  @VisibleForTesting
205  static final byte[] maxLog10ForLeadingZeros = {
206    19, 18, 18, 18, 18, 17, 17, 17, 16, 16, 16, 15, 15, 15, 15, 14, 14, 14, 13, 13, 13, 12, 12, 12,
207    12, 11, 11, 11, 10, 10, 10, 9, 9, 9, 9, 8, 8, 8, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 4, 4, 4, 3, 3, 3,
208    3, 2, 2, 2, 1, 1, 1, 0, 0, 0
209  };
210
211  @GwtIncompatible // TODO
212  @VisibleForTesting
213  static final long[] powersOf10 = {
214    1L,
215    10L,
216    100L,
217    1000L,
218    10000L,
219    100000L,
220    1000000L,
221    10000000L,
222    100000000L,
223    1000000000L,
224    10000000000L,
225    100000000000L,
226    1000000000000L,
227    10000000000000L,
228    100000000000000L,
229    1000000000000000L,
230    10000000000000000L,
231    100000000000000000L,
232    1000000000000000000L
233  };
234
235  // halfPowersOf10[i] = largest long less than 10^(i + 0.5)
236  @GwtIncompatible // TODO
237  @VisibleForTesting
238  static final long[] halfPowersOf10 = {
239    3L,
240    31L,
241    316L,
242    3162L,
243    31622L,
244    316227L,
245    3162277L,
246    31622776L,
247    316227766L,
248    3162277660L,
249    31622776601L,
250    316227766016L,
251    3162277660168L,
252    31622776601683L,
253    316227766016837L,
254    3162277660168379L,
255    31622776601683793L,
256    316227766016837933L,
257    3162277660168379331L
258  };
259
260  /**
261   * Returns {@code b} to the {@code k}th power. Even if the result overflows, it will be equal to
262   * {@code BigInteger.valueOf(b).pow(k).longValue()}. This implementation runs in {@code O(log k)}
263   * time.
264   *
265   * @throws IllegalArgumentException if {@code k < 0}
266   */
267  @GwtIncompatible // TODO
268  public static long pow(long b, int k) {
269    checkNonNegative("exponent", k);
270    if (-2 <= b && b <= 2) {
271      switch ((int) b) {
272        case 0:
273          return (k == 0) ? 1 : 0;
274        case 1:
275          return 1;
276        case (-1):
277          return ((k & 1) == 0) ? 1 : -1;
278        case 2:
279          return (k < Long.SIZE) ? 1L << k : 0;
280        case (-2):
281          if (k < Long.SIZE) {
282            return ((k & 1) == 0) ? 1L << k : -(1L << k);
283          } else {
284            return 0;
285          }
286        default:
287          throw new AssertionError();
288      }
289    }
290    for (long accum = 1; ; k >>= 1) {
291      switch (k) {
292        case 0:
293          return accum;
294        case 1:
295          return accum * b;
296        default:
297          accum *= ((k & 1) == 0) ? 1 : b;
298          b *= b;
299      }
300    }
301  }
302
303  /**
304   * Returns the square root of {@code x}, rounded with the specified rounding mode.
305   *
306   * @throws IllegalArgumentException if {@code x < 0}
307   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and
308   *     {@code sqrt(x)} is not an integer
309   */
310  @GwtIncompatible // TODO
311  @SuppressWarnings("fallthrough")
312  public static long sqrt(long x, RoundingMode mode) {
313    checkNonNegative("x", x);
314    if (fitsInInt(x)) {
315      return IntMath.sqrt((int) x, mode);
316    }
317    /*
318     * Let k be the true value of floor(sqrt(x)), so that
319     *
320     *            k * k <= x          <  (k + 1) * (k + 1)
321     * (double) (k * k) <= (double) x <= (double) ((k + 1) * (k + 1))
322     *          since casting to double is nondecreasing.
323     *          Note that the right-hand inequality is no longer strict.
324     * Math.sqrt(k * k) <= Math.sqrt(x) <= Math.sqrt((k + 1) * (k + 1))
325     *          since Math.sqrt is monotonic.
326     * (long) Math.sqrt(k * k) <= (long) Math.sqrt(x) <= (long) Math.sqrt((k + 1) * (k + 1))
327     *          since casting to long is monotonic
328     * k <= (long) Math.sqrt(x) <= k + 1
329     *          since (long) Math.sqrt(k * k) == k, as checked exhaustively in
330     *          {@link LongMathTest#testSqrtOfPerfectSquareAsDoubleIsPerfect}
331     */
332    long guess = (long) Math.sqrt(x);
333    // Note: guess is always <= FLOOR_SQRT_MAX_LONG.
334    long guessSquared = guess * guess;
335    // Note (2013-2-26): benchmarks indicate that, inscrutably enough, using if statements is
336    // faster here than using lessThanBranchFree.
337    switch (mode) {
338      case UNNECESSARY:
339        checkRoundingUnnecessary(guessSquared == x);
340        return guess;
341      case FLOOR:
342      case DOWN:
343        if (x < guessSquared) {
344          return guess - 1;
345        }
346        return guess;
347      case CEILING:
348      case UP:
349        if (x > guessSquared) {
350          return guess + 1;
351        }
352        return guess;
353      case HALF_DOWN:
354      case HALF_UP:
355      case HALF_EVEN:
356        long sqrtFloor = guess - ((x < guessSquared) ? 1 : 0);
357        long halfSquare = sqrtFloor * sqrtFloor + sqrtFloor;
358        /*
359         * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both x
360         * and halfSquare are integers, this is equivalent to testing whether or not x <=
361         * halfSquare. (We have to deal with overflow, though.)
362         *
363         * If we treat halfSquare as an unsigned long, we know that
364         *            sqrtFloor^2 <= x < (sqrtFloor + 1)^2
365         * halfSquare - sqrtFloor <= x < halfSquare + sqrtFloor + 1
366         * so |x - halfSquare| <= sqrtFloor.  Therefore, it's safe to treat x - halfSquare as a
367         * signed long, so lessThanBranchFree is safe for use.
368         */
369        return sqrtFloor + lessThanBranchFree(halfSquare, x);
370      default:
371        throw new AssertionError();
372    }
373  }
374
375  /**
376   * Returns the result of dividing {@code p} by {@code q}, rounding using the specified
377   * {@code RoundingMode}.
378   *
379   * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
380   *     is not an integer multiple of {@code b}
381   */
382  @GwtIncompatible // TODO
383  @SuppressWarnings("fallthrough")
384  public static long divide(long p, long q, RoundingMode mode) {
385    checkNotNull(mode);
386    long div = p / q; // throws if q == 0
387    long rem = p - q * div; // equals p % q
388
389    if (rem == 0) {
390      return div;
391    }
392
393    /*
394     * Normal Java division rounds towards 0, consistently with RoundingMode.DOWN. We just have to
395     * deal with the cases where rounding towards 0 is wrong, which typically depends on the sign of
396     * p / q.
397     *
398     * signum is 1 if p and q are both nonnegative or both negative, and -1 otherwise.
399     */
400    int signum = 1 | (int) ((p ^ q) >> (Long.SIZE - 1));
401    boolean increment;
402    switch (mode) {
403      case UNNECESSARY:
404        checkRoundingUnnecessary(rem == 0);
405        // fall through
406      case DOWN:
407        increment = false;
408        break;
409      case UP:
410        increment = true;
411        break;
412      case CEILING:
413        increment = signum > 0;
414        break;
415      case FLOOR:
416        increment = signum < 0;
417        break;
418      case HALF_EVEN:
419      case HALF_DOWN:
420      case HALF_UP:
421        long absRem = abs(rem);
422        long cmpRemToHalfDivisor = absRem - (abs(q) - absRem);
423        // subtracting two nonnegative longs can't overflow
424        // cmpRemToHalfDivisor has the same sign as compare(abs(rem), abs(q) / 2).
425        if (cmpRemToHalfDivisor == 0) { // exactly on the half mark
426          increment = (mode == HALF_UP | (mode == HALF_EVEN & (div & 1) != 0));
427        } else {
428          increment = cmpRemToHalfDivisor > 0; // closer to the UP value
429        }
430        break;
431      default:
432        throw new AssertionError();
433    }
434    return increment ? div + signum : div;
435  }
436
437  /**
438   * Returns {@code x mod m}, a non-negative value less than {@code m}. This differs from
439   * {@code x % m}, which might be negative.
440   *
441   * <p>For example:
442   *
443   * <pre> {@code
444   *
445   * mod(7, 4) == 3
446   * mod(-7, 4) == 1
447   * mod(-1, 4) == 3
448   * mod(-8, 4) == 0
449   * mod(8, 4) == 0}</pre>
450   *
451   * @throws ArithmeticException if {@code m <= 0}
452   * @see <a href="http://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.17.3">
453   *     Remainder Operator</a>
454   */
455  @GwtIncompatible // TODO
456  public static int mod(long x, int m) {
457    // Cast is safe because the result is guaranteed in the range [0, m)
458    return (int) mod(x, (long) m);
459  }
460
461  /**
462   * Returns {@code x mod m}, a non-negative value less than {@code m}. This differs from
463   * {@code x % m}, which might be negative.
464   *
465   * <p>For example:
466   *
467   * <pre> {@code
468   *
469   * mod(7, 4) == 3
470   * mod(-7, 4) == 1
471   * mod(-1, 4) == 3
472   * mod(-8, 4) == 0
473   * mod(8, 4) == 0}</pre>
474   *
475   * @throws ArithmeticException if {@code m <= 0}
476   * @see <a href="http://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.17.3">
477   *     Remainder Operator</a>
478   */
479  @GwtIncompatible // TODO
480  public static long mod(long x, long m) {
481    if (m <= 0) {
482      throw new ArithmeticException("Modulus must be positive");
483    }
484    long result = x % m;
485    return (result >= 0) ? result : result + m;
486  }
487
488  /**
489   * Returns the greatest common divisor of {@code a, b}. Returns {@code 0} if
490   * {@code a == 0 && b == 0}.
491   *
492   * @throws IllegalArgumentException if {@code a < 0} or {@code b < 0}
493   */
494  public static long gcd(long a, long b) {
495    /*
496     * The reason we require both arguments to be >= 0 is because otherwise, what do you return on
497     * gcd(0, Long.MIN_VALUE)? BigInteger.gcd would return positive 2^63, but positive 2^63 isn't an
498     * int.
499     */
500    checkNonNegative("a", a);
501    checkNonNegative("b", b);
502    if (a == 0) {
503      // 0 % b == 0, so b divides a, but the converse doesn't hold.
504      // BigInteger.gcd is consistent with this decision.
505      return b;
506    } else if (b == 0) {
507      return a; // similar logic
508    }
509    /*
510     * Uses the binary GCD algorithm; see http://en.wikipedia.org/wiki/Binary_GCD_algorithm. This is
511     * >60% faster than the Euclidean algorithm in benchmarks.
512     */
513    int aTwos = Long.numberOfTrailingZeros(a);
514    a >>= aTwos; // divide out all 2s
515    int bTwos = Long.numberOfTrailingZeros(b);
516    b >>= bTwos; // divide out all 2s
517    while (a != b) { // both a, b are odd
518      // The key to the binary GCD algorithm is as follows:
519      // Both a and b are odd. Assume a > b; then gcd(a - b, b) = gcd(a, b).
520      // But in gcd(a - b, b), a - b is even and b is odd, so we can divide out powers of two.
521
522      // We bend over backwards to avoid branching, adapting a technique from
523      // http://graphics.stanford.edu/~seander/bithacks.html#IntegerMinOrMax
524
525      long delta = a - b; // can't overflow, since a and b are nonnegative
526
527      long minDeltaOrZero = delta & (delta >> (Long.SIZE - 1));
528      // equivalent to Math.min(delta, 0)
529
530      a = delta - minDeltaOrZero - minDeltaOrZero; // sets a to Math.abs(a - b)
531      // a is now nonnegative and even
532
533      b += minDeltaOrZero; // sets b to min(old a, b)
534      a >>= Long.numberOfTrailingZeros(a); // divide out all 2s, since 2 doesn't divide b
535    }
536    return a << min(aTwos, bTwos);
537  }
538
539  /**
540   * Returns the sum of {@code a} and {@code b}, provided it does not overflow.
541   *
542   * @throws ArithmeticException if {@code a + b} overflows in signed {@code long} arithmetic
543   */
544  @GwtIncompatible // TODO
545  public static long checkedAdd(long a, long b) {
546    long result = a + b;
547    checkNoOverflow((a ^ b) < 0 | (a ^ result) >= 0);
548    return result;
549  }
550
551  /**
552   * Returns the difference of {@code a} and {@code b}, provided it does not overflow.
553   *
554   * @throws ArithmeticException if {@code a - b} overflows in signed {@code long} arithmetic
555   */
556  @GwtIncompatible // TODO
557  public static long checkedSubtract(long a, long b) {
558    long result = a - b;
559    checkNoOverflow((a ^ b) >= 0 | (a ^ result) >= 0);
560    return result;
561  }
562
563  /**
564   * Returns the product of {@code a} and {@code b}, provided it does not overflow.
565   *
566   * @throws ArithmeticException if {@code a * b} overflows in signed {@code long} arithmetic
567   */
568  @GwtIncompatible // TODO
569  public static long checkedMultiply(long a, long b) {
570    // Hacker's Delight, Section 2-12
571    int leadingZeros =
572        Long.numberOfLeadingZeros(a)
573            + Long.numberOfLeadingZeros(~a)
574            + Long.numberOfLeadingZeros(b)
575            + Long.numberOfLeadingZeros(~b);
576    /*
577     * If leadingZeros > Long.SIZE + 1 it's definitely fine, if it's < Long.SIZE it's definitely
578     * bad. We do the leadingZeros check to avoid the division below if at all possible.
579     *
580     * Otherwise, if b == Long.MIN_VALUE, then the only allowed values of a are 0 and 1. We take
581     * care of all a < 0 with their own check, because in particular, the case a == -1 will
582     * incorrectly pass the division check below.
583     *
584     * In all other cases, we check that either a is 0 or the result is consistent with division.
585     */
586    if (leadingZeros > Long.SIZE + 1) {
587      return a * b;
588    }
589    checkNoOverflow(leadingZeros >= Long.SIZE);
590    checkNoOverflow(a >= 0 | b != Long.MIN_VALUE);
591    long result = a * b;
592    checkNoOverflow(a == 0 || result / a == b);
593    return result;
594  }
595
596  /**
597   * Returns the {@code b} to the {@code k}th power, provided it does not overflow.
598   *
599   * @throws ArithmeticException if {@code b} to the {@code k}th power overflows in signed
600   *     {@code long} arithmetic
601   */
602  @GwtIncompatible // TODO
603  public static long checkedPow(long b, int k) {
604    checkNonNegative("exponent", k);
605    if (b >= -2 & b <= 2) {
606      switch ((int) b) {
607        case 0:
608          return (k == 0) ? 1 : 0;
609        case 1:
610          return 1;
611        case (-1):
612          return ((k & 1) == 0) ? 1 : -1;
613        case 2:
614          checkNoOverflow(k < Long.SIZE - 1);
615          return 1L << k;
616        case (-2):
617          checkNoOverflow(k < Long.SIZE);
618          return ((k & 1) == 0) ? (1L << k) : (-1L << k);
619        default:
620          throw new AssertionError();
621      }
622    }
623    long accum = 1;
624    while (true) {
625      switch (k) {
626        case 0:
627          return accum;
628        case 1:
629          return checkedMultiply(accum, b);
630        default:
631          if ((k & 1) != 0) {
632            accum = checkedMultiply(accum, b);
633          }
634          k >>= 1;
635          if (k > 0) {
636            checkNoOverflow(-FLOOR_SQRT_MAX_LONG <= b && b <= FLOOR_SQRT_MAX_LONG);
637            b *= b;
638          }
639      }
640    }
641  }
642
643  /**
644   * Returns the sum of {@code a} and {@code b} unless it would overflow or underflow in which case
645   * {@code Long.MAX_VALUE} or {@code Long.MIN_VALUE} is returned, respectively.
646   *
647   * @since 20.0
648   */
649  @Beta
650  public static long saturatedAdd(long a, long b) {
651    long naiveSum = a + b;
652    if ((a ^ b) < 0 | (a ^ naiveSum) >= 0) {
653      // If a and b have different signs or a has the same sign as the result then there was no
654      // overflow, return.
655      return naiveSum;
656    }
657    // we did over/under flow, if the sign is negative we should return MAX otherwise MIN
658    return Long.MAX_VALUE + ((naiveSum >>> (Long.SIZE - 1)) ^ 1);
659  }
660
661  /**
662   * Returns the difference of {@code a} and {@code b} unless it would overflow or underflow in
663   * which case {@code Long.MAX_VALUE} or {@code Long.MIN_VALUE} is returned, respectively.
664   *
665   * @since 20.0
666   */
667  @Beta
668  public static long saturatedSubtract(long a, long b) {
669    long naiveDifference = a - b;
670    if ((a ^ b) >= 0 | (a ^ naiveDifference) >= 0) {
671      // If a and b have the same signs or a has the same sign as the result then there was no
672      // overflow, return.
673      return naiveDifference;
674    }
675    // we did over/under flow
676    return Long.MAX_VALUE + ((naiveDifference >>> (Long.SIZE - 1)) ^ 1);
677  }
678
679  /**
680   * Returns the product of {@code a} and {@code b} unless it would overflow or underflow in which
681   * case {@code Long.MAX_VALUE} or {@code Long.MIN_VALUE} is returned, respectively.
682   *
683   * @since 20.0
684   */
685  @Beta
686  public static long saturatedMultiply(long a, long b) {
687    // see checkedMultiply for explanation
688    int leadingZeros =
689        Long.numberOfLeadingZeros(a)
690            + Long.numberOfLeadingZeros(~a)
691            + Long.numberOfLeadingZeros(b)
692            + Long.numberOfLeadingZeros(~b);
693    if (leadingZeros > Long.SIZE + 1) {
694      return a * b;
695    }
696    // the return value if we will overflow (which we calculate by overflowing a long :) )
697    long limit = Long.MAX_VALUE + ((a ^ b) >>> (Long.SIZE - 1));
698    if (leadingZeros < Long.SIZE | (a < 0 & b == Long.MIN_VALUE)) {
699      // overflow
700      return limit;
701    }
702    long result = a * b;
703    if (a == 0 || result / a == b) {
704      return result;
705    }
706    return limit;
707  }
708
709  /**
710   * Returns the {@code b} to the {@code k}th power, unless it would overflow or underflow in which
711   * case {@code Long.MAX_VALUE} or {@code Long.MIN_VALUE} is returned, respectively.
712   *
713   * @since 20.0
714   */
715  @Beta
716  public static long saturatedPow(long b, int k) {
717    checkNonNegative("exponent", k);
718    if (b >= -2 & b <= 2) {
719      switch ((int) b) {
720        case 0:
721          return (k == 0) ? 1 : 0;
722        case 1:
723          return 1;
724        case (-1):
725          return ((k & 1) == 0) ? 1 : -1;
726        case 2:
727          if (k >= Long.SIZE - 1) {
728            return Long.MAX_VALUE;
729          }
730          return 1L << k;
731        case (-2):
732          if (k >= Long.SIZE) {
733            return Long.MAX_VALUE + (k & 1);
734          }
735          return ((k & 1) == 0) ? (1L << k) : (-1L << k);
736        default:
737          throw new AssertionError();
738      }
739    }
740    long accum = 1;
741    // if b is negative and k is odd then the limit is MIN otherwise the limit is MAX
742    long limit = Long.MAX_VALUE + ((b >>> Long.SIZE - 1) & (k & 1));
743    while (true) {
744      switch (k) {
745        case 0:
746          return accum;
747        case 1:
748          return saturatedMultiply(accum, b);
749        default:
750          if ((k & 1) != 0) {
751            accum = saturatedMultiply(accum, b);
752          }
753          k >>= 1;
754          if (k > 0) {
755            if (-FLOOR_SQRT_MAX_LONG > b | b > FLOOR_SQRT_MAX_LONG) {
756              return limit;
757            }
758            b *= b;
759          }
760      }
761    }
762  }
763
764  @VisibleForTesting static final long FLOOR_SQRT_MAX_LONG = 3037000499L;
765
766  /**
767   * Returns {@code n!}, that is, the product of the first {@code n} positive integers, {@code 1} if
768   * {@code n == 0}, or {@link Long#MAX_VALUE} if the result does not fit in a {@code long}.
769   *
770   * @throws IllegalArgumentException if {@code n < 0}
771   */
772  @GwtIncompatible // TODO
773  public static long factorial(int n) {
774    checkNonNegative("n", n);
775    return (n < factorials.length) ? factorials[n] : Long.MAX_VALUE;
776  }
777
778  static final long[] factorials = {
779    1L,
780    1L,
781    1L * 2,
782    1L * 2 * 3,
783    1L * 2 * 3 * 4,
784    1L * 2 * 3 * 4 * 5,
785    1L * 2 * 3 * 4 * 5 * 6,
786    1L * 2 * 3 * 4 * 5 * 6 * 7,
787    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8,
788    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9,
789    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10,
790    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11,
791    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12,
792    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13,
793    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14,
794    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15,
795    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16,
796    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17,
797    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18,
798    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18 * 19,
799    1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18 * 19 * 20
800  };
801
802  /**
803   * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
804   * {@code k}, or {@link Long#MAX_VALUE} if the result does not fit in a {@code long}.
805   *
806   * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
807   */
808  public static long binomial(int n, int k) {
809    checkNonNegative("n", n);
810    checkNonNegative("k", k);
811    checkArgument(k <= n, "k (%s) > n (%s)", k, n);
812    if (k > (n >> 1)) {
813      k = n - k;
814    }
815    switch (k) {
816      case 0:
817        return 1;
818      case 1:
819        return n;
820      default:
821        if (n < factorials.length) {
822          return factorials[n] / (factorials[k] * factorials[n - k]);
823        } else if (k >= biggestBinomials.length || n > biggestBinomials[k]) {
824          return Long.MAX_VALUE;
825        } else if (k < biggestSimpleBinomials.length && n <= biggestSimpleBinomials[k]) {
826          // guaranteed not to overflow
827          long result = n--;
828          for (int i = 2; i <= k; n--, i++) {
829            result *= n;
830            result /= i;
831          }
832          return result;
833        } else {
834          int nBits = LongMath.log2(n, RoundingMode.CEILING);
835
836          long result = 1;
837          long numerator = n--;
838          long denominator = 1;
839
840          int numeratorBits = nBits;
841          // This is an upper bound on log2(numerator, ceiling).
842
843          /*
844           * We want to do this in long math for speed, but want to avoid overflow. We adapt the
845           * technique previously used by BigIntegerMath: maintain separate numerator and
846           * denominator accumulators, multiplying the fraction into result when near overflow.
847           */
848          for (int i = 2; i <= k; i++, n--) {
849            if (numeratorBits + nBits < Long.SIZE - 1) {
850              // It's definitely safe to multiply into numerator and denominator.
851              numerator *= n;
852              denominator *= i;
853              numeratorBits += nBits;
854            } else {
855              // It might not be safe to multiply into numerator and denominator,
856              // so multiply (numerator / denominator) into result.
857              result = multiplyFraction(result, numerator, denominator);
858              numerator = n;
859              denominator = i;
860              numeratorBits = nBits;
861            }
862          }
863          return multiplyFraction(result, numerator, denominator);
864        }
865    }
866  }
867
868  /**
869   * Returns (x * numerator / denominator), which is assumed to come out to an integral value.
870   */
871  static long multiplyFraction(long x, long numerator, long denominator) {
872    if (x == 1) {
873      return numerator / denominator;
874    }
875    long commonDivisor = gcd(x, denominator);
876    x /= commonDivisor;
877    denominator /= commonDivisor;
878    // We know gcd(x, denominator) = 1, and x * numerator / denominator is exact,
879    // so denominator must be a divisor of numerator.
880    return x * (numerator / denominator);
881  }
882
883  /*
884   * binomial(biggestBinomials[k], k) fits in a long, but not binomial(biggestBinomials[k] + 1, k).
885   */
886  static final int[] biggestBinomials = {
887    Integer.MAX_VALUE,
888    Integer.MAX_VALUE,
889    Integer.MAX_VALUE,
890    3810779,
891    121977,
892    16175,
893    4337,
894    1733,
895    887,
896    534,
897    361,
898    265,
899    206,
900    169,
901    143,
902    125,
903    111,
904    101,
905    94,
906    88,
907    83,
908    79,
909    76,
910    74,
911    72,
912    70,
913    69,
914    68,
915    67,
916    67,
917    66,
918    66,
919    66,
920    66
921  };
922
923  /*
924   * binomial(biggestSimpleBinomials[k], k) doesn't need to use the slower GCD-based impl, but
925   * binomial(biggestSimpleBinomials[k] + 1, k) does.
926   */
927  @VisibleForTesting
928  static final int[] biggestSimpleBinomials = {
929    Integer.MAX_VALUE,
930    Integer.MAX_VALUE,
931    Integer.MAX_VALUE,
932    2642246,
933    86251,
934    11724,
935    3218,
936    1313,
937    684,
938    419,
939    287,
940    214,
941    169,
942    139,
943    119,
944    105,
945    95,
946    87,
947    81,
948    76,
949    73,
950    70,
951    68,
952    66,
953    64,
954    63,
955    62,
956    62,
957    61,
958    61,
959    61
960  };
961  // These values were generated by using checkedMultiply to see when the simple multiply/divide
962  // algorithm would lead to an overflow.
963
964  static boolean fitsInInt(long x) {
965    return (int) x == x;
966  }
967
968  /**
969   * Returns the arithmetic mean of {@code x} and {@code y}, rounded toward negative infinity. This
970   * method is resilient to overflow.
971   *
972   * @since 14.0
973   */
974  public static long mean(long x, long y) {
975    // Efficient method for computing the arithmetic mean.
976    // The alternative (x + y) / 2 fails for large values.
977    // The alternative (x + y) >>> 1 fails for negative values.
978    return (x & y) + ((x ^ y) >> 1);
979  }
980
981  /*
982   * This bitmask is used as an optimization for cheaply testing for divisiblity by 2, 3, or 5.
983   * Each bit is set to 1 for all remainders that indicate divisibility by 2, 3, or 5, so
984   * 1, 7, 11, 13, 17, 19, 23, 29 are set to 0. 30 and up don't matter because they won't be hit.
985   */
986  private static final int SIEVE_30 =
987      ~((1 << 1) | (1 << 7) | (1 << 11) | (1 << 13)
988          | (1 << 17) | (1 << 19) | (1 << 23) | (1 << 29));
989
990  /**
991   * Returns {@code true} if {@code n} is a
992   * <a href="http://mathworld.wolfram.com/PrimeNumber.html">prime number</a>: an integer <i>greater
993   * than one</i> that cannot be factored into a product of <i>smaller</i> positive integers.
994   * Returns {@code false} if {@code n} is zero, one, or a composite number (one which <i>can</i>
995   * be factored into smaller positive integers).
996   *
997   * <p>To test larger numbers, use {@link BigInteger#isProbablePrime}.
998   *
999   * @throws IllegalArgumentException if {@code n} is negative
1000   * @since 20.0
1001   */
1002  @GwtIncompatible // TODO
1003  @Beta
1004  public static boolean isPrime(long n) {
1005    if (n < 2) {
1006      checkNonNegative("n", n);
1007      return false;
1008    }
1009    if (n == 2 || n == 3 || n == 5 || n == 7 || n == 11 || n == 13) {
1010      return true;
1011    }
1012
1013    if ((SIEVE_30 & (1 << (n % 30))) != 0) {
1014      return false;
1015    }
1016    if (n % 7 == 0 || n % 11 == 0 || n % 13 == 0) {
1017      return false;
1018    }
1019    if (n < 17 * 17) {
1020      return true;
1021    }
1022
1023    for (long[] baseSet : millerRabinBaseSets) {
1024      if (n <= baseSet[0]) {
1025        for (int i = 1; i < baseSet.length; i++) {
1026          if (!MillerRabinTester.test(baseSet[i], n)) {
1027            return false;
1028          }
1029        }
1030        return true;
1031      }
1032    }
1033    throw new AssertionError();
1034  }
1035
1036  /*
1037   * If n <= millerRabinBases[i][0], then testing n against bases millerRabinBases[i][1..] suffices
1038   * to prove its primality. Values from miller-rabin.appspot.com.
1039   *
1040   * NOTE: We could get slightly better bases that would be treated as unsigned, but benchmarks
1041   * showed negligible performance improvements.
1042   */
1043  private static final long[][] millerRabinBaseSets = {
1044    {291830, 126401071349994536L},
1045    {885594168, 725270293939359937L, 3569819667048198375L},
1046    {273919523040L, 15, 7363882082L, 992620450144556L},
1047    {47636622961200L, 2, 2570940, 211991001, 3749873356L},
1048    {
1049      7999252175582850L,
1050      2,
1051      4130806001517L,
1052      149795463772692060L,
1053      186635894390467037L,
1054      3967304179347715805L
1055    },
1056    {
1057      585226005592931976L,
1058      2,
1059      123635709730000L,
1060      9233062284813009L,
1061      43835965440333360L,
1062      761179012939631437L,
1063      1263739024124850375L
1064    },
1065    {Long.MAX_VALUE, 2, 325, 9375, 28178, 450775, 9780504, 1795265022}
1066  };
1067
1068  private enum MillerRabinTester {
1069    /** Works for inputs ≤ FLOOR_SQRT_MAX_LONG. */
1070    SMALL {
1071      @Override
1072      long mulMod(long a, long b, long m) {
1073        /*
1074         * NOTE(lowasser, 2015-Feb-12): Benchmarks suggest that changing this to
1075         * UnsignedLongs.remainder and increasing the threshold to 2^32 doesn't pay for itself, and
1076         * adding another enum constant hurts performance further -- I suspect because bimorphic
1077         * implementation is a sweet spot for the JVM.
1078         */
1079        return (a * b) % m;
1080      }
1081
1082      @Override
1083      long squareMod(long a, long m) {
1084        return (a * a) % m;
1085      }
1086    },
1087    /**
1088     * Works for all nonnegative signed longs.
1089     */
1090    LARGE {
1091      /** Returns (a + b) mod m. Precondition: {@code 0 <= a}, {@code b < m < 2^63}. */
1092      private long plusMod(long a, long b, long m) {
1093        return (a >= m - b) ? (a + b - m) : (a + b);
1094      }
1095
1096      /**
1097       * Returns (a * 2^32) mod m. a may be any unsigned long.
1098       */
1099      private long times2ToThe32Mod(long a, long m) {
1100        int remainingPowersOf2 = 32;
1101        do {
1102          int shift = Math.min(remainingPowersOf2, Long.numberOfLeadingZeros(a));
1103          // shift is either the number of powers of 2 left to multiply a by, or the biggest shift
1104          // possible while keeping a in an unsigned long.
1105          a = UnsignedLongs.remainder(a << shift, m);
1106          remainingPowersOf2 -= shift;
1107        } while (remainingPowersOf2 > 0);
1108        return a;
1109      }
1110
1111      @Override
1112      long mulMod(long a, long b, long m) {
1113        long aHi = a >>> 32; // < 2^31
1114        long bHi = b >>> 32; // < 2^31
1115        long aLo = a & 0xFFFFFFFFL; // < 2^32
1116        long bLo = b & 0xFFFFFFFFL; // < 2^32
1117
1118        /*
1119         * a * b == aHi * bHi * 2^64 + (aHi * bLo + aLo * bHi) * 2^32 + aLo * bLo.
1120         *       == (aHi * bHi * 2^32 + aHi * bLo + aLo * bHi) * 2^32 + aLo * bLo
1121         *
1122         * We carry out this computation in modular arithmetic. Since times2ToThe32Mod accepts any
1123         * unsigned long, we don't have to do a mod on every operation, only when intermediate
1124         * results can exceed 2^63.
1125         */
1126        long result = times2ToThe32Mod(aHi * bHi /* < 2^62 */, m); // < m < 2^63
1127        result += aHi * bLo; // aHi * bLo < 2^63, result < 2^64
1128        if (result < 0) {
1129          result = UnsignedLongs.remainder(result, m);
1130        }
1131        // result < 2^63 again
1132        result += aLo * bHi; // aLo * bHi < 2^63, result < 2^64
1133        result = times2ToThe32Mod(result, m); // result < m < 2^63
1134        return plusMod(
1135            result,
1136            UnsignedLongs.remainder(aLo * bLo /* < 2^64 */, m),
1137            m);
1138      }
1139
1140      @Override
1141      long squareMod(long a, long m) {
1142        long aHi = a >>> 32; // < 2^31
1143        long aLo = a & 0xFFFFFFFFL; // < 2^32
1144
1145        /*
1146         * a^2 == aHi^2 * 2^64 + aHi * aLo * 2^33 + aLo^2
1147         *     == (aHi^2 * 2^32 + aHi * aLo * 2) * 2^32 + aLo^2
1148         * We carry out this computation in modular arithmetic.  Since times2ToThe32Mod accepts any
1149         * unsigned long, we don't have to do a mod on every operation, only when intermediate
1150         * results can exceed 2^63.
1151         */
1152        long result = times2ToThe32Mod(aHi * aHi /* < 2^62 */, m); // < m < 2^63
1153        long hiLo = aHi * aLo * 2;
1154        if (hiLo < 0) {
1155          hiLo = UnsignedLongs.remainder(hiLo, m);
1156        }
1157        // hiLo < 2^63
1158        result += hiLo; // result < 2^64
1159        result = times2ToThe32Mod(result, m); // result < m < 2^63
1160        return plusMod(
1161            result,
1162            UnsignedLongs.remainder(aLo * aLo /* < 2^64 */, m),
1163            m);
1164      }
1165    };
1166
1167    static boolean test(long base, long n) {
1168      // Since base will be considered % n, it's okay if base > FLOOR_SQRT_MAX_LONG,
1169      // so long as n <= FLOOR_SQRT_MAX_LONG.
1170      return ((n <= FLOOR_SQRT_MAX_LONG) ? SMALL : LARGE).testWitness(base, n);
1171    }
1172
1173    /**
1174     * Returns a * b mod m.
1175     */
1176    abstract long mulMod(long a, long b, long m);
1177
1178    /**
1179     * Returns a^2 mod m.
1180     */
1181    abstract long squareMod(long a, long m);
1182
1183    /**
1184     * Returns a^p mod m.
1185     */
1186    private long powMod(long a, long p, long m) {
1187      long res = 1;
1188      for (; p != 0; p >>= 1) {
1189        if ((p & 1) != 0) {
1190          res = mulMod(res, a, m);
1191        }
1192        a = squareMod(a, m);
1193      }
1194      return res;
1195    }
1196
1197    /**
1198     * Returns true if n is a strong probable prime relative to the specified base.
1199     */
1200    private boolean testWitness(long base, long n) {
1201      int r = Long.numberOfTrailingZeros(n - 1);
1202      long d = (n - 1) >> r;
1203      base %= n;
1204      if (base == 0) {
1205        return true;
1206      }
1207      // Calculate a := base^d mod n.
1208      long a = powMod(base, d, n);
1209      // n passes this test if
1210      //    base^d = 1 (mod n)
1211      // or base^(2^j * d) = -1 (mod n) for some 0 <= j < r.
1212      if (a == 1) {
1213        return true;
1214      }
1215      int j = 0;
1216      while (a != n - 1) {
1217        if (++j == r) {
1218          return false;
1219        }
1220        a = squareMod(a, n);
1221      }
1222      return true;
1223    }
1224  }
1225
1226  private LongMath() {}
1227}