001/*
002 * Copyright (C) 2011 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
005 * in compliance with the License. You may obtain a copy of the License at
006 *
007 * http://www.apache.org/licenses/LICENSE-2.0
008 *
009 * Unless required by applicable law or agreed to in writing, software distributed under the License
010 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
011 * or implied. See the License for the specific language governing permissions and limitations under
012 * the License.
013 */
014
015package com.google.common.math;
016
017import static com.google.common.base.Preconditions.checkArgument;
018import static com.google.common.base.Preconditions.checkNotNull;
019import static com.google.common.math.MathPreconditions.checkNonNegative;
020import static com.google.common.math.MathPreconditions.checkPositive;
021import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
022import static java.math.RoundingMode.CEILING;
023import static java.math.RoundingMode.FLOOR;
024import static java.math.RoundingMode.HALF_DOWN;
025import static java.math.RoundingMode.HALF_EVEN;
026import static java.math.RoundingMode.UNNECESSARY;
027
028import com.google.common.annotations.GwtCompatible;
029import com.google.common.annotations.GwtIncompatible;
030import com.google.common.annotations.J2ktIncompatible;
031import com.google.common.annotations.VisibleForTesting;
032import java.math.BigDecimal;
033import java.math.BigInteger;
034import java.math.RoundingMode;
035import java.util.ArrayList;
036import java.util.List;
037
038/**
039 * A class for arithmetic on values of type {@code BigInteger}.
040 *
041 * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
042 * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
043 *
044 * <p>Similar functionality for {@code int} and for {@code long} can be found in {@link IntMath} and
045 * {@link LongMath} respectively.
046 *
047 * @author Louis Wasserman
048 * @since 11.0
049 */
050@GwtCompatible(emulated = true)
051@ElementTypesAreNonnullByDefault
052public final class BigIntegerMath {
053  /**
054   * Returns the smallest power of two greater than or equal to {@code x}. This is equivalent to
055   * {@code BigInteger.valueOf(2).pow(log2(x, CEILING))}.
056   *
057   * @throws IllegalArgumentException if {@code x <= 0}
058   * @since 20.0
059   */
060  public static BigInteger ceilingPowerOfTwo(BigInteger x) {
061    return BigInteger.ZERO.setBit(log2(x, CEILING));
062  }
063
064  /**
065   * Returns the largest power of two less than or equal to {@code x}. This is equivalent to {@code
066   * BigInteger.valueOf(2).pow(log2(x, FLOOR))}.
067   *
068   * @throws IllegalArgumentException if {@code x <= 0}
069   * @since 20.0
070   */
071  public static BigInteger floorPowerOfTwo(BigInteger x) {
072    return BigInteger.ZERO.setBit(log2(x, FLOOR));
073  }
074
075  /** Returns {@code true} if {@code x} represents a power of two. */
076  public static boolean isPowerOfTwo(BigInteger x) {
077    checkNotNull(x);
078    return x.signum() > 0 && x.getLowestSetBit() == x.bitLength() - 1;
079  }
080
081  /**
082   * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
083   *
084   * @throws IllegalArgumentException if {@code x <= 0}
085   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
086   *     is not a power of two
087   */
088  @SuppressWarnings("fallthrough")
089  // TODO(kevinb): remove after this warning is disabled globally
090  public static int log2(BigInteger x, RoundingMode mode) {
091    checkPositive("x", checkNotNull(x));
092    int logFloor = x.bitLength() - 1;
093    switch (mode) {
094      case UNNECESSARY:
095        checkRoundingUnnecessary(isPowerOfTwo(x)); // fall through
096      case DOWN:
097      case FLOOR:
098        return logFloor;
099
100      case UP:
101      case CEILING:
102        return isPowerOfTwo(x) ? logFloor : logFloor + 1;
103
104      case HALF_DOWN:
105      case HALF_UP:
106      case HALF_EVEN:
107        if (logFloor < SQRT2_PRECOMPUTE_THRESHOLD) {
108          BigInteger halfPower =
109              SQRT2_PRECOMPUTED_BITS.shiftRight(SQRT2_PRECOMPUTE_THRESHOLD - logFloor);
110          if (x.compareTo(halfPower) <= 0) {
111            return logFloor;
112          } else {
113            return logFloor + 1;
114          }
115        }
116        // Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
117        //
118        // To determine which side of logFloor.5 the logarithm is,
119        // we compare x^2 to 2^(2 * logFloor + 1).
120        BigInteger x2 = x.pow(2);
121        int logX2Floor = x2.bitLength() - 1;
122        return (logX2Floor < 2 * logFloor + 1) ? logFloor : logFloor + 1;
123
124      default:
125        throw new AssertionError();
126    }
127  }
128
129  /*
130   * The maximum number of bits in a square root for which we'll precompute an explicit half power
131   * of two. This can be any value, but higher values incur more class load time and linearly
132   * increasing memory consumption.
133   */
134  @VisibleForTesting static final int SQRT2_PRECOMPUTE_THRESHOLD = 256;
135
136  @VisibleForTesting
137  static final BigInteger SQRT2_PRECOMPUTED_BITS =
138      new BigInteger("16a09e667f3bcc908b2fb1366ea957d3e3adec17512775099da2f590b0667322a", 16);
139
140  /**
141   * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
142   *
143   * @throws IllegalArgumentException if {@code x <= 0}
144   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
145   *     is not a power of ten
146   */
147  @J2ktIncompatible
148  @GwtIncompatible // TODO
149  @SuppressWarnings("fallthrough")
150  public static int log10(BigInteger x, RoundingMode mode) {
151    checkPositive("x", x);
152    if (fitsInLong(x)) {
153      return LongMath.log10(x.longValue(), mode);
154    }
155
156    int approxLog10 = (int) (log2(x, FLOOR) * LN_2 / LN_10);
157    BigInteger approxPow = BigInteger.TEN.pow(approxLog10);
158    int approxCmp = approxPow.compareTo(x);
159
160    /*
161     * We adjust approxLog10 and approxPow until they're equal to floor(log10(x)) and
162     * 10^floor(log10(x)).
163     */
164
165    if (approxCmp > 0) {
166      /*
167       * The code is written so that even completely incorrect approximations will still yield the
168       * correct answer eventually, but in practice this branch should almost never be entered, and
169       * even then the loop should not run more than once.
170       */
171      do {
172        approxLog10--;
173        approxPow = approxPow.divide(BigInteger.TEN);
174        approxCmp = approxPow.compareTo(x);
175      } while (approxCmp > 0);
176    } else {
177      BigInteger nextPow = BigInteger.TEN.multiply(approxPow);
178      int nextCmp = nextPow.compareTo(x);
179      while (nextCmp <= 0) {
180        approxLog10++;
181        approxPow = nextPow;
182        approxCmp = nextCmp;
183        nextPow = BigInteger.TEN.multiply(approxPow);
184        nextCmp = nextPow.compareTo(x);
185      }
186    }
187
188    int floorLog = approxLog10;
189    BigInteger floorPow = approxPow;
190    int floorCmp = approxCmp;
191
192    switch (mode) {
193      case UNNECESSARY:
194        checkRoundingUnnecessary(floorCmp == 0);
195        // fall through
196      case FLOOR:
197      case DOWN:
198        return floorLog;
199
200      case CEILING:
201      case UP:
202        return floorPow.equals(x) ? floorLog : floorLog + 1;
203
204      case HALF_DOWN:
205      case HALF_UP:
206      case HALF_EVEN:
207        // Since sqrt(10) is irrational, log10(x) - floorLog can never be exactly 0.5
208        BigInteger x2 = x.pow(2);
209        BigInteger halfPowerSquared = floorPow.pow(2).multiply(BigInteger.TEN);
210        return (x2.compareTo(halfPowerSquared) <= 0) ? floorLog : floorLog + 1;
211      default:
212        throw new AssertionError();
213    }
214  }
215
216  private static final double LN_10 = Math.log(10);
217  private static final double LN_2 = Math.log(2);
218
219  /**
220   * Returns the square root of {@code x}, rounded with the specified rounding mode.
221   *
222   * @throws IllegalArgumentException if {@code x < 0}
223   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code
224   *     sqrt(x)} is not an integer
225   */
226  @J2ktIncompatible
227  @GwtIncompatible // TODO
228  @SuppressWarnings("fallthrough")
229  public static BigInteger sqrt(BigInteger x, RoundingMode mode) {
230    checkNonNegative("x", x);
231    if (fitsInLong(x)) {
232      return BigInteger.valueOf(LongMath.sqrt(x.longValue(), mode));
233    }
234    BigInteger sqrtFloor = sqrtFloor(x);
235    switch (mode) {
236      case UNNECESSARY:
237        checkRoundingUnnecessary(sqrtFloor.pow(2).equals(x)); // fall through
238      case FLOOR:
239      case DOWN:
240        return sqrtFloor;
241      case CEILING:
242      case UP:
243        int sqrtFloorInt = sqrtFloor.intValue();
244        boolean sqrtFloorIsExact =
245            (sqrtFloorInt * sqrtFloorInt == x.intValue()) // fast check mod 2^32
246                && sqrtFloor.pow(2).equals(x); // slow exact check
247        return sqrtFloorIsExact ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
248      case HALF_DOWN:
249      case HALF_UP:
250      case HALF_EVEN:
251        BigInteger halfSquare = sqrtFloor.pow(2).add(sqrtFloor);
252        /*
253         * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both x
254         * and halfSquare are integers, this is equivalent to testing whether or not x <=
255         * halfSquare.
256         */
257        return (halfSquare.compareTo(x) >= 0) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
258      default:
259        throw new AssertionError();
260    }
261  }
262
263  @J2ktIncompatible
264  @GwtIncompatible // TODO
265  private static BigInteger sqrtFloor(BigInteger x) {
266    /*
267     * Adapted from Hacker's Delight, Figure 11-1.
268     *
269     * Using DoubleUtils.bigToDouble, getting a double approximation of x is extremely fast, and
270     * then we can get a double approximation of the square root. Then, we iteratively improve this
271     * guess with an application of Newton's method, which sets guess := (guess + (x / guess)) / 2.
272     * This iteration has the following two properties:
273     *
274     * a) every iteration (except potentially the first) has guess >= floor(sqrt(x)). This is
275     * because guess' is the arithmetic mean of guess and x / guess, sqrt(x) is the geometric mean,
276     * and the arithmetic mean is always higher than the geometric mean.
277     *
278     * b) this iteration converges to floor(sqrt(x)). In fact, the number of correct digits doubles
279     * with each iteration, so this algorithm takes O(log(digits)) iterations.
280     *
281     * We start out with a double-precision approximation, which may be higher or lower than the
282     * true value. Therefore, we perform at least one Newton iteration to get a guess that's
283     * definitely >= floor(sqrt(x)), and then continue the iteration until we reach a fixed point.
284     */
285    BigInteger sqrt0;
286    int log2 = log2(x, FLOOR);
287    if (log2 < Double.MAX_EXPONENT) {
288      sqrt0 = sqrtApproxWithDoubles(x);
289    } else {
290      int shift = (log2 - DoubleUtils.SIGNIFICAND_BITS) & ~1; // even!
291      /*
292       * We have that x / 2^shift < 2^54. Our initial approximation to sqrtFloor(x) will be
293       * 2^(shift/2) * sqrtApproxWithDoubles(x / 2^shift).
294       */
295      sqrt0 = sqrtApproxWithDoubles(x.shiftRight(shift)).shiftLeft(shift >> 1);
296    }
297    BigInteger sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
298    if (sqrt0.equals(sqrt1)) {
299      return sqrt0;
300    }
301    do {
302      sqrt0 = sqrt1;
303      sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
304    } while (sqrt1.compareTo(sqrt0) < 0);
305    return sqrt0;
306  }
307
308  @J2ktIncompatible
309  @GwtIncompatible // TODO
310  private static BigInteger sqrtApproxWithDoubles(BigInteger x) {
311    return DoubleMath.roundToBigInteger(Math.sqrt(DoubleUtils.bigToDouble(x)), HALF_EVEN);
312  }
313
314  /**
315   * Returns {@code x}, rounded to a {@code double} with the specified rounding mode. If {@code x}
316   * is precisely representable as a {@code double}, its {@code double} value will be returned;
317   * otherwise, the rounding will choose between the two nearest representable values with {@code
318   * mode}.
319   *
320   * <p>For the case of {@link RoundingMode#HALF_DOWN}, {@code HALF_UP}, and {@code HALF_EVEN},
321   * infinite {@code double} values are considered infinitely far away. For example, 2^2000 is not
322   * representable as a double, but {@code roundToDouble(BigInteger.valueOf(2).pow(2000), HALF_UP)}
323   * will return {@code Double.MAX_VALUE}, not {@code Double.POSITIVE_INFINITY}.
324   *
325   * <p>For the case of {@link RoundingMode#HALF_EVEN}, this implementation uses the IEEE 754
326   * default rounding mode: if the two nearest representable values are equally near, the one with
327   * the least significant bit zero is chosen. (In such cases, both of the nearest representable
328   * values are even integers; this method returns the one that is a multiple of a greater power of
329   * two.)
330   *
331   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
332   *     is not precisely representable as a {@code double}
333   * @since 30.0
334   */
335  @J2ktIncompatible
336  @GwtIncompatible
337  public static double roundToDouble(BigInteger x, RoundingMode mode) {
338    return BigIntegerToDoubleRounder.INSTANCE.roundToDouble(x, mode);
339  }
340
341  @J2ktIncompatible
342  @GwtIncompatible
343  private static class BigIntegerToDoubleRounder extends ToDoubleRounder<BigInteger> {
344    static final BigIntegerToDoubleRounder INSTANCE = new BigIntegerToDoubleRounder();
345
346    private BigIntegerToDoubleRounder() {}
347
348    @Override
349    double roundToDoubleArbitrarily(BigInteger bigInteger) {
350      return DoubleUtils.bigToDouble(bigInteger);
351    }
352
353    @Override
354    int sign(BigInteger bigInteger) {
355      return bigInteger.signum();
356    }
357
358    @Override
359    BigInteger toX(double d, RoundingMode mode) {
360      return DoubleMath.roundToBigInteger(d, mode);
361    }
362
363    @Override
364    BigInteger minus(BigInteger a, BigInteger b) {
365      return a.subtract(b);
366    }
367  }
368
369  /**
370   * Returns the result of dividing {@code p} by {@code q}, rounding using the specified {@code
371   * RoundingMode}.
372   *
373   * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
374   *     is not an integer multiple of {@code b}
375   */
376  @J2ktIncompatible
377  @GwtIncompatible // TODO
378  public static BigInteger divide(BigInteger p, BigInteger q, RoundingMode mode) {
379    BigDecimal pDec = new BigDecimal(p);
380    BigDecimal qDec = new BigDecimal(q);
381    return pDec.divide(qDec, 0, mode).toBigIntegerExact();
382  }
383
384  /**
385   * Returns {@code n!}, that is, the product of the first {@code n} positive integers, or {@code 1}
386   * if {@code n == 0}.
387   *
388   * <p><b>Warning:</b> the result takes <i>O(n log n)</i> space, so use cautiously.
389   *
390   * <p>This uses an efficient binary recursive algorithm to compute the factorial with balanced
391   * multiplies. It also removes all the 2s from the intermediate products (shifting them back in at
392   * the end).
393   *
394   * @throws IllegalArgumentException if {@code n < 0}
395   */
396  public static BigInteger factorial(int n) {
397    checkNonNegative("n", n);
398
399    // If the factorial is small enough, just use LongMath to do it.
400    if (n < LongMath.factorials.length) {
401      return BigInteger.valueOf(LongMath.factorials[n]);
402    }
403
404    // Pre-allocate space for our list of intermediate BigIntegers.
405    int approxSize = IntMath.divide(n * IntMath.log2(n, CEILING), Long.SIZE, CEILING);
406    ArrayList<BigInteger> bignums = new ArrayList<>(approxSize);
407
408    // Start from the pre-computed maximum long factorial.
409    int startingNumber = LongMath.factorials.length;
410    long product = LongMath.factorials[startingNumber - 1];
411    // Strip off 2s from this value.
412    int shift = Long.numberOfTrailingZeros(product);
413    product >>= shift;
414
415    // Use floor(log2(num)) + 1 to prevent overflow of multiplication.
416    int productBits = LongMath.log2(product, FLOOR) + 1;
417    int bits = LongMath.log2(startingNumber, FLOOR) + 1;
418    // Check for the next power of two boundary, to save us a CLZ operation.
419    int nextPowerOfTwo = 1 << (bits - 1);
420
421    // Iteratively multiply the longs as big as they can go.
422    for (long num = startingNumber; num <= n; num++) {
423      // Check to see if the floor(log2(num)) + 1 has changed.
424      if ((num & nextPowerOfTwo) != 0) {
425        nextPowerOfTwo <<= 1;
426        bits++;
427      }
428      // Get rid of the 2s in num.
429      int tz = Long.numberOfTrailingZeros(num);
430      long normalizedNum = num >> tz;
431      shift += tz;
432      // Adjust floor(log2(num)) + 1.
433      int normalizedBits = bits - tz;
434      // If it won't fit in a long, then we store off the intermediate product.
435      if (normalizedBits + productBits >= Long.SIZE) {
436        bignums.add(BigInteger.valueOf(product));
437        product = 1;
438        productBits = 0;
439      }
440      product *= normalizedNum;
441      productBits = LongMath.log2(product, FLOOR) + 1;
442    }
443    // Check for leftovers.
444    if (product > 1) {
445      bignums.add(BigInteger.valueOf(product));
446    }
447    // Efficiently multiply all the intermediate products together.
448    return listProduct(bignums).shiftLeft(shift);
449  }
450
451  static BigInteger listProduct(List<BigInteger> nums) {
452    return listProduct(nums, 0, nums.size());
453  }
454
455  static BigInteger listProduct(List<BigInteger> nums, int start, int end) {
456    switch (end - start) {
457      case 0:
458        return BigInteger.ONE;
459      case 1:
460        return nums.get(start);
461      case 2:
462        return nums.get(start).multiply(nums.get(start + 1));
463      case 3:
464        return nums.get(start).multiply(nums.get(start + 1)).multiply(nums.get(start + 2));
465      default:
466        // Otherwise, split the list in half and recursively do this.
467        int m = (end + start) >>> 1;
468        return listProduct(nums, start, m).multiply(listProduct(nums, m, end));
469    }
470  }
471
472  /**
473   * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
474   * {@code k}, that is, {@code n! / (k! (n - k)!)}.
475   *
476   * <p><b>Warning:</b> the result can take as much as <i>O(k log n)</i> space.
477   *
478   * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
479   */
480  public static BigInteger binomial(int n, int k) {
481    checkNonNegative("n", n);
482    checkNonNegative("k", k);
483    checkArgument(k <= n, "k (%s) > n (%s)", k, n);
484    if (k > (n >> 1)) {
485      k = n - k;
486    }
487    if (k < LongMath.biggestBinomials.length && n <= LongMath.biggestBinomials[k]) {
488      return BigInteger.valueOf(LongMath.binomial(n, k));
489    }
490
491    BigInteger accum = BigInteger.ONE;
492
493    long numeratorAccum = n;
494    long denominatorAccum = 1;
495
496    int bits = LongMath.log2(n, CEILING);
497
498    int numeratorBits = bits;
499
500    for (int i = 1; i < k; i++) {
501      int p = n - i;
502      int q = i + 1;
503
504      // log2(p) >= bits - 1, because p >= n/2
505
506      if (numeratorBits + bits >= Long.SIZE - 1) {
507        // The numerator is as big as it can get without risking overflow.
508        // Multiply numeratorAccum / denominatorAccum into accum.
509        accum =
510            accum
511                .multiply(BigInteger.valueOf(numeratorAccum))
512                .divide(BigInteger.valueOf(denominatorAccum));
513        numeratorAccum = p;
514        denominatorAccum = q;
515        numeratorBits = bits;
516      } else {
517        // We can definitely multiply into the long accumulators without overflowing them.
518        numeratorAccum *= p;
519        denominatorAccum *= q;
520        numeratorBits += bits;
521      }
522    }
523    return accum
524        .multiply(BigInteger.valueOf(numeratorAccum))
525        .divide(BigInteger.valueOf(denominatorAccum));
526  }
527
528  // Returns true if BigInteger.valueOf(x.longValue()).equals(x).
529  @J2ktIncompatible
530  @GwtIncompatible // TODO
531  static boolean fitsInLong(BigInteger x) {
532    return x.bitLength() <= Long.SIZE - 1;
533  }
534
535  private BigIntegerMath() {}
536}