001    /*
002     * Copyright (C) 2011 The Guava Authors
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package com.google.common.math;
018    
019    import static com.google.common.base.Preconditions.checkArgument;
020    import static com.google.common.base.Preconditions.checkNotNull;
021    import static com.google.common.math.MathPreconditions.checkNonNegative;
022    import static com.google.common.math.MathPreconditions.checkPositive;
023    import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
024    import static java.math.RoundingMode.CEILING;
025    import static java.math.RoundingMode.FLOOR;
026    import static java.math.RoundingMode.HALF_EVEN;
027    
028    import com.google.common.annotations.Beta;
029    import com.google.common.annotations.GwtCompatible;
030    import com.google.common.annotations.GwtIncompatible;
031    import com.google.common.annotations.VisibleForTesting;
032    
033    import java.math.BigDecimal;
034    import java.math.BigInteger;
035    import java.math.RoundingMode;
036    import java.util.ArrayList;
037    import java.util.List;
038    
039    /**
040     * A class for arithmetic on values of type {@code BigInteger}.
041     *
042     * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
043     * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
044     *
045     * <p>Similar functionality for {@code int} and for {@code long} can be found in
046     * {@link IntMath} and {@link LongMath} respectively.
047     *
048     * @author Louis Wasserman
049     * @since 11.0
050     */
051    @Beta
052    @GwtCompatible(emulated = true)
053    public final class BigIntegerMath {
054      /**
055       * Returns {@code true} if {@code x} represents a power of two.
056       */
057      public static boolean isPowerOfTwo(BigInteger x) {
058        checkNotNull(x);
059        return x.signum() > 0 && x.getLowestSetBit() == x.bitLength() - 1;
060      }
061    
062      /**
063       * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
064       *
065       * @throws IllegalArgumentException if {@code x <= 0}
066       * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
067       *         is not a power of two
068       */
069      @SuppressWarnings("fallthrough")
070      public static int log2(BigInteger x, RoundingMode mode) {
071        checkPositive("x", checkNotNull(x));
072        int logFloor = x.bitLength() - 1;
073        switch (mode) {
074          case UNNECESSARY:
075            checkRoundingUnnecessary(isPowerOfTwo(x)); // fall through
076          case DOWN:
077          case FLOOR:
078            return logFloor;
079    
080          case UP:
081          case CEILING:
082            return isPowerOfTwo(x) ? logFloor : logFloor + 1;
083    
084          case HALF_DOWN:
085          case HALF_UP:
086          case HALF_EVEN:
087            if (logFloor < SQRT2_PRECOMPUTE_THRESHOLD) {
088              BigInteger halfPower = SQRT2_PRECOMPUTED_BITS.shiftRight(
089                  SQRT2_PRECOMPUTE_THRESHOLD - logFloor);
090              if (x.compareTo(halfPower) <= 0) {
091                return logFloor;
092              } else {
093                return logFloor + 1;
094              }
095            }
096            /*
097             * Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
098             *
099             * To determine which side of logFloor.5 the logarithm is, we compare x^2 to 2^(2 *
100             * logFloor + 1).
101             */
102            BigInteger x2 = x.pow(2);
103            int logX2Floor = x2.bitLength() - 1;
104            return (logX2Floor < 2 * logFloor + 1) ? logFloor : logFloor + 1;
105    
106          default:
107            throw new AssertionError();
108        }
109      }
110    
111      /*
112       * The maximum number of bits in a square root for which we'll precompute an explicit half power
113       * of two. This can be any value, but higher values incur more class load time and linearly
114       * increasing memory consumption.
115       */
116      @VisibleForTesting static final int SQRT2_PRECOMPUTE_THRESHOLD = 256;
117    
118      @VisibleForTesting static final BigInteger SQRT2_PRECOMPUTED_BITS =
119          new BigInteger("16a09e667f3bcc908b2fb1366ea957d3e3adec17512775099da2f590b0667322a", 16);
120    
121      /**
122       * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
123       *
124       * @throws IllegalArgumentException if {@code x <= 0}
125       * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
126       *         is not a power of ten
127       */
128      @GwtIncompatible("TODO")
129      @SuppressWarnings("fallthrough")
130      public static int log10(BigInteger x, RoundingMode mode) {
131        checkPositive("x", x);
132        if (fitsInLong(x)) {
133          return LongMath.log10(x.longValue(), mode);
134        }
135    
136        int approxLog10 = (int) (log2(x, FLOOR) * LN_2 / LN_10);
137        BigInteger approxPow = BigInteger.TEN.pow(approxLog10);
138        int approxCmp = approxPow.compareTo(x);
139    
140        /*
141         * We adjust approxLog10 and approxPow until they're equal to floor(log10(x)) and
142         * 10^floor(log10(x)).
143         */
144    
145        if (approxCmp > 0) {
146          /*
147           * The code is written so that even completely incorrect approximations will still yield the
148           * correct answer eventually, but in practice this branch should almost never be entered,
149           * and even then the loop should not run more than once.
150           */
151          do {
152            approxLog10--;
153            approxPow = approxPow.divide(BigInteger.TEN);
154            approxCmp = approxPow.compareTo(x);
155          } while (approxCmp > 0);
156        } else {
157          BigInteger nextPow = BigInteger.TEN.multiply(approxPow);
158          int nextCmp = nextPow.compareTo(x);
159          while (nextCmp <= 0) {
160            approxLog10++;
161            approxPow = nextPow;
162            approxCmp = nextCmp;
163            nextPow = BigInteger.TEN.multiply(approxPow);
164            nextCmp = nextPow.compareTo(x);
165          }
166        }
167    
168        int floorLog = approxLog10;
169        BigInteger floorPow = approxPow;
170        int floorCmp = approxCmp;
171    
172        switch (mode) {
173          case UNNECESSARY:
174            checkRoundingUnnecessary(floorCmp == 0);
175            // fall through
176          case FLOOR:
177          case DOWN:
178            return floorLog;
179    
180          case CEILING:
181          case UP:
182            return floorPow.equals(x) ? floorLog : floorLog + 1;
183    
184          case HALF_DOWN:
185          case HALF_UP:
186          case HALF_EVEN:
187            // Since sqrt(10) is irrational, log10(x) - floorLog can never be exactly 0.5
188            BigInteger x2 = x.pow(2);
189            BigInteger halfPowerSquared = floorPow.pow(2).multiply(BigInteger.TEN);
190            return (x2.compareTo(halfPowerSquared) <= 0) ? floorLog : floorLog + 1;
191          default:
192            throw new AssertionError();
193        }
194      }
195    
196      private static final double LN_10 = Math.log(10);
197      private static final double LN_2 = Math.log(2);
198    
199      /**
200       * Returns the square root of {@code x}, rounded with the specified rounding mode.
201       *
202       * @throws IllegalArgumentException if {@code x < 0}
203       * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and
204       *         {@code sqrt(x)} is not an integer
205       */
206      @GwtIncompatible("TODO")
207      @SuppressWarnings("fallthrough")
208      public static BigInteger sqrt(BigInteger x, RoundingMode mode) {
209        checkNonNegative("x", x);
210        if (fitsInLong(x)) {
211          return BigInteger.valueOf(LongMath.sqrt(x.longValue(), mode));
212        }
213        BigInteger sqrtFloor = sqrtFloor(x);
214        switch (mode) {
215          case UNNECESSARY:
216            checkRoundingUnnecessary(sqrtFloor.pow(2).equals(x)); // fall through
217          case FLOOR:
218          case DOWN:
219            return sqrtFloor;
220          case CEILING:
221          case UP:
222            return sqrtFloor.pow(2).equals(x) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
223          case HALF_DOWN:
224          case HALF_UP:
225          case HALF_EVEN:
226            BigInteger halfSquare = sqrtFloor.pow(2).add(sqrtFloor);
227            /*
228             * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both
229             * x and halfSquare are integers, this is equivalent to testing whether or not x <=
230             * halfSquare.
231             */
232            return (halfSquare.compareTo(x) >= 0) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
233          default:
234            throw new AssertionError();
235        }
236      }
237    
238      @GwtIncompatible("TODO")
239      private static BigInteger sqrtFloor(BigInteger x) {
240        /*
241         * Adapted from Hacker's Delight, Figure 11-1.
242         *
243         * Using DoubleUtils.bigToDouble, getting a double approximation of x is extremely fast, and
244         * then we can get a double approximation of the square root. Then, we iteratively improve this
245         * guess with an application of Newton's method, which sets guess := (guess + (x / guess)) / 2.
246         * This iteration has the following two properties:
247         *
248         * a) every iteration (except potentially the first) has guess >= floor(sqrt(x)). This is
249         * because guess' is the arithmetic mean of guess and x / guess, sqrt(x) is the geometric mean,
250         * and the arithmetic mean is always higher than the geometric mean.
251         *
252         * b) this iteration converges to floor(sqrt(x)). In fact, the number of correct digits doubles
253         * with each iteration, so this algorithm takes O(log(digits)) iterations.
254         *
255         * We start out with a double-precision approximation, which may be higher or lower than the
256         * true value. Therefore, we perform at least one Newton iteration to get a guess that's
257         * definitely >= floor(sqrt(x)), and then continue the iteration until we reach a fixed point.
258         */
259        BigInteger sqrt0;
260        int log2 = log2(x, FLOOR);
261        if(log2 < Double.MAX_EXPONENT) {
262          sqrt0 = sqrtApproxWithDoubles(x);
263        } else {
264          int shift = (log2 - DoubleUtils.SIGNIFICAND_BITS) & ~1; // even!
265          /*
266           * We have that x / 2^shift < 2^54. Our initial approximation to sqrtFloor(x) will be
267           * 2^(shift/2) * sqrtApproxWithDoubles(x / 2^shift).
268           */
269          sqrt0 = sqrtApproxWithDoubles(x.shiftRight(shift)).shiftLeft(shift >> 1);
270        }
271        BigInteger sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
272        if (sqrt0.equals(sqrt1)) {
273          return sqrt0;
274        }
275        do {
276          sqrt0 = sqrt1;
277          sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
278        } while (sqrt1.compareTo(sqrt0) < 0);
279        return sqrt0;
280      }
281    
282      @GwtIncompatible("TODO")
283      private static BigInteger sqrtApproxWithDoubles(BigInteger x) {
284        return DoubleMath.roundToBigInteger(Math.sqrt(DoubleUtils.bigToDouble(x)), HALF_EVEN);
285      }
286    
287      /**
288       * Returns the result of dividing {@code p} by {@code q}, rounding using the specified
289       * {@code RoundingMode}.
290       *
291       * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
292       *         is not an integer multiple of {@code b}
293       */
294      @GwtIncompatible("TODO")
295      public static BigInteger divide(BigInteger p, BigInteger q, RoundingMode mode){
296        BigDecimal pDec = new BigDecimal(p);
297        BigDecimal qDec = new BigDecimal(q);
298        return pDec.divide(qDec, 0, mode).toBigIntegerExact();
299      }
300    
301      /**
302       * Returns {@code n!}, that is, the product of the first {@code n} positive
303       * integers, or {@code 1} if {@code n == 0}.
304       *
305       * <p><b>Warning</b>: the result takes <i>O(n log n)</i> space, so use cautiously.
306       *
307       * <p>This uses an efficient binary recursive algorithm to compute the factorial
308       * with balanced multiplies.  It also removes all the 2s from the intermediate
309       * products (shifting them back in at the end).
310       *
311       * @throws IllegalArgumentException if {@code n < 0}
312       */
313      public static BigInteger factorial(int n) {
314        checkNonNegative("n", n);
315    
316        // If the factorial is small enough, just use LongMath to do it.
317        if (n < LongMath.FACTORIALS.length) {
318          return BigInteger.valueOf(LongMath.FACTORIALS[n]);
319        }
320    
321        // Pre-allocate space for our list of intermediate BigIntegers.
322        int approxSize = IntMath.divide(n * IntMath.log2(n, CEILING), Long.SIZE, CEILING);
323        ArrayList<BigInteger> bignums = new ArrayList<BigInteger>(approxSize);
324    
325        // Start from the pre-computed maximum long factorial.
326        int startingNumber = LongMath.FACTORIALS.length;
327        long product = LongMath.FACTORIALS[startingNumber - 1];
328        // Strip off 2s from this value.
329        int shift = Long.numberOfTrailingZeros(product);
330        product >>= shift;
331    
332        // Use floor(log2(num)) + 1 to prevent overflow of multiplication.
333        int productBits = LongMath.log2(product, FLOOR) + 1;
334        int bits = LongMath.log2(startingNumber, FLOOR) + 1;
335        // Check for the next power of two boundary, to save us a CLZ operation.
336        int nextPowerOfTwo = 1 << (bits - 1);
337    
338        // Iteratively multiply the longs as big as they can go.
339        for (long num = startingNumber; num <= n; num++) {
340          // Check to see if the floor(log2(num)) + 1 has changed.
341          if ((num & nextPowerOfTwo) != 0) {
342            nextPowerOfTwo <<= 1;
343            bits++;
344          }
345          // Get rid of the 2s in num.
346          int tz = Long.numberOfTrailingZeros(num);
347          long normalizedNum = num >> tz;
348          shift += tz;
349          // Adjust floor(log2(num)) + 1.
350          int normalizedBits = bits - tz;
351          // If it won't fit in a long, then we store off the intermediate product.
352          if (normalizedBits + productBits >= Long.SIZE) {
353            bignums.add(BigInteger.valueOf(product));
354            product = 1;
355            productBits = 0;
356          }
357          product *= normalizedNum;
358          productBits = LongMath.log2(product, FLOOR) + 1;
359        }
360        // Check for leftovers.
361        if (product > 1) {
362          bignums.add(BigInteger.valueOf(product));
363        }
364        // Efficiently multiply all the intermediate products together.
365        return listProduct(bignums).shiftLeft(shift);
366      }
367    
368      static BigInteger listProduct(List<BigInteger> nums) {
369        return listProduct(nums, 0, nums.size());
370      }
371    
372      static BigInteger listProduct(List<BigInteger> nums, int start, int end) {
373        switch (end - start) {
374          case 0:
375            return BigInteger.ONE;
376          case 1:
377            return nums.get(start);
378          case 2:
379            return nums.get(start).multiply(nums.get(start + 1));
380          case 3:
381            return nums.get(start).multiply(nums.get(start + 1)).multiply(nums.get(start + 2));
382          default:
383            // Otherwise, split the list in half and recursively do this.
384            int m = (end + start) >>> 1;
385            return listProduct(nums, start, m).multiply(listProduct(nums, m, end));
386        }
387      }
388    
389     /**
390       * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
391       * {@code k}, that is, {@code n! / (k! (n - k)!)}.
392       *
393       * <p><b>Warning</b>: the result can take as much as <i>O(k log n)</i> space.
394       *
395       * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
396       */
397      public static BigInteger binomial(int n, int k) {
398        checkNonNegative("n", n);
399        checkNonNegative("k", k);
400        checkArgument(k <= n, "k (%s) > n (%s)", k, n);
401        if (k > (n >> 1)) {
402          k = n - k;
403        }
404        if (k < LongMath.BIGGEST_BINOMIALS.length && n <= LongMath.BIGGEST_BINOMIALS[k]) {
405          return BigInteger.valueOf(LongMath.binomial(n, k));
406        }
407    
408        BigInteger accum = BigInteger.ONE;
409    
410        long numeratorAccum = n;
411        long denominatorAccum = 1;
412    
413        int bits = LongMath.log2(n, RoundingMode.CEILING);
414    
415        int numeratorBits = bits;
416    
417        for (int i = 1; i < k; i++) {
418          int p = n - i;
419          int q = i + 1;
420    
421          // log2(p) >= bits - 1, because p >= n/2
422    
423          if (numeratorBits + bits >= Long.SIZE - 1) {
424            // The numerator is as big as it can get without risking overflow.
425            // Multiply numeratorAccum / denominatorAccum into accum.
426            accum = accum
427                .multiply(BigInteger.valueOf(numeratorAccum))
428                .divide(BigInteger.valueOf(denominatorAccum));
429            numeratorAccum = p;
430            denominatorAccum = q;
431            numeratorBits = bits;
432          } else {
433            // We can definitely multiply into the long accumulators without overflowing them.
434            numeratorAccum *= p;
435            denominatorAccum *= q;
436            numeratorBits += bits;
437          }
438        }
439        return accum
440            .multiply(BigInteger.valueOf(numeratorAccum))
441            .divide(BigInteger.valueOf(denominatorAccum));
442      }
443    
444      // Returns true if BigInteger.valueOf(x.longValue()).equals(x).
445      @GwtIncompatible("TODO")
446      static boolean fitsInLong(BigInteger x) {
447        return x.bitLength() <= Long.SIZE - 1;
448      }
449    
450      private BigIntegerMath() {}
451    }