001/*
002 * Copyright (C) 2011 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.math;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkNotNull;
021import static com.google.common.math.MathPreconditions.checkNonNegative;
022import static com.google.common.math.MathPreconditions.checkPositive;
023import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
024import static java.math.RoundingMode.CEILING;
025import static java.math.RoundingMode.FLOOR;
026import static java.math.RoundingMode.HALF_EVEN;
027
028import com.google.common.annotations.GwtCompatible;
029import com.google.common.annotations.GwtIncompatible;
030import com.google.common.annotations.VisibleForTesting;
031
032import java.math.BigDecimal;
033import java.math.BigInteger;
034import java.math.RoundingMode;
035import java.util.ArrayList;
036import java.util.List;
037
038/**
039 * A class for arithmetic on values of type {@code BigInteger}.
040 *
041 * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
042 * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
043 *
044 * <p>Similar functionality for {@code int} and for {@code long} can be found in
045 * {@link IntMath} and {@link LongMath} respectively.
046 *
047 * @author Louis Wasserman
048 * @since 11.0
049 */
050@GwtCompatible(emulated = true)
051public final class BigIntegerMath {
052  /**
053   * Returns {@code true} if {@code x} represents a power of two.
054   */
055  public static boolean isPowerOfTwo(BigInteger x) {
056    checkNotNull(x);
057    return x.signum() > 0 && x.getLowestSetBit() == x.bitLength() - 1;
058  }
059
060  /**
061   * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
062   *
063   * @throws IllegalArgumentException if {@code x <= 0}
064   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
065   *         is not a power of two
066   */
067  @SuppressWarnings("fallthrough")
068  // TODO(kevinb): remove after this warning is disabled globally
069  public static int log2(BigInteger x, RoundingMode mode) {
070    checkPositive("x", checkNotNull(x));
071    int logFloor = x.bitLength() - 1;
072    switch (mode) {
073      case UNNECESSARY:
074        checkRoundingUnnecessary(isPowerOfTwo(x)); // fall through
075      case DOWN:
076      case FLOOR:
077        return logFloor;
078
079      case UP:
080      case CEILING:
081        return isPowerOfTwo(x) ? logFloor : logFloor + 1;
082
083      case HALF_DOWN:
084      case HALF_UP:
085      case HALF_EVEN:
086        if (logFloor < SQRT2_PRECOMPUTE_THRESHOLD) {
087          BigInteger halfPower = SQRT2_PRECOMPUTED_BITS.shiftRight(
088              SQRT2_PRECOMPUTE_THRESHOLD - logFloor);
089          if (x.compareTo(halfPower) <= 0) {
090            return logFloor;
091          } else {
092            return logFloor + 1;
093          }
094        }
095        /*
096         * Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
097         *
098         * To determine which side of logFloor.5 the logarithm is, we compare x^2 to 2^(2 *
099         * logFloor + 1).
100         */
101        BigInteger x2 = x.pow(2);
102        int logX2Floor = x2.bitLength() - 1;
103        return (logX2Floor < 2 * logFloor + 1) ? logFloor : logFloor + 1;
104
105      default:
106        throw new AssertionError();
107    }
108  }
109
110  /*
111   * The maximum number of bits in a square root for which we'll precompute an explicit half power
112   * of two. This can be any value, but higher values incur more class load time and linearly
113   * increasing memory consumption.
114   */
115  @VisibleForTesting static final int SQRT2_PRECOMPUTE_THRESHOLD = 256;
116
117  @VisibleForTesting static final BigInteger SQRT2_PRECOMPUTED_BITS =
118      new BigInteger("16a09e667f3bcc908b2fb1366ea957d3e3adec17512775099da2f590b0667322a", 16);
119
120  /**
121   * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
122   *
123   * @throws IllegalArgumentException if {@code x <= 0}
124   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
125   *         is not a power of ten
126   */
127  @GwtIncompatible("TODO")
128  @SuppressWarnings("fallthrough")
129  public static int log10(BigInteger x, RoundingMode mode) {
130    checkPositive("x", x);
131    if (fitsInLong(x)) {
132      return LongMath.log10(x.longValue(), mode);
133    }
134
135    int approxLog10 = (int) (log2(x, FLOOR) * LN_2 / LN_10);
136    BigInteger approxPow = BigInteger.TEN.pow(approxLog10);
137    int approxCmp = approxPow.compareTo(x);
138
139    /*
140     * We adjust approxLog10 and approxPow until they're equal to floor(log10(x)) and
141     * 10^floor(log10(x)).
142     */
143
144    if (approxCmp > 0) {
145      /*
146       * The code is written so that even completely incorrect approximations will still yield the
147       * correct answer eventually, but in practice this branch should almost never be entered,
148       * and even then the loop should not run more than once.
149       */
150      do {
151        approxLog10--;
152        approxPow = approxPow.divide(BigInteger.TEN);
153        approxCmp = approxPow.compareTo(x);
154      } while (approxCmp > 0);
155    } else {
156      BigInteger nextPow = BigInteger.TEN.multiply(approxPow);
157      int nextCmp = nextPow.compareTo(x);
158      while (nextCmp <= 0) {
159        approxLog10++;
160        approxPow = nextPow;
161        approxCmp = nextCmp;
162        nextPow = BigInteger.TEN.multiply(approxPow);
163        nextCmp = nextPow.compareTo(x);
164      }
165    }
166
167    int floorLog = approxLog10;
168    BigInteger floorPow = approxPow;
169    int floorCmp = approxCmp;
170
171    switch (mode) {
172      case UNNECESSARY:
173        checkRoundingUnnecessary(floorCmp == 0);
174        // fall through
175      case FLOOR:
176      case DOWN:
177        return floorLog;
178
179      case CEILING:
180      case UP:
181        return floorPow.equals(x) ? floorLog : floorLog + 1;
182
183      case HALF_DOWN:
184      case HALF_UP:
185      case HALF_EVEN:
186        // Since sqrt(10) is irrational, log10(x) - floorLog can never be exactly 0.5
187        BigInteger x2 = x.pow(2);
188        BigInteger halfPowerSquared = floorPow.pow(2).multiply(BigInteger.TEN);
189        return (x2.compareTo(halfPowerSquared) <= 0) ? floorLog : floorLog + 1;
190      default:
191        throw new AssertionError();
192    }
193  }
194
195  private static final double LN_10 = Math.log(10);
196  private static final double LN_2 = Math.log(2);
197
198  /**
199   * Returns the square root of {@code x}, rounded with the specified rounding mode.
200   *
201   * @throws IllegalArgumentException if {@code x < 0}
202   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and
203   *         {@code sqrt(x)} is not an integer
204   */
205  @GwtIncompatible("TODO")
206  @SuppressWarnings("fallthrough")
207  public static BigInteger sqrt(BigInteger x, RoundingMode mode) {
208    checkNonNegative("x", x);
209    if (fitsInLong(x)) {
210      return BigInteger.valueOf(LongMath.sqrt(x.longValue(), mode));
211    }
212    BigInteger sqrtFloor = sqrtFloor(x);
213    switch (mode) {
214      case UNNECESSARY:
215        checkRoundingUnnecessary(sqrtFloor.pow(2).equals(x)); // fall through
216      case FLOOR:
217      case DOWN:
218        return sqrtFloor;
219      case CEILING:
220      case UP:
221        int sqrtFloorInt = sqrtFloor.intValue();
222        boolean sqrtFloorIsExact =
223            (sqrtFloorInt * sqrtFloorInt == x.intValue()) // fast check mod 2^32
224            && sqrtFloor.pow(2).equals(x); // slow exact check
225        return sqrtFloorIsExact ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
226      case HALF_DOWN:
227      case HALF_UP:
228      case HALF_EVEN:
229        BigInteger halfSquare = sqrtFloor.pow(2).add(sqrtFloor);
230        /*
231         * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both
232         * x and halfSquare are integers, this is equivalent to testing whether or not x <=
233         * halfSquare.
234         */
235        return (halfSquare.compareTo(x) >= 0) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
236      default:
237        throw new AssertionError();
238    }
239  }
240
241  @GwtIncompatible("TODO")
242  private static BigInteger sqrtFloor(BigInteger x) {
243    /*
244     * Adapted from Hacker's Delight, Figure 11-1.
245     *
246     * Using DoubleUtils.bigToDouble, getting a double approximation of x is extremely fast, and
247     * then we can get a double approximation of the square root. Then, we iteratively improve this
248     * guess with an application of Newton's method, which sets guess := (guess + (x / guess)) / 2.
249     * This iteration has the following two properties:
250     *
251     * a) every iteration (except potentially the first) has guess >= floor(sqrt(x)). This is
252     * because guess' is the arithmetic mean of guess and x / guess, sqrt(x) is the geometric mean,
253     * and the arithmetic mean is always higher than the geometric mean.
254     *
255     * b) this iteration converges to floor(sqrt(x)). In fact, the number of correct digits doubles
256     * with each iteration, so this algorithm takes O(log(digits)) iterations.
257     *
258     * We start out with a double-precision approximation, which may be higher or lower than the
259     * true value. Therefore, we perform at least one Newton iteration to get a guess that's
260     * definitely >= floor(sqrt(x)), and then continue the iteration until we reach a fixed point.
261     */
262    BigInteger sqrt0;
263    int log2 = log2(x, FLOOR);
264    if (log2 < Double.MAX_EXPONENT) {
265      sqrt0 = sqrtApproxWithDoubles(x);
266    } else {
267      int shift = (log2 - DoubleUtils.SIGNIFICAND_BITS) & ~1; // even!
268      /*
269       * We have that x / 2^shift < 2^54. Our initial approximation to sqrtFloor(x) will be
270       * 2^(shift/2) * sqrtApproxWithDoubles(x / 2^shift).
271       */
272      sqrt0 = sqrtApproxWithDoubles(x.shiftRight(shift)).shiftLeft(shift >> 1);
273    }
274    BigInteger sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
275    if (sqrt0.equals(sqrt1)) {
276      return sqrt0;
277    }
278    do {
279      sqrt0 = sqrt1;
280      sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
281    } while (sqrt1.compareTo(sqrt0) < 0);
282    return sqrt0;
283  }
284
285  @GwtIncompatible("TODO")
286  private static BigInteger sqrtApproxWithDoubles(BigInteger x) {
287    return DoubleMath.roundToBigInteger(Math.sqrt(DoubleUtils.bigToDouble(x)), HALF_EVEN);
288  }
289
290  /**
291   * Returns the result of dividing {@code p} by {@code q}, rounding using the specified
292   * {@code RoundingMode}.
293   *
294   * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
295   *         is not an integer multiple of {@code b}
296   */
297  @GwtIncompatible("TODO")
298  public static BigInteger divide(BigInteger p, BigInteger q, RoundingMode mode) {
299    BigDecimal pDec = new BigDecimal(p);
300    BigDecimal qDec = new BigDecimal(q);
301    return pDec.divide(qDec, 0, mode).toBigIntegerExact();
302  }
303
304  /**
305   * Returns {@code n!}, that is, the product of the first {@code n} positive
306   * integers, or {@code 1} if {@code n == 0}.
307   *
308   * <p><b>Warning</b>: the result takes <i>O(n log n)</i> space, so use cautiously.
309   *
310   * <p>This uses an efficient binary recursive algorithm to compute the factorial
311   * with balanced multiplies.  It also removes all the 2s from the intermediate
312   * products (shifting them back in at the end).
313   *
314   * @throws IllegalArgumentException if {@code n < 0}
315   */
316  public static BigInteger factorial(int n) {
317    checkNonNegative("n", n);
318
319    // If the factorial is small enough, just use LongMath to do it.
320    if (n < LongMath.factorials.length) {
321      return BigInteger.valueOf(LongMath.factorials[n]);
322    }
323
324    // Pre-allocate space for our list of intermediate BigIntegers.
325    int approxSize = IntMath.divide(n * IntMath.log2(n, CEILING), Long.SIZE, CEILING);
326    ArrayList<BigInteger> bignums = new ArrayList<BigInteger>(approxSize);
327
328    // Start from the pre-computed maximum long factorial.
329    int startingNumber = LongMath.factorials.length;
330    long product = LongMath.factorials[startingNumber - 1];
331    // Strip off 2s from this value.
332    int shift = Long.numberOfTrailingZeros(product);
333    product >>= shift;
334
335    // Use floor(log2(num)) + 1 to prevent overflow of multiplication.
336    int productBits = LongMath.log2(product, FLOOR) + 1;
337    int bits = LongMath.log2(startingNumber, FLOOR) + 1;
338    // Check for the next power of two boundary, to save us a CLZ operation.
339    int nextPowerOfTwo = 1 << (bits - 1);
340
341    // Iteratively multiply the longs as big as they can go.
342    for (long num = startingNumber; num <= n; num++) {
343      // Check to see if the floor(log2(num)) + 1 has changed.
344      if ((num & nextPowerOfTwo) != 0) {
345        nextPowerOfTwo <<= 1;
346        bits++;
347      }
348      // Get rid of the 2s in num.
349      int tz = Long.numberOfTrailingZeros(num);
350      long normalizedNum = num >> tz;
351      shift += tz;
352      // Adjust floor(log2(num)) + 1.
353      int normalizedBits = bits - tz;
354      // If it won't fit in a long, then we store off the intermediate product.
355      if (normalizedBits + productBits >= Long.SIZE) {
356        bignums.add(BigInteger.valueOf(product));
357        product = 1;
358        productBits = 0;
359      }
360      product *= normalizedNum;
361      productBits = LongMath.log2(product, FLOOR) + 1;
362    }
363    // Check for leftovers.
364    if (product > 1) {
365      bignums.add(BigInteger.valueOf(product));
366    }
367    // Efficiently multiply all the intermediate products together.
368    return listProduct(bignums).shiftLeft(shift);
369  }
370
371  static BigInteger listProduct(List<BigInteger> nums) {
372    return listProduct(nums, 0, nums.size());
373  }
374
375  static BigInteger listProduct(List<BigInteger> nums, int start, int end) {
376    switch (end - start) {
377      case 0:
378        return BigInteger.ONE;
379      case 1:
380        return nums.get(start);
381      case 2:
382        return nums.get(start).multiply(nums.get(start + 1));
383      case 3:
384        return nums.get(start).multiply(nums.get(start + 1)).multiply(nums.get(start + 2));
385      default:
386        // Otherwise, split the list in half and recursively do this.
387        int m = (end + start) >>> 1;
388        return listProduct(nums, start, m).multiply(listProduct(nums, m, end));
389    }
390  }
391
392 /**
393   * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
394   * {@code k}, that is, {@code n! / (k! (n - k)!)}.
395   *
396   * <p><b>Warning</b>: the result can take as much as <i>O(k log n)</i> space.
397   *
398   * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
399   */
400  public static BigInteger binomial(int n, int k) {
401    checkNonNegative("n", n);
402    checkNonNegative("k", k);
403    checkArgument(k <= n, "k (%s) > n (%s)", k, n);
404    if (k > (n >> 1)) {
405      k = n - k;
406    }
407    if (k < LongMath.biggestBinomials.length && n <= LongMath.biggestBinomials[k]) {
408      return BigInteger.valueOf(LongMath.binomial(n, k));
409    }
410
411    BigInteger accum = BigInteger.ONE;
412
413    long numeratorAccum = n;
414    long denominatorAccum = 1;
415
416    int bits = LongMath.log2(n, RoundingMode.CEILING);
417
418    int numeratorBits = bits;
419
420    for (int i = 1; i < k; i++) {
421      int p = n - i;
422      int q = i + 1;
423
424      // log2(p) >= bits - 1, because p >= n/2
425
426      if (numeratorBits + bits >= Long.SIZE - 1) {
427        // The numerator is as big as it can get without risking overflow.
428        // Multiply numeratorAccum / denominatorAccum into accum.
429        accum = accum
430            .multiply(BigInteger.valueOf(numeratorAccum))
431            .divide(BigInteger.valueOf(denominatorAccum));
432        numeratorAccum = p;
433        denominatorAccum = q;
434        numeratorBits = bits;
435      } else {
436        // We can definitely multiply into the long accumulators without overflowing them.
437        numeratorAccum *= p;
438        denominatorAccum *= q;
439        numeratorBits += bits;
440      }
441    }
442    return accum
443        .multiply(BigInteger.valueOf(numeratorAccum))
444        .divide(BigInteger.valueOf(denominatorAccum));
445  }
446
447  // Returns true if BigInteger.valueOf(x.longValue()).equals(x).
448  @GwtIncompatible("TODO")
449  static boolean fitsInLong(BigInteger x) {
450    return x.bitLength() <= Long.SIZE - 1;
451  }
452
453  private BigIntegerMath() {}
454}