001/*
002 * Copyright (C) 2011 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
005 * in compliance with the License. You may obtain a copy of the License at
006 *
007 * http://www.apache.org/licenses/LICENSE-2.0
008 *
009 * Unless required by applicable law or agreed to in writing, software distributed under the License
010 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
011 * or implied. See the License for the specific language governing permissions and limitations under
012 * the License.
013 */
014
015package com.google.common.math;
016
017import static com.google.common.base.Preconditions.checkArgument;
018import static com.google.common.base.Preconditions.checkNotNull;
019import static com.google.common.math.MathPreconditions.checkNonNegative;
020import static com.google.common.math.MathPreconditions.checkPositive;
021import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
022import static java.math.RoundingMode.CEILING;
023import static java.math.RoundingMode.FLOOR;
024import static java.math.RoundingMode.HALF_EVEN;
025
026import com.google.common.annotations.Beta;
027import com.google.common.annotations.GwtCompatible;
028import com.google.common.annotations.GwtIncompatible;
029import com.google.common.annotations.VisibleForTesting;
030import java.math.BigDecimal;
031import java.math.BigInteger;
032import java.math.RoundingMode;
033import java.util.ArrayList;
034import java.util.List;
035
036/**
037 * A class for arithmetic on values of type {@code BigInteger}.
038 *
039 * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
040 * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
041 *
042 * <p>Similar functionality for {@code int} and for {@code long} can be found in {@link IntMath} and
043 * {@link LongMath} respectively.
044 *
045 * @author Louis Wasserman
046 * @since 11.0
047 */
048@GwtCompatible(emulated = true)
049public final class BigIntegerMath {
050  /**
051   * Returns the smallest power of two greater than or equal to {@code x}. This is equivalent to
052   * {@code BigInteger.valueOf(2).pow(log2(x, CEILING))}.
053   *
054   * @throws IllegalArgumentException if {@code x <= 0}
055   * @since 20.0
056   */
057  @Beta
058  public static BigInteger ceilingPowerOfTwo(BigInteger x) {
059    return BigInteger.ZERO.setBit(log2(x, RoundingMode.CEILING));
060  }
061
062  /**
063   * Returns the largest power of two less than or equal to {@code x}. This is equivalent to {@code
064   * BigInteger.valueOf(2).pow(log2(x, FLOOR))}.
065   *
066   * @throws IllegalArgumentException if {@code x <= 0}
067   * @since 20.0
068   */
069  @Beta
070  public static BigInteger floorPowerOfTwo(BigInteger x) {
071    return BigInteger.ZERO.setBit(log2(x, RoundingMode.FLOOR));
072  }
073
074  /** Returns {@code true} if {@code x} represents a power of two. */
075  public static boolean isPowerOfTwo(BigInteger x) {
076    checkNotNull(x);
077    return x.signum() > 0 && x.getLowestSetBit() == x.bitLength() - 1;
078  }
079
080  /**
081   * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
082   *
083   * @throws IllegalArgumentException if {@code x <= 0}
084   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
085   *     is not a power of two
086   */
087  @SuppressWarnings("fallthrough")
088  // TODO(kevinb): remove after this warning is disabled globally
089  public static int log2(BigInteger x, RoundingMode mode) {
090    checkPositive("x", checkNotNull(x));
091    int logFloor = x.bitLength() - 1;
092    switch (mode) {
093      case UNNECESSARY:
094        checkRoundingUnnecessary(isPowerOfTwo(x)); // fall through
095      case DOWN:
096      case FLOOR:
097        return logFloor;
098
099      case UP:
100      case CEILING:
101        return isPowerOfTwo(x) ? logFloor : logFloor + 1;
102
103      case HALF_DOWN:
104      case HALF_UP:
105      case HALF_EVEN:
106        if (logFloor < SQRT2_PRECOMPUTE_THRESHOLD) {
107          BigInteger halfPower =
108              SQRT2_PRECOMPUTED_BITS.shiftRight(SQRT2_PRECOMPUTE_THRESHOLD - logFloor);
109          if (x.compareTo(halfPower) <= 0) {
110            return logFloor;
111          } else {
112            return logFloor + 1;
113          }
114        }
115        // Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
116        //
117        // To determine which side of logFloor.5 the logarithm is,
118        // we compare x^2 to 2^(2 * logFloor + 1).
119        BigInteger x2 = x.pow(2);
120        int logX2Floor = x2.bitLength() - 1;
121        return (logX2Floor < 2 * logFloor + 1) ? logFloor : logFloor + 1;
122
123      default:
124        throw new AssertionError();
125    }
126  }
127
128  /*
129   * The maximum number of bits in a square root for which we'll precompute an explicit half power
130   * of two. This can be any value, but higher values incur more class load time and linearly
131   * increasing memory consumption.
132   */
133  @VisibleForTesting static final int SQRT2_PRECOMPUTE_THRESHOLD = 256;
134
135  @VisibleForTesting
136  static final BigInteger SQRT2_PRECOMPUTED_BITS =
137      new BigInteger("16a09e667f3bcc908b2fb1366ea957d3e3adec17512775099da2f590b0667322a", 16);
138
139  /**
140   * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
141   *
142   * @throws IllegalArgumentException if {@code x <= 0}
143   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
144   *     is not a power of ten
145   */
146  @GwtIncompatible // TODO
147  @SuppressWarnings("fallthrough")
148  public static int log10(BigInteger x, RoundingMode mode) {
149    checkPositive("x", x);
150    if (fitsInLong(x)) {
151      return LongMath.log10(x.longValue(), mode);
152    }
153
154    int approxLog10 = (int) (log2(x, FLOOR) * LN_2 / LN_10);
155    BigInteger approxPow = BigInteger.TEN.pow(approxLog10);
156    int approxCmp = approxPow.compareTo(x);
157
158    /*
159     * We adjust approxLog10 and approxPow until they're equal to floor(log10(x)) and
160     * 10^floor(log10(x)).
161     */
162
163    if (approxCmp > 0) {
164      /*
165       * The code is written so that even completely incorrect approximations will still yield the
166       * correct answer eventually, but in practice this branch should almost never be entered, and
167       * even then the loop should not run more than once.
168       */
169      do {
170        approxLog10--;
171        approxPow = approxPow.divide(BigInteger.TEN);
172        approxCmp = approxPow.compareTo(x);
173      } while (approxCmp > 0);
174    } else {
175      BigInteger nextPow = BigInteger.TEN.multiply(approxPow);
176      int nextCmp = nextPow.compareTo(x);
177      while (nextCmp <= 0) {
178        approxLog10++;
179        approxPow = nextPow;
180        approxCmp = nextCmp;
181        nextPow = BigInteger.TEN.multiply(approxPow);
182        nextCmp = nextPow.compareTo(x);
183      }
184    }
185
186    int floorLog = approxLog10;
187    BigInteger floorPow = approxPow;
188    int floorCmp = approxCmp;
189
190    switch (mode) {
191      case UNNECESSARY:
192        checkRoundingUnnecessary(floorCmp == 0);
193        // fall through
194      case FLOOR:
195      case DOWN:
196        return floorLog;
197
198      case CEILING:
199      case UP:
200        return floorPow.equals(x) ? floorLog : floorLog + 1;
201
202      case HALF_DOWN:
203      case HALF_UP:
204      case HALF_EVEN:
205        // Since sqrt(10) is irrational, log10(x) - floorLog can never be exactly 0.5
206        BigInteger x2 = x.pow(2);
207        BigInteger halfPowerSquared = floorPow.pow(2).multiply(BigInteger.TEN);
208        return (x2.compareTo(halfPowerSquared) <= 0) ? floorLog : floorLog + 1;
209      default:
210        throw new AssertionError();
211    }
212  }
213
214  private static final double LN_10 = Math.log(10);
215  private static final double LN_2 = Math.log(2);
216
217  /**
218   * Returns the square root of {@code x}, rounded with the specified rounding mode.
219   *
220   * @throws IllegalArgumentException if {@code x < 0}
221   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code
222   *     sqrt(x)} is not an integer
223   */
224  @GwtIncompatible // TODO
225  @SuppressWarnings("fallthrough")
226  public static BigInteger sqrt(BigInteger x, RoundingMode mode) {
227    checkNonNegative("x", x);
228    if (fitsInLong(x)) {
229      return BigInteger.valueOf(LongMath.sqrt(x.longValue(), mode));
230    }
231    BigInteger sqrtFloor = sqrtFloor(x);
232    switch (mode) {
233      case UNNECESSARY:
234        checkRoundingUnnecessary(sqrtFloor.pow(2).equals(x)); // fall through
235      case FLOOR:
236      case DOWN:
237        return sqrtFloor;
238      case CEILING:
239      case UP:
240        int sqrtFloorInt = sqrtFloor.intValue();
241        boolean sqrtFloorIsExact =
242            (sqrtFloorInt * sqrtFloorInt == x.intValue()) // fast check mod 2^32
243                && sqrtFloor.pow(2).equals(x); // slow exact check
244        return sqrtFloorIsExact ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
245      case HALF_DOWN:
246      case HALF_UP:
247      case HALF_EVEN:
248        BigInteger halfSquare = sqrtFloor.pow(2).add(sqrtFloor);
249        /*
250         * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both x
251         * and halfSquare are integers, this is equivalent to testing whether or not x <=
252         * halfSquare.
253         */
254        return (halfSquare.compareTo(x) >= 0) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
255      default:
256        throw new AssertionError();
257    }
258  }
259
260  @GwtIncompatible // TODO
261  private static BigInteger sqrtFloor(BigInteger x) {
262    /*
263     * Adapted from Hacker's Delight, Figure 11-1.
264     *
265     * Using DoubleUtils.bigToDouble, getting a double approximation of x is extremely fast, and
266     * then we can get a double approximation of the square root. Then, we iteratively improve this
267     * guess with an application of Newton's method, which sets guess := (guess + (x / guess)) / 2.
268     * This iteration has the following two properties:
269     *
270     * a) every iteration (except potentially the first) has guess >= floor(sqrt(x)). This is
271     * because guess' is the arithmetic mean of guess and x / guess, sqrt(x) is the geometric mean,
272     * and the arithmetic mean is always higher than the geometric mean.
273     *
274     * b) this iteration converges to floor(sqrt(x)). In fact, the number of correct digits doubles
275     * with each iteration, so this algorithm takes O(log(digits)) iterations.
276     *
277     * We start out with a double-precision approximation, which may be higher or lower than the
278     * true value. Therefore, we perform at least one Newton iteration to get a guess that's
279     * definitely >= floor(sqrt(x)), and then continue the iteration until we reach a fixed point.
280     */
281    BigInteger sqrt0;
282    int log2 = log2(x, FLOOR);
283    if (log2 < Double.MAX_EXPONENT) {
284      sqrt0 = sqrtApproxWithDoubles(x);
285    } else {
286      int shift = (log2 - DoubleUtils.SIGNIFICAND_BITS) & ~1; // even!
287      /*
288       * We have that x / 2^shift < 2^54. Our initial approximation to sqrtFloor(x) will be
289       * 2^(shift/2) * sqrtApproxWithDoubles(x / 2^shift).
290       */
291      sqrt0 = sqrtApproxWithDoubles(x.shiftRight(shift)).shiftLeft(shift >> 1);
292    }
293    BigInteger sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
294    if (sqrt0.equals(sqrt1)) {
295      return sqrt0;
296    }
297    do {
298      sqrt0 = sqrt1;
299      sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
300    } while (sqrt1.compareTo(sqrt0) < 0);
301    return sqrt0;
302  }
303
304  @GwtIncompatible // TODO
305  private static BigInteger sqrtApproxWithDoubles(BigInteger x) {
306    return DoubleMath.roundToBigInteger(Math.sqrt(DoubleUtils.bigToDouble(x)), HALF_EVEN);
307  }
308
309  /**
310   * Returns the result of dividing {@code p} by {@code q}, rounding using the specified {@code
311   * RoundingMode}.
312   *
313   * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
314   *     is not an integer multiple of {@code b}
315   */
316  @GwtIncompatible // TODO
317  public static BigInteger divide(BigInteger p, BigInteger q, RoundingMode mode) {
318    BigDecimal pDec = new BigDecimal(p);
319    BigDecimal qDec = new BigDecimal(q);
320    return pDec.divide(qDec, 0, mode).toBigIntegerExact();
321  }
322
323  /**
324   * Returns {@code n!}, that is, the product of the first {@code n} positive integers, or {@code 1}
325   * if {@code n == 0}.
326   *
327   * <p><b>Warning:</b> the result takes <i>O(n log n)</i> space, so use cautiously.
328   *
329   * <p>This uses an efficient binary recursive algorithm to compute the factorial with balanced
330   * multiplies. It also removes all the 2s from the intermediate products (shifting them back in at
331   * the end).
332   *
333   * @throws IllegalArgumentException if {@code n < 0}
334   */
335  public static BigInteger factorial(int n) {
336    checkNonNegative("n", n);
337
338    // If the factorial is small enough, just use LongMath to do it.
339    if (n < LongMath.factorials.length) {
340      return BigInteger.valueOf(LongMath.factorials[n]);
341    }
342
343    // Pre-allocate space for our list of intermediate BigIntegers.
344    int approxSize = IntMath.divide(n * IntMath.log2(n, CEILING), Long.SIZE, CEILING);
345    ArrayList<BigInteger> bignums = new ArrayList<>(approxSize);
346
347    // Start from the pre-computed maximum long factorial.
348    int startingNumber = LongMath.factorials.length;
349    long product = LongMath.factorials[startingNumber - 1];
350    // Strip off 2s from this value.
351    int shift = Long.numberOfTrailingZeros(product);
352    product >>= shift;
353
354    // Use floor(log2(num)) + 1 to prevent overflow of multiplication.
355    int productBits = LongMath.log2(product, FLOOR) + 1;
356    int bits = LongMath.log2(startingNumber, FLOOR) + 1;
357    // Check for the next power of two boundary, to save us a CLZ operation.
358    int nextPowerOfTwo = 1 << (bits - 1);
359
360    // Iteratively multiply the longs as big as they can go.
361    for (long num = startingNumber; num <= n; num++) {
362      // Check to see if the floor(log2(num)) + 1 has changed.
363      if ((num & nextPowerOfTwo) != 0) {
364        nextPowerOfTwo <<= 1;
365        bits++;
366      }
367      // Get rid of the 2s in num.
368      int tz = Long.numberOfTrailingZeros(num);
369      long normalizedNum = num >> tz;
370      shift += tz;
371      // Adjust floor(log2(num)) + 1.
372      int normalizedBits = bits - tz;
373      // If it won't fit in a long, then we store off the intermediate product.
374      if (normalizedBits + productBits >= Long.SIZE) {
375        bignums.add(BigInteger.valueOf(product));
376        product = 1;
377        productBits = 0;
378      }
379      product *= normalizedNum;
380      productBits = LongMath.log2(product, FLOOR) + 1;
381    }
382    // Check for leftovers.
383    if (product > 1) {
384      bignums.add(BigInteger.valueOf(product));
385    }
386    // Efficiently multiply all the intermediate products together.
387    return listProduct(bignums).shiftLeft(shift);
388  }
389
390  static BigInteger listProduct(List<BigInteger> nums) {
391    return listProduct(nums, 0, nums.size());
392  }
393
394  static BigInteger listProduct(List<BigInteger> nums, int start, int end) {
395    switch (end - start) {
396      case 0:
397        return BigInteger.ONE;
398      case 1:
399        return nums.get(start);
400      case 2:
401        return nums.get(start).multiply(nums.get(start + 1));
402      case 3:
403        return nums.get(start).multiply(nums.get(start + 1)).multiply(nums.get(start + 2));
404      default:
405        // Otherwise, split the list in half and recursively do this.
406        int m = (end + start) >>> 1;
407        return listProduct(nums, start, m).multiply(listProduct(nums, m, end));
408    }
409  }
410
411  /**
412   * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
413   * {@code k}, that is, {@code n! / (k! (n - k)!)}.
414   *
415   * <p><b>Warning:</b> the result can take as much as <i>O(k log n)</i> space.
416   *
417   * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
418   */
419  public static BigInteger binomial(int n, int k) {
420    checkNonNegative("n", n);
421    checkNonNegative("k", k);
422    checkArgument(k <= n, "k (%s) > n (%s)", k, n);
423    if (k > (n >> 1)) {
424      k = n - k;
425    }
426    if (k < LongMath.biggestBinomials.length && n <= LongMath.biggestBinomials[k]) {
427      return BigInteger.valueOf(LongMath.binomial(n, k));
428    }
429
430    BigInteger accum = BigInteger.ONE;
431
432    long numeratorAccum = n;
433    long denominatorAccum = 1;
434
435    int bits = LongMath.log2(n, RoundingMode.CEILING);
436
437    int numeratorBits = bits;
438
439    for (int i = 1; i < k; i++) {
440      int p = n - i;
441      int q = i + 1;
442
443      // log2(p) >= bits - 1, because p >= n/2
444
445      if (numeratorBits + bits >= Long.SIZE - 1) {
446        // The numerator is as big as it can get without risking overflow.
447        // Multiply numeratorAccum / denominatorAccum into accum.
448        accum =
449            accum
450                .multiply(BigInteger.valueOf(numeratorAccum))
451                .divide(BigInteger.valueOf(denominatorAccum));
452        numeratorAccum = p;
453        denominatorAccum = q;
454        numeratorBits = bits;
455      } else {
456        // We can definitely multiply into the long accumulators without overflowing them.
457        numeratorAccum *= p;
458        denominatorAccum *= q;
459        numeratorBits += bits;
460      }
461    }
462    return accum
463        .multiply(BigInteger.valueOf(numeratorAccum))
464        .divide(BigInteger.valueOf(denominatorAccum));
465  }
466
467  // Returns true if BigInteger.valueOf(x.longValue()).equals(x).
468  @GwtIncompatible // TODO
469  static boolean fitsInLong(BigInteger x) {
470    return x.bitLength() <= Long.SIZE - 1;
471  }
472
473  private BigIntegerMath() {}
474}