001/*
002 * Copyright (C) 2011 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
005 * in compliance with the License. You may obtain a copy of the License at
006 *
007 * http://www.apache.org/licenses/LICENSE-2.0
008 *
009 * Unless required by applicable law or agreed to in writing, software distributed under the License
010 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
011 * or implied. See the License for the specific language governing permissions and limitations under
012 * the License.
013 */
014
015package com.google.common.math;
016
017import static com.google.common.base.Preconditions.checkArgument;
018import static com.google.common.base.Preconditions.checkNotNull;
019import static com.google.common.math.MathPreconditions.checkNonNegative;
020import static com.google.common.math.MathPreconditions.checkPositive;
021import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
022import static java.math.RoundingMode.CEILING;
023import static java.math.RoundingMode.FLOOR;
024import static java.math.RoundingMode.HALF_EVEN;
025
026import com.google.common.annotations.GwtCompatible;
027import com.google.common.annotations.GwtIncompatible;
028import com.google.common.annotations.VisibleForTesting;
029import java.math.BigDecimal;
030import java.math.BigInteger;
031import java.math.RoundingMode;
032import java.util.ArrayList;
033import java.util.List;
034
035/**
036 * A class for arithmetic on values of type {@code BigInteger}.
037 *
038 * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
039 * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
040 *
041 * <p>Similar functionality for {@code int} and for {@code long} can be found in {@link IntMath} and
042 * {@link LongMath} respectively.
043 *
044 * @author Louis Wasserman
045 * @since 11.0
046 */
047@GwtCompatible(emulated = true)
048@ElementTypesAreNonnullByDefault
049public final class BigIntegerMath {
050  /**
051   * Returns the smallest power of two greater than or equal to {@code x}. This is equivalent to
052   * {@code BigInteger.valueOf(2).pow(log2(x, CEILING))}.
053   *
054   * @throws IllegalArgumentException if {@code x <= 0}
055   * @since 20.0
056   */
057  public static BigInteger ceilingPowerOfTwo(BigInteger x) {
058    return BigInteger.ZERO.setBit(log2(x, CEILING));
059  }
060
061  /**
062   * Returns the largest power of two less than or equal to {@code x}. This is equivalent to {@code
063   * BigInteger.valueOf(2).pow(log2(x, FLOOR))}.
064   *
065   * @throws IllegalArgumentException if {@code x <= 0}
066   * @since 20.0
067   */
068  public static BigInteger floorPowerOfTwo(BigInteger x) {
069    return BigInteger.ZERO.setBit(log2(x, FLOOR));
070  }
071
072  /** Returns {@code true} if {@code x} represents a power of two. */
073  public static boolean isPowerOfTwo(BigInteger x) {
074    checkNotNull(x);
075    return x.signum() > 0 && x.getLowestSetBit() == x.bitLength() - 1;
076  }
077
078  /**
079   * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
080   *
081   * @throws IllegalArgumentException if {@code x <= 0}
082   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
083   *     is not a power of two
084   */
085  @SuppressWarnings("fallthrough")
086  // TODO(kevinb): remove after this warning is disabled globally
087  public static int log2(BigInteger x, RoundingMode mode) {
088    checkPositive("x", checkNotNull(x));
089    int logFloor = x.bitLength() - 1;
090    switch (mode) {
091      case UNNECESSARY:
092        checkRoundingUnnecessary(isPowerOfTwo(x)); // fall through
093      case DOWN:
094      case FLOOR:
095        return logFloor;
096
097      case UP:
098      case CEILING:
099        return isPowerOfTwo(x) ? logFloor : logFloor + 1;
100
101      case HALF_DOWN:
102      case HALF_UP:
103      case HALF_EVEN:
104        if (logFloor < SQRT2_PRECOMPUTE_THRESHOLD) {
105          BigInteger halfPower =
106              SQRT2_PRECOMPUTED_BITS.shiftRight(SQRT2_PRECOMPUTE_THRESHOLD - logFloor);
107          if (x.compareTo(halfPower) <= 0) {
108            return logFloor;
109          } else {
110            return logFloor + 1;
111          }
112        }
113        // Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
114        //
115        // To determine which side of logFloor.5 the logarithm is,
116        // we compare x^2 to 2^(2 * logFloor + 1).
117        BigInteger x2 = x.pow(2);
118        int logX2Floor = x2.bitLength() - 1;
119        return (logX2Floor < 2 * logFloor + 1) ? logFloor : logFloor + 1;
120    }
121    throw new AssertionError();
122  }
123
124  /*
125   * The maximum number of bits in a square root for which we'll precompute an explicit half power
126   * of two. This can be any value, but higher values incur more class load time and linearly
127   * increasing memory consumption.
128   */
129  @VisibleForTesting static final int SQRT2_PRECOMPUTE_THRESHOLD = 256;
130
131  @VisibleForTesting
132  static final BigInteger SQRT2_PRECOMPUTED_BITS =
133      new BigInteger("16a09e667f3bcc908b2fb1366ea957d3e3adec17512775099da2f590b0667322a", 16);
134
135  /**
136   * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
137   *
138   * @throws IllegalArgumentException if {@code x <= 0}
139   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
140   *     is not a power of ten
141   */
142  @GwtIncompatible // TODO
143  @SuppressWarnings("fallthrough")
144  public static int log10(BigInteger x, RoundingMode mode) {
145    checkPositive("x", x);
146    if (fitsInLong(x)) {
147      return LongMath.log10(x.longValue(), mode);
148    }
149
150    int approxLog10 = (int) (log2(x, FLOOR) * LN_2 / LN_10);
151    BigInteger approxPow = BigInteger.TEN.pow(approxLog10);
152    int approxCmp = approxPow.compareTo(x);
153
154    /*
155     * We adjust approxLog10 and approxPow until they're equal to floor(log10(x)) and
156     * 10^floor(log10(x)).
157     */
158
159    if (approxCmp > 0) {
160      /*
161       * The code is written so that even completely incorrect approximations will still yield the
162       * correct answer eventually, but in practice this branch should almost never be entered, and
163       * even then the loop should not run more than once.
164       */
165      do {
166        approxLog10--;
167        approxPow = approxPow.divide(BigInteger.TEN);
168        approxCmp = approxPow.compareTo(x);
169      } while (approxCmp > 0);
170    } else {
171      BigInteger nextPow = BigInteger.TEN.multiply(approxPow);
172      int nextCmp = nextPow.compareTo(x);
173      while (nextCmp <= 0) {
174        approxLog10++;
175        approxPow = nextPow;
176        approxCmp = nextCmp;
177        nextPow = BigInteger.TEN.multiply(approxPow);
178        nextCmp = nextPow.compareTo(x);
179      }
180    }
181
182    int floorLog = approxLog10;
183    BigInteger floorPow = approxPow;
184    int floorCmp = approxCmp;
185
186    switch (mode) {
187      case UNNECESSARY:
188        checkRoundingUnnecessary(floorCmp == 0);
189        // fall through
190      case FLOOR:
191      case DOWN:
192        return floorLog;
193
194      case CEILING:
195      case UP:
196        return floorPow.equals(x) ? floorLog : floorLog + 1;
197
198      case HALF_DOWN:
199      case HALF_UP:
200      case HALF_EVEN:
201        // Since sqrt(10) is irrational, log10(x) - floorLog can never be exactly 0.5
202        BigInteger x2 = x.pow(2);
203        BigInteger halfPowerSquared = floorPow.pow(2).multiply(BigInteger.TEN);
204        return (x2.compareTo(halfPowerSquared) <= 0) ? floorLog : floorLog + 1;
205    }
206    throw new AssertionError();
207  }
208
209  private static final double LN_10 = Math.log(10);
210  private static final double LN_2 = Math.log(2);
211
212  /**
213   * Returns the square root of {@code x}, rounded with the specified rounding mode.
214   *
215   * @throws IllegalArgumentException if {@code x < 0}
216   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code
217   *     sqrt(x)} is not an integer
218   */
219  @GwtIncompatible // TODO
220  @SuppressWarnings("fallthrough")
221  public static BigInteger sqrt(BigInteger x, RoundingMode mode) {
222    checkNonNegative("x", x);
223    if (fitsInLong(x)) {
224      return BigInteger.valueOf(LongMath.sqrt(x.longValue(), mode));
225    }
226    BigInteger sqrtFloor = sqrtFloor(x);
227    switch (mode) {
228      case UNNECESSARY:
229        checkRoundingUnnecessary(sqrtFloor.pow(2).equals(x)); // fall through
230      case FLOOR:
231      case DOWN:
232        return sqrtFloor;
233      case CEILING:
234      case UP:
235        int sqrtFloorInt = sqrtFloor.intValue();
236        boolean sqrtFloorIsExact =
237            (sqrtFloorInt * sqrtFloorInt == x.intValue()) // fast check mod 2^32
238                && sqrtFloor.pow(2).equals(x); // slow exact check
239        return sqrtFloorIsExact ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
240      case HALF_DOWN:
241      case HALF_UP:
242      case HALF_EVEN:
243        BigInteger halfSquare = sqrtFloor.pow(2).add(sqrtFloor);
244        /*
245         * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both x
246         * and halfSquare are integers, this is equivalent to testing whether or not x <=
247         * halfSquare.
248         */
249        return (halfSquare.compareTo(x) >= 0) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
250    }
251    throw new AssertionError();
252  }
253
254  @GwtIncompatible // TODO
255  private static BigInteger sqrtFloor(BigInteger x) {
256    /*
257     * Adapted from Hacker's Delight, Figure 11-1.
258     *
259     * Using DoubleUtils.bigToDouble, getting a double approximation of x is extremely fast, and
260     * then we can get a double approximation of the square root. Then, we iteratively improve this
261     * guess with an application of Newton's method, which sets guess := (guess + (x / guess)) / 2.
262     * This iteration has the following two properties:
263     *
264     * a) every iteration (except potentially the first) has guess >= floor(sqrt(x)). This is
265     * because guess' is the arithmetic mean of guess and x / guess, sqrt(x) is the geometric mean,
266     * and the arithmetic mean is always higher than the geometric mean.
267     *
268     * b) this iteration converges to floor(sqrt(x)). In fact, the number of correct digits doubles
269     * with each iteration, so this algorithm takes O(log(digits)) iterations.
270     *
271     * We start out with a double-precision approximation, which may be higher or lower than the
272     * true value. Therefore, we perform at least one Newton iteration to get a guess that's
273     * definitely >= floor(sqrt(x)), and then continue the iteration until we reach a fixed point.
274     */
275    BigInteger sqrt0;
276    int log2 = log2(x, FLOOR);
277    if (log2 < Double.MAX_EXPONENT) {
278      sqrt0 = sqrtApproxWithDoubles(x);
279    } else {
280      int shift = (log2 - DoubleUtils.SIGNIFICAND_BITS) & ~1; // even!
281      /*
282       * We have that x / 2^shift < 2^54. Our initial approximation to sqrtFloor(x) will be
283       * 2^(shift/2) * sqrtApproxWithDoubles(x / 2^shift).
284       */
285      sqrt0 = sqrtApproxWithDoubles(x.shiftRight(shift)).shiftLeft(shift >> 1);
286    }
287    BigInteger sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
288    if (sqrt0.equals(sqrt1)) {
289      return sqrt0;
290    }
291    do {
292      sqrt0 = sqrt1;
293      sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
294    } while (sqrt1.compareTo(sqrt0) < 0);
295    return sqrt0;
296  }
297
298  @GwtIncompatible // TODO
299  private static BigInteger sqrtApproxWithDoubles(BigInteger x) {
300    return DoubleMath.roundToBigInteger(Math.sqrt(DoubleUtils.bigToDouble(x)), HALF_EVEN);
301  }
302
303  /**
304   * Returns {@code x}, rounded to a {@code double} with the specified rounding mode. If {@code x}
305   * is precisely representable as a {@code double}, its {@code double} value will be returned;
306   * otherwise, the rounding will choose between the two nearest representable values with {@code
307   * mode}.
308   *
309   * <p>For the case of {@link RoundingMode#HALF_DOWN}, {@code HALF_UP}, and {@code HALF_EVEN},
310   * infinite {@code double} values are considered infinitely far away. For example, 2^2000 is not
311   * representable as a double, but {@code roundToDouble(BigInteger.valueOf(2).pow(2000), HALF_UP)}
312   * will return {@code Double.MAX_VALUE}, not {@code Double.POSITIVE_INFINITY}.
313   *
314   * <p>For the case of {@link RoundingMode#HALF_EVEN}, this implementation uses the IEEE 754
315   * default rounding mode: if the two nearest representable values are equally near, the one with
316   * the least significant bit zero is chosen. (In such cases, both of the nearest representable
317   * values are even integers; this method returns the one that is a multiple of a greater power of
318   * two.)
319   *
320   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
321   *     is not precisely representable as a {@code double}
322   * @since 30.0
323   */
324  @GwtIncompatible
325  public static double roundToDouble(BigInteger x, RoundingMode mode) {
326    return BigIntegerToDoubleRounder.INSTANCE.roundToDouble(x, mode);
327  }
328
329  @GwtIncompatible
330  private static class BigIntegerToDoubleRounder extends ToDoubleRounder<BigInteger> {
331    static final BigIntegerToDoubleRounder INSTANCE = new BigIntegerToDoubleRounder();
332
333    private BigIntegerToDoubleRounder() {}
334
335    @Override
336    double roundToDoubleArbitrarily(BigInteger bigInteger) {
337      return DoubleUtils.bigToDouble(bigInteger);
338    }
339
340    @Override
341    int sign(BigInteger bigInteger) {
342      return bigInteger.signum();
343    }
344
345    @Override
346    BigInteger toX(double d, RoundingMode mode) {
347      return DoubleMath.roundToBigInteger(d, mode);
348    }
349
350    @Override
351    BigInteger minus(BigInteger a, BigInteger b) {
352      return a.subtract(b);
353    }
354  }
355
356  /**
357   * Returns the result of dividing {@code p} by {@code q}, rounding using the specified {@code
358   * RoundingMode}.
359   *
360   * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
361   *     is not an integer multiple of {@code b}
362   */
363  @GwtIncompatible // TODO
364  public static BigInteger divide(BigInteger p, BigInteger q, RoundingMode mode) {
365    BigDecimal pDec = new BigDecimal(p);
366    BigDecimal qDec = new BigDecimal(q);
367    return pDec.divide(qDec, 0, mode).toBigIntegerExact();
368  }
369
370  /**
371   * Returns {@code n!}, that is, the product of the first {@code n} positive integers, or {@code 1}
372   * if {@code n == 0}.
373   *
374   * <p><b>Warning:</b> the result takes <i>O(n log n)</i> space, so use cautiously.
375   *
376   * <p>This uses an efficient binary recursive algorithm to compute the factorial with balanced
377   * multiplies. It also removes all the 2s from the intermediate products (shifting them back in at
378   * the end).
379   *
380   * @throws IllegalArgumentException if {@code n < 0}
381   */
382  public static BigInteger factorial(int n) {
383    checkNonNegative("n", n);
384
385    // If the factorial is small enough, just use LongMath to do it.
386    if (n < LongMath.factorials.length) {
387      return BigInteger.valueOf(LongMath.factorials[n]);
388    }
389
390    // Pre-allocate space for our list of intermediate BigIntegers.
391    int approxSize = IntMath.divide(n * IntMath.log2(n, CEILING), Long.SIZE, CEILING);
392    ArrayList<BigInteger> bignums = new ArrayList<>(approxSize);
393
394    // Start from the pre-computed maximum long factorial.
395    int startingNumber = LongMath.factorials.length;
396    long product = LongMath.factorials[startingNumber - 1];
397    // Strip off 2s from this value.
398    int shift = Long.numberOfTrailingZeros(product);
399    product >>= shift;
400
401    // Use floor(log2(num)) + 1 to prevent overflow of multiplication.
402    int productBits = LongMath.log2(product, FLOOR) + 1;
403    int bits = LongMath.log2(startingNumber, FLOOR) + 1;
404    // Check for the next power of two boundary, to save us a CLZ operation.
405    int nextPowerOfTwo = 1 << (bits - 1);
406
407    // Iteratively multiply the longs as big as they can go.
408    for (long num = startingNumber; num <= n; num++) {
409      // Check to see if the floor(log2(num)) + 1 has changed.
410      if ((num & nextPowerOfTwo) != 0) {
411        nextPowerOfTwo <<= 1;
412        bits++;
413      }
414      // Get rid of the 2s in num.
415      int tz = Long.numberOfTrailingZeros(num);
416      long normalizedNum = num >> tz;
417      shift += tz;
418      // Adjust floor(log2(num)) + 1.
419      int normalizedBits = bits - tz;
420      // If it won't fit in a long, then we store off the intermediate product.
421      if (normalizedBits + productBits >= Long.SIZE) {
422        bignums.add(BigInteger.valueOf(product));
423        product = 1;
424        productBits = 0;
425      }
426      product *= normalizedNum;
427      productBits = LongMath.log2(product, FLOOR) + 1;
428    }
429    // Check for leftovers.
430    if (product > 1) {
431      bignums.add(BigInteger.valueOf(product));
432    }
433    // Efficiently multiply all the intermediate products together.
434    return listProduct(bignums).shiftLeft(shift);
435  }
436
437  static BigInteger listProduct(List<BigInteger> nums) {
438    return listProduct(nums, 0, nums.size());
439  }
440
441  static BigInteger listProduct(List<BigInteger> nums, int start, int end) {
442    switch (end - start) {
443      case 0:
444        return BigInteger.ONE;
445      case 1:
446        return nums.get(start);
447      case 2:
448        return nums.get(start).multiply(nums.get(start + 1));
449      case 3:
450        return nums.get(start).multiply(nums.get(start + 1)).multiply(nums.get(start + 2));
451      default:
452        // Otherwise, split the list in half and recursively do this.
453        int m = (end + start) >>> 1;
454        return listProduct(nums, start, m).multiply(listProduct(nums, m, end));
455    }
456  }
457
458  /**
459   * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
460   * {@code k}, that is, {@code n! / (k! (n - k)!)}.
461   *
462   * <p><b>Warning:</b> the result can take as much as <i>O(k log n)</i> space.
463   *
464   * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
465   */
466  public static BigInteger binomial(int n, int k) {
467    checkNonNegative("n", n);
468    checkNonNegative("k", k);
469    checkArgument(k <= n, "k (%s) > n (%s)", k, n);
470    if (k > (n >> 1)) {
471      k = n - k;
472    }
473    if (k < LongMath.biggestBinomials.length && n <= LongMath.biggestBinomials[k]) {
474      return BigInteger.valueOf(LongMath.binomial(n, k));
475    }
476
477    BigInteger accum = BigInteger.ONE;
478
479    long numeratorAccum = n;
480    long denominatorAccum = 1;
481
482    int bits = LongMath.log2(n, CEILING);
483
484    int numeratorBits = bits;
485
486    for (int i = 1; i < k; i++) {
487      int p = n - i;
488      int q = i + 1;
489
490      // log2(p) >= bits - 1, because p >= n/2
491
492      if (numeratorBits + bits >= Long.SIZE - 1) {
493        // The numerator is as big as it can get without risking overflow.
494        // Multiply numeratorAccum / denominatorAccum into accum.
495        accum =
496            accum
497                .multiply(BigInteger.valueOf(numeratorAccum))
498                .divide(BigInteger.valueOf(denominatorAccum));
499        numeratorAccum = p;
500        denominatorAccum = q;
501        numeratorBits = bits;
502      } else {
503        // We can definitely multiply into the long accumulators without overflowing them.
504        numeratorAccum *= p;
505        denominatorAccum *= q;
506        numeratorBits += bits;
507      }
508    }
509    return accum
510        .multiply(BigInteger.valueOf(numeratorAccum))
511        .divide(BigInteger.valueOf(denominatorAccum));
512  }
513
514  // Returns true if BigInteger.valueOf(x.longValue()).equals(x).
515  @GwtIncompatible // TODO
516  static boolean fitsInLong(BigInteger x) {
517    return x.bitLength() <= Long.SIZE - 1;
518  }
519
520  private BigIntegerMath() {}
521}