001/*
002 * Copyright (C) 2011 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.math;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkNotNull;
021import static com.google.common.math.MathPreconditions.checkNonNegative;
022import static com.google.common.math.MathPreconditions.checkPositive;
023import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
024import static java.math.RoundingMode.CEILING;
025import static java.math.RoundingMode.FLOOR;
026import static java.math.RoundingMode.HALF_EVEN;
027
028import com.google.common.annotations.GwtCompatible;
029import com.google.common.annotations.GwtIncompatible;
030import com.google.common.annotations.VisibleForTesting;
031
032import java.math.BigDecimal;
033import java.math.BigInteger;
034import java.math.RoundingMode;
035import java.util.ArrayList;
036import java.util.List;
037
038/**
039 * A class for arithmetic on values of type {@code BigInteger}.
040 *
041 * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
042 * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
043 *
044 * <p>Similar functionality for {@code int} and for {@code long} can be found in
045 * {@link IntMath} and {@link LongMath} respectively.
046 *
047 * @author Louis Wasserman
048 * @since 11.0
049 */
050@GwtCompatible(emulated = true)
051public final class BigIntegerMath {
052  /**
053   * Returns {@code true} if {@code x} represents a power of two.
054   */
055  public static boolean isPowerOfTwo(BigInteger x) {
056    checkNotNull(x);
057    return x.signum() > 0 && x.getLowestSetBit() == x.bitLength() - 1;
058  }
059
060  /**
061   * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
062   *
063   * @throws IllegalArgumentException if {@code x <= 0}
064   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
065   *         is not a power of two
066   */
067  @SuppressWarnings("fallthrough")
068  // TODO(kevinb): remove after this warning is disabled globally
069  public static int log2(BigInteger x, RoundingMode mode) {
070    checkPositive("x", checkNotNull(x));
071    int logFloor = x.bitLength() - 1;
072    switch (mode) {
073      case UNNECESSARY:
074        checkRoundingUnnecessary(isPowerOfTwo(x)); // fall through
075      case DOWN:
076      case FLOOR:
077        return logFloor;
078
079      case UP:
080      case CEILING:
081        return isPowerOfTwo(x) ? logFloor : logFloor + 1;
082
083      case HALF_DOWN:
084      case HALF_UP:
085      case HALF_EVEN:
086        if (logFloor < SQRT2_PRECOMPUTE_THRESHOLD) {
087          BigInteger halfPower = SQRT2_PRECOMPUTED_BITS.shiftRight(
088              SQRT2_PRECOMPUTE_THRESHOLD - logFloor);
089          if (x.compareTo(halfPower) <= 0) {
090            return logFloor;
091          } else {
092            return logFloor + 1;
093          }
094        }
095        /*
096         * Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
097         *
098         * To determine which side of logFloor.5 the logarithm is, we compare x^2 to 2^(2 *
099         * logFloor + 1).
100         */
101        BigInteger x2 = x.pow(2);
102        int logX2Floor = x2.bitLength() - 1;
103        return (logX2Floor < 2 * logFloor + 1) ? logFloor : logFloor + 1;
104
105      default:
106        throw new AssertionError();
107    }
108  }
109
110  /*
111   * The maximum number of bits in a square root for which we'll precompute an explicit half power
112   * of two. This can be any value, but higher values incur more class load time and linearly
113   * increasing memory consumption.
114   */
115  @VisibleForTesting static final int SQRT2_PRECOMPUTE_THRESHOLD = 256;
116
117  @VisibleForTesting static final BigInteger SQRT2_PRECOMPUTED_BITS =
118      new BigInteger("16a09e667f3bcc908b2fb1366ea957d3e3adec17512775099da2f590b0667322a", 16);
119
120  /**
121   * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
122   *
123   * @throws IllegalArgumentException if {@code x <= 0}
124   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
125   *         is not a power of ten
126   */
127  @GwtIncompatible("TODO")
128  @SuppressWarnings("fallthrough")
129  public static int log10(BigInteger x, RoundingMode mode) {
130    checkPositive("x", x);
131    if (fitsInLong(x)) {
132      return LongMath.log10(x.longValue(), mode);
133    }
134
135    int approxLog10 = (int) (log2(x, FLOOR) * LN_2 / LN_10);
136    BigInteger approxPow = BigInteger.TEN.pow(approxLog10);
137    int approxCmp = approxPow.compareTo(x);
138
139    /*
140     * We adjust approxLog10 and approxPow until they're equal to floor(log10(x)) and
141     * 10^floor(log10(x)).
142     */
143
144    if (approxCmp > 0) {
145      /*
146       * The code is written so that even completely incorrect approximations will still yield the
147       * correct answer eventually, but in practice this branch should almost never be entered,
148       * and even then the loop should not run more than once.
149       */
150      do {
151        approxLog10--;
152        approxPow = approxPow.divide(BigInteger.TEN);
153        approxCmp = approxPow.compareTo(x);
154      } while (approxCmp > 0);
155    } else {
156      BigInteger nextPow = BigInteger.TEN.multiply(approxPow);
157      int nextCmp = nextPow.compareTo(x);
158      while (nextCmp <= 0) {
159        approxLog10++;
160        approxPow = nextPow;
161        approxCmp = nextCmp;
162        nextPow = BigInteger.TEN.multiply(approxPow);
163        nextCmp = nextPow.compareTo(x);
164      }
165    }
166
167    int floorLog = approxLog10;
168    BigInteger floorPow = approxPow;
169    int floorCmp = approxCmp;
170
171    switch (mode) {
172      case UNNECESSARY:
173        checkRoundingUnnecessary(floorCmp == 0);
174        // fall through
175      case FLOOR:
176      case DOWN:
177        return floorLog;
178
179      case CEILING:
180      case UP:
181        return floorPow.equals(x) ? floorLog : floorLog + 1;
182
183      case HALF_DOWN:
184      case HALF_UP:
185      case HALF_EVEN:
186        // Since sqrt(10) is irrational, log10(x) - floorLog can never be exactly 0.5
187        BigInteger x2 = x.pow(2);
188        BigInteger halfPowerSquared = floorPow.pow(2).multiply(BigInteger.TEN);
189        return (x2.compareTo(halfPowerSquared) <= 0) ? floorLog : floorLog + 1;
190      default:
191        throw new AssertionError();
192    }
193  }
194
195  private static final double LN_10 = Math.log(10);
196  private static final double LN_2 = Math.log(2);
197
198  /**
199   * Returns the square root of {@code x}, rounded with the specified rounding mode.
200   *
201   * @throws IllegalArgumentException if {@code x < 0}
202   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and
203   *         {@code sqrt(x)} is not an integer
204   */
205  @GwtIncompatible("TODO")
206  @SuppressWarnings("fallthrough")
207  public static BigInteger sqrt(BigInteger x, RoundingMode mode) {
208    checkNonNegative("x", x);
209    if (fitsInLong(x)) {
210      return BigInteger.valueOf(LongMath.sqrt(x.longValue(), mode));
211    }
212    BigInteger sqrtFloor = sqrtFloor(x);
213    switch (mode) {
214      case UNNECESSARY:
215        checkRoundingUnnecessary(sqrtFloor.pow(2).equals(x)); // fall through
216      case FLOOR:
217      case DOWN:
218        return sqrtFloor;
219      case CEILING:
220      case UP:
221        return sqrtFloor.pow(2).equals(x) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
222      case HALF_DOWN:
223      case HALF_UP:
224      case HALF_EVEN:
225        BigInteger halfSquare = sqrtFloor.pow(2).add(sqrtFloor);
226        /*
227         * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both
228         * x and halfSquare are integers, this is equivalent to testing whether or not x <=
229         * halfSquare.
230         */
231        return (halfSquare.compareTo(x) >= 0) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
232      default:
233        throw new AssertionError();
234    }
235  }
236
237  @GwtIncompatible("TODO")
238  private static BigInteger sqrtFloor(BigInteger x) {
239    /*
240     * Adapted from Hacker's Delight, Figure 11-1.
241     *
242     * Using DoubleUtils.bigToDouble, getting a double approximation of x is extremely fast, and
243     * then we can get a double approximation of the square root. Then, we iteratively improve this
244     * guess with an application of Newton's method, which sets guess := (guess + (x / guess)) / 2.
245     * This iteration has the following two properties:
246     *
247     * a) every iteration (except potentially the first) has guess >= floor(sqrt(x)). This is
248     * because guess' is the arithmetic mean of guess and x / guess, sqrt(x) is the geometric mean,
249     * and the arithmetic mean is always higher than the geometric mean.
250     *
251     * b) this iteration converges to floor(sqrt(x)). In fact, the number of correct digits doubles
252     * with each iteration, so this algorithm takes O(log(digits)) iterations.
253     *
254     * We start out with a double-precision approximation, which may be higher or lower than the
255     * true value. Therefore, we perform at least one Newton iteration to get a guess that's
256     * definitely >= floor(sqrt(x)), and then continue the iteration until we reach a fixed point.
257     */
258    BigInteger sqrt0;
259    int log2 = log2(x, FLOOR);
260    if (log2 < Double.MAX_EXPONENT) {
261      sqrt0 = sqrtApproxWithDoubles(x);
262    } else {
263      int shift = (log2 - DoubleUtils.SIGNIFICAND_BITS) & ~1; // even!
264      /*
265       * We have that x / 2^shift < 2^54. Our initial approximation to sqrtFloor(x) will be
266       * 2^(shift/2) * sqrtApproxWithDoubles(x / 2^shift).
267       */
268      sqrt0 = sqrtApproxWithDoubles(x.shiftRight(shift)).shiftLeft(shift >> 1);
269    }
270    BigInteger sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
271    if (sqrt0.equals(sqrt1)) {
272      return sqrt0;
273    }
274    do {
275      sqrt0 = sqrt1;
276      sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
277    } while (sqrt1.compareTo(sqrt0) < 0);
278    return sqrt0;
279  }
280
281  @GwtIncompatible("TODO")
282  private static BigInteger sqrtApproxWithDoubles(BigInteger x) {
283    return DoubleMath.roundToBigInteger(Math.sqrt(DoubleUtils.bigToDouble(x)), HALF_EVEN);
284  }
285
286  /**
287   * Returns the result of dividing {@code p} by {@code q}, rounding using the specified
288   * {@code RoundingMode}.
289   *
290   * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
291   *         is not an integer multiple of {@code b}
292   */
293  @GwtIncompatible("TODO")
294  public static BigInteger divide(BigInteger p, BigInteger q, RoundingMode mode) {
295    BigDecimal pDec = new BigDecimal(p);
296    BigDecimal qDec = new BigDecimal(q);
297    return pDec.divide(qDec, 0, mode).toBigIntegerExact();
298  }
299
300  /**
301   * Returns {@code n!}, that is, the product of the first {@code n} positive
302   * integers, or {@code 1} if {@code n == 0}.
303   *
304   * <p><b>Warning</b>: the result takes <i>O(n log n)</i> space, so use cautiously.
305   *
306   * <p>This uses an efficient binary recursive algorithm to compute the factorial
307   * with balanced multiplies.  It also removes all the 2s from the intermediate
308   * products (shifting them back in at the end).
309   *
310   * @throws IllegalArgumentException if {@code n < 0}
311   */
312  public static BigInteger factorial(int n) {
313    checkNonNegative("n", n);
314
315    // If the factorial is small enough, just use LongMath to do it.
316    if (n < LongMath.factorials.length) {
317      return BigInteger.valueOf(LongMath.factorials[n]);
318    }
319
320    // Pre-allocate space for our list of intermediate BigIntegers.
321    int approxSize = IntMath.divide(n * IntMath.log2(n, CEILING), Long.SIZE, CEILING);
322    ArrayList<BigInteger> bignums = new ArrayList<BigInteger>(approxSize);
323
324    // Start from the pre-computed maximum long factorial.
325    int startingNumber = LongMath.factorials.length;
326    long product = LongMath.factorials[startingNumber - 1];
327    // Strip off 2s from this value.
328    int shift = Long.numberOfTrailingZeros(product);
329    product >>= shift;
330
331    // Use floor(log2(num)) + 1 to prevent overflow of multiplication.
332    int productBits = LongMath.log2(product, FLOOR) + 1;
333    int bits = LongMath.log2(startingNumber, FLOOR) + 1;
334    // Check for the next power of two boundary, to save us a CLZ operation.
335    int nextPowerOfTwo = 1 << (bits - 1);
336
337    // Iteratively multiply the longs as big as they can go.
338    for (long num = startingNumber; num <= n; num++) {
339      // Check to see if the floor(log2(num)) + 1 has changed.
340      if ((num & nextPowerOfTwo) != 0) {
341        nextPowerOfTwo <<= 1;
342        bits++;
343      }
344      // Get rid of the 2s in num.
345      int tz = Long.numberOfTrailingZeros(num);
346      long normalizedNum = num >> tz;
347      shift += tz;
348      // Adjust floor(log2(num)) + 1.
349      int normalizedBits = bits - tz;
350      // If it won't fit in a long, then we store off the intermediate product.
351      if (normalizedBits + productBits >= Long.SIZE) {
352        bignums.add(BigInteger.valueOf(product));
353        product = 1;
354        productBits = 0;
355      }
356      product *= normalizedNum;
357      productBits = LongMath.log2(product, FLOOR) + 1;
358    }
359    // Check for leftovers.
360    if (product > 1) {
361      bignums.add(BigInteger.valueOf(product));
362    }
363    // Efficiently multiply all the intermediate products together.
364    return listProduct(bignums).shiftLeft(shift);
365  }
366
367  static BigInteger listProduct(List<BigInteger> nums) {
368    return listProduct(nums, 0, nums.size());
369  }
370
371  static BigInteger listProduct(List<BigInteger> nums, int start, int end) {
372    switch (end - start) {
373      case 0:
374        return BigInteger.ONE;
375      case 1:
376        return nums.get(start);
377      case 2:
378        return nums.get(start).multiply(nums.get(start + 1));
379      case 3:
380        return nums.get(start).multiply(nums.get(start + 1)).multiply(nums.get(start + 2));
381      default:
382        // Otherwise, split the list in half and recursively do this.
383        int m = (end + start) >>> 1;
384        return listProduct(nums, start, m).multiply(listProduct(nums, m, end));
385    }
386  }
387
388 /**
389   * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
390   * {@code k}, that is, {@code n! / (k! (n - k)!)}.
391   *
392   * <p><b>Warning</b>: the result can take as much as <i>O(k log n)</i> space.
393   *
394   * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
395   */
396  public static BigInteger binomial(int n, int k) {
397    checkNonNegative("n", n);
398    checkNonNegative("k", k);
399    checkArgument(k <= n, "k (%s) > n (%s)", k, n);
400    if (k > (n >> 1)) {
401      k = n - k;
402    }
403    if (k < LongMath.biggestBinomials.length && n <= LongMath.biggestBinomials[k]) {
404      return BigInteger.valueOf(LongMath.binomial(n, k));
405    }
406
407    BigInteger accum = BigInteger.ONE;
408
409    long numeratorAccum = n;
410    long denominatorAccum = 1;
411
412    int bits = LongMath.log2(n, RoundingMode.CEILING);
413
414    int numeratorBits = bits;
415
416    for (int i = 1; i < k; i++) {
417      int p = n - i;
418      int q = i + 1;
419
420      // log2(p) >= bits - 1, because p >= n/2
421
422      if (numeratorBits + bits >= Long.SIZE - 1) {
423        // The numerator is as big as it can get without risking overflow.
424        // Multiply numeratorAccum / denominatorAccum into accum.
425        accum = accum
426            .multiply(BigInteger.valueOf(numeratorAccum))
427            .divide(BigInteger.valueOf(denominatorAccum));
428        numeratorAccum = p;
429        denominatorAccum = q;
430        numeratorBits = bits;
431      } else {
432        // We can definitely multiply into the long accumulators without overflowing them.
433        numeratorAccum *= p;
434        denominatorAccum *= q;
435        numeratorBits += bits;
436      }
437    }
438    return accum
439        .multiply(BigInteger.valueOf(numeratorAccum))
440        .divide(BigInteger.valueOf(denominatorAccum));
441  }
442
443  // Returns true if BigInteger.valueOf(x.longValue()).equals(x).
444  @GwtIncompatible("TODO")
445  static boolean fitsInLong(BigInteger x) {
446    return x.bitLength() <= Long.SIZE - 1;
447  }
448
449  private BigIntegerMath() {}
450}