001    /*
002     * Copyright (C) 2011 The Guava Authors
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package com.google.common.math;
018    
019    import static com.google.common.base.Preconditions.checkArgument;
020    import static com.google.common.base.Preconditions.checkNotNull;
021    import static com.google.common.math.MathPreconditions.checkNonNegative;
022    import static com.google.common.math.MathPreconditions.checkPositive;
023    import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
024    import static java.math.RoundingMode.CEILING;
025    import static java.math.RoundingMode.FLOOR;
026    import static java.math.RoundingMode.HALF_EVEN;
027    
028    import com.google.common.annotations.Beta;
029    import com.google.common.annotations.VisibleForTesting;
030    
031    import java.math.BigDecimal;
032    import java.math.BigInteger;
033    import java.math.RoundingMode;
034    import java.util.ArrayList;
035    import java.util.List;
036    
037    /**
038     * A class for arithmetic on values of type {@code BigInteger}.
039     *
040     * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
041     * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
042     *
043     * <p>Similar functionality for {@code int} and for {@code long} can be found in
044     * {@link IntMath} and {@link LongMath} respectively.
045     *
046     * @author Louis Wasserman
047     * @since 11.0
048     */
049    @Beta
050    public final class BigIntegerMath {
051      /**
052       * Returns {@code true} if {@code x} represents a power of two.
053       */
054      public static boolean isPowerOfTwo(BigInteger x) {
055        checkNotNull(x);
056        return x.signum() > 0 && x.getLowestSetBit() == x.bitLength() - 1;
057      }
058    
059      /**
060       * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
061       *
062       * @throws IllegalArgumentException if {@code x <= 0}
063       * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
064       *         is not a power of two
065       */
066      @SuppressWarnings("fallthrough")
067      public static int log2(BigInteger x, RoundingMode mode) {
068        checkPositive("x", checkNotNull(x));
069        int logFloor = x.bitLength() - 1;
070        switch (mode) {
071          case UNNECESSARY:
072            checkRoundingUnnecessary(isPowerOfTwo(x)); // fall through
073          case DOWN:
074          case FLOOR:
075            return logFloor;
076    
077          case UP:
078          case CEILING:
079            return isPowerOfTwo(x) ? logFloor : logFloor + 1;
080    
081          case HALF_DOWN:
082          case HALF_UP:
083          case HALF_EVEN:
084            if (logFloor < SQRT2_PRECOMPUTE_THRESHOLD) {
085              BigInteger halfPower = SQRT2_PRECOMPUTED_BITS.shiftRight(
086                  SQRT2_PRECOMPUTE_THRESHOLD - logFloor);
087              if (x.compareTo(halfPower) <= 0) {
088                return logFloor;
089              } else {
090                return logFloor + 1;
091              }
092            }
093            /*
094             * Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
095             *
096             * To determine which side of logFloor.5 the logarithm is, we compare x^2 to 2^(2 *
097             * logFloor + 1).
098             */
099            BigInteger x2 = x.pow(2);
100            int logX2Floor = x2.bitLength() - 1;
101            return (logX2Floor < 2 * logFloor + 1) ? logFloor : logFloor + 1;
102    
103          default:
104            throw new AssertionError();
105        }
106      }
107    
108      /*
109       * The maximum number of bits in a square root for which we'll precompute an explicit half power
110       * of two. This can be any value, but higher values incur more class load time and linearly
111       * increasing memory consumption.
112       */
113      @VisibleForTesting static final int SQRT2_PRECOMPUTE_THRESHOLD = 256;
114    
115      @VisibleForTesting static final BigInteger SQRT2_PRECOMPUTED_BITS =
116          new BigInteger("16a09e667f3bcc908b2fb1366ea957d3e3adec17512775099da2f590b0667322a", 16);
117    
118      /**
119       * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
120       *
121       * @throws IllegalArgumentException if {@code x <= 0}
122       * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
123       *         is not a power of ten
124       */
125      @SuppressWarnings("fallthrough")
126      public static int log10(BigInteger x, RoundingMode mode) {
127        checkPositive("x", x);
128        if (fitsInLong(x)) {
129          return LongMath.log10(x.longValue(), mode);
130        }
131    
132        // capacity of 10 suffices for all x <= 10^(2^10).
133        List<BigInteger> powersOf10 = new ArrayList<BigInteger>(10);
134        BigInteger powerOf10 = BigInteger.TEN;
135        while (x.compareTo(powerOf10) >= 0) {
136          powersOf10.add(powerOf10);
137          powerOf10 = powerOf10.pow(2);
138        }
139        BigInteger floorPow = BigInteger.ONE;
140        int floorLog = 0;
141        for (int i = powersOf10.size() - 1; i >= 0; i--) {
142          BigInteger powOf10 = powersOf10.get(i);
143          floorLog *= 2;
144          BigInteger tenPow = powOf10.multiply(floorPow);
145          if (x.compareTo(tenPow) >= 0) {
146            floorPow = tenPow;
147            floorLog++;
148          }
149        }
150        switch (mode) {
151          case UNNECESSARY:
152            checkRoundingUnnecessary(floorPow.equals(x));
153            // fall through
154          case FLOOR:
155          case DOWN:
156            return floorLog;
157    
158          case CEILING:
159          case UP:
160            return floorPow.equals(x) ? floorLog : floorLog + 1;
161    
162          case HALF_DOWN:
163          case HALF_UP:
164          case HALF_EVEN:
165            // Since sqrt(10) is irrational, log10(x) - floorLog can never be exactly 0.5
166            BigInteger x2 = x.pow(2);
167            BigInteger halfPowerSquared = floorPow.pow(2).multiply(BigInteger.TEN);
168            return (x2.compareTo(halfPowerSquared) <= 0) ? floorLog : floorLog + 1;
169          default:
170            throw new AssertionError();
171        }
172      }
173    
174      /**
175       * Returns the square root of {@code x}, rounded with the specified rounding mode.
176       *
177       * @throws IllegalArgumentException if {@code x < 0}
178       * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and
179       *         {@code sqrt(x)} is not an integer
180       */
181      @SuppressWarnings("fallthrough")
182      public static BigInteger sqrt(BigInteger x, RoundingMode mode) {
183        checkNonNegative("x", x);
184        if (fitsInLong(x)) {
185          return BigInteger.valueOf(LongMath.sqrt(x.longValue(), mode));
186        }
187        BigInteger sqrtFloor = sqrtFloor(x);
188        switch (mode) {
189          case UNNECESSARY:
190            checkRoundingUnnecessary(sqrtFloor.pow(2).equals(x)); // fall through
191          case FLOOR:
192          case DOWN:
193            return sqrtFloor;
194          case CEILING:
195          case UP:
196            return sqrtFloor.pow(2).equals(x) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
197          case HALF_DOWN:
198          case HALF_UP:
199          case HALF_EVEN:
200            BigInteger halfSquare = sqrtFloor.pow(2).add(sqrtFloor);
201            /*
202             * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both
203             * x and halfSquare are integers, this is equivalent to testing whether or not x <=
204             * halfSquare.
205             */
206            return (halfSquare.compareTo(x) >= 0) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
207          default:
208            throw new AssertionError();
209        }
210      }
211    
212      private static BigInteger sqrtFloor(BigInteger x) {
213        /*
214         * Adapted from Hacker's Delight, Figure 11-1.
215         *
216         * Using DoubleUtils.bigToDouble, getting a double approximation of x is extremely fast, and
217         * then we can get a double approximation of the square root. Then, we iteratively improve this
218         * guess with an application of Newton's method, which sets guess := (guess + (x / guess)) / 2.
219         * This iteration has the following two properties:
220         *
221         * a) every iteration (except potentially the first) has guess >= floor(sqrt(x)). This is
222         * because guess' is the arithmetic mean of guess and x / guess, sqrt(x) is the geometric mean,
223         * and the arithmetic mean is always higher than the geometric mean.
224         *
225         * b) this iteration converges to floor(sqrt(x)). In fact, the number of correct digits doubles
226         * with each iteration, so this algorithm takes O(log(digits)) iterations.
227         *
228         * We start out with a double-precision approximation, which may be higher or lower than the
229         * true value. Therefore, we perform at least one Newton iteration to get a guess that's
230         * definitely >= floor(sqrt(x)), and then continue the iteration until we reach a fixed point.
231         */
232        BigInteger sqrt0;
233        int log2 = log2(x, FLOOR);
234        if(log2 < DoubleUtils.MAX_DOUBLE_EXPONENT) {
235          sqrt0 = sqrtApproxWithDoubles(x);
236        } else {
237          int shift = (log2 - DoubleUtils.SIGNIFICAND_BITS) & ~1; // even!
238          /*
239           * We have that x / 2^shift < 2^54. Our initial approximation to sqrtFloor(x) will be
240           * 2^(shift/2) * sqrtApproxWithDoubles(x / 2^shift).
241           */
242          sqrt0 = sqrtApproxWithDoubles(x.shiftRight(shift)).shiftLeft(shift >> 1);
243        }
244        BigInteger sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
245        if (sqrt0.equals(sqrt1)) {
246          return sqrt0;
247        }
248        do {
249          sqrt0 = sqrt1;
250          sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
251        } while (sqrt1.compareTo(sqrt0) < 0);
252        return sqrt0;
253      }
254    
255      private static BigInteger sqrtApproxWithDoubles(BigInteger x) {
256        return DoubleMath.roundToBigInteger(Math.sqrt(DoubleUtils.bigToDouble(x)), HALF_EVEN);
257      }
258    
259      /**
260       * Returns the result of dividing {@code p} by {@code q}, rounding using the specified
261       * {@code RoundingMode}.
262       *
263       * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
264       *         is not an integer multiple of {@code b}
265       */
266      public static BigInteger divide(BigInteger p, BigInteger q, RoundingMode mode){
267        BigDecimal pDec = new BigDecimal(p);
268        BigDecimal qDec = new BigDecimal(q);
269        return pDec.divide(qDec, 0, mode).toBigIntegerExact();
270      }
271    
272      /**
273       * Returns {@code n!}, that is, the product of the first {@code n} positive
274       * integers, or {@code 1} if {@code n == 0}.
275       *
276       * <p><b>Warning</b>: the result takes <i>O(n log n)</i> space, so use cautiously.
277       *
278       * <p>This uses an efficient binary recursive algorithm to compute the factorial
279       * with balanced multiplies.  It also removes all the 2s from the intermediate
280       * products (shifting them back in at the end).
281       *
282       * @throws IllegalArgumentException if {@code n < 0}
283       */
284      public static BigInteger factorial(int n) {
285        checkNonNegative("n", n);
286    
287        // If the factorial is small enough, just use LongMath to do it.
288        if (n < LongMath.FACTORIALS.length) {
289          return BigInteger.valueOf(LongMath.FACTORIALS[n]);
290        }
291    
292        // Pre-allocate space for our list of intermediate BigIntegers.
293        int approxSize = IntMath.divide(n * IntMath.log2(n, CEILING), Long.SIZE, CEILING);
294        ArrayList<BigInteger> bignums = new ArrayList<BigInteger>(approxSize);
295    
296        // Start from the pre-computed maximum long factorial.
297        int startingNumber = LongMath.FACTORIALS.length;
298        long product = LongMath.FACTORIALS[startingNumber - 1];
299        // Strip off 2s from this value.
300        int shift = Long.numberOfTrailingZeros(product);
301        product >>= shift;
302    
303        // Use floor(log2(num)) + 1 to prevent overflow of multiplication.
304        int productBits = LongMath.log2(product, FLOOR) + 1;
305        int bits = LongMath.log2(startingNumber, FLOOR) + 1;
306        // Check for the next power of two boundary, to save us a CLZ operation.
307        int nextPowerOfTwo = 1 << (bits - 1);
308    
309        // Iteratively multiply the longs as big as they can go.
310        for (long num = startingNumber; num <= n; num++) {
311          // Check to see if the floor(log2(num)) + 1 has changed.
312          if ((num & nextPowerOfTwo) != 0) {
313            nextPowerOfTwo <<= 1;
314            bits++;
315          }
316          // Get rid of the 2s in num.
317          int tz = Long.numberOfTrailingZeros(num);
318          long normalizedNum = num >> tz;
319          shift += tz;
320          // Adjust floor(log2(num)) + 1.
321          int normalizedBits = bits - tz;
322          // If it won't fit in a long, then we store off the intermediate product.
323          if (normalizedBits + productBits >= Long.SIZE) {
324            bignums.add(BigInteger.valueOf(product));
325            product = 1;
326            productBits = 0;
327          }
328          product *= normalizedNum;
329          productBits = LongMath.log2(product, FLOOR) + 1;
330        }
331        // Check for leftovers.
332        if (product > 1) {
333          bignums.add(BigInteger.valueOf(product));
334        }
335        // Efficiently multiply all the intermediate products together.
336        return listProduct(bignums).shiftLeft(shift);
337      }
338    
339      static BigInteger listProduct(List<BigInteger> nums) {
340        return listProduct(nums, 0, nums.size());
341      }
342    
343      static BigInteger listProduct(List<BigInteger> nums, int start, int end) {
344        switch (end - start) {
345          case 0:
346            return BigInteger.ONE;
347          case 1:
348            return nums.get(start);
349          case 2:
350            return nums.get(start).multiply(nums.get(start + 1));
351          case 3:
352            return nums.get(start).multiply(nums.get(start + 1)).multiply(nums.get(start + 2));
353          default:
354            // Otherwise, split the list in half and recursively do this.
355            int m = (end + start) >>> 1;
356            return listProduct(nums, start, m).multiply(listProduct(nums, m, end));
357        }
358      }
359    
360     /**
361       * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
362       * {@code k}, that is, {@code n! / (k! (n - k)!)}.
363       *
364       * <p><b>Warning</b>: the result can take as much as <i>O(k log n)</i> space.
365       *
366       * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
367       */
368      public static BigInteger binomial(int n, int k) {
369        checkNonNegative("n", n);
370        checkNonNegative("k", k);
371        checkArgument(k <= n, "k (%s) > n (%s)", k, n);
372        if (k > (n >> 1)) {
373          k = n - k;
374        }
375        if (k < LongMath.BIGGEST_BINOMIALS.length && n <= LongMath.BIGGEST_BINOMIALS[k]) {
376          return BigInteger.valueOf(LongMath.binomial(n, k));
377        }
378        BigInteger result = BigInteger.ONE;
379        for (int i = 0; i < k; i++) {
380          result = result.multiply(BigInteger.valueOf(n - i));
381          result = result.divide(BigInteger.valueOf(i + 1));
382        }
383        return result;
384      }
385    
386      // Returns true if BigInteger.valueOf(x.longValue()).equals(x).
387      static boolean fitsInLong(BigInteger x) {
388        return x.bitLength() <= Long.SIZE - 1;
389      }
390    
391      private BigIntegerMath() {}
392    }