001    /*
002     * Copyright (C) 2011 The Guava Authors
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package com.google.common.math;
018    
019    import static com.google.common.base.Preconditions.checkArgument;
020    import static com.google.common.base.Preconditions.checkNotNull;
021    import static com.google.common.math.MathPreconditions.checkNoOverflow;
022    import static com.google.common.math.MathPreconditions.checkNonNegative;
023    import static com.google.common.math.MathPreconditions.checkPositive;
024    import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
025    import static java.lang.Math.abs;
026    import static java.lang.Math.min;
027    import static java.math.RoundingMode.HALF_EVEN;
028    import static java.math.RoundingMode.HALF_UP;
029    
030    import com.google.common.annotations.Beta;
031    import com.google.common.annotations.VisibleForTesting;
032    
033    import java.math.BigInteger;
034    import java.math.RoundingMode;
035    
036    /**
037     * A class for arithmetic on values of type {@code long}. Where possible, methods are defined and
038     * named analogously to their {@code BigInteger} counterparts.
039     *
040     * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
041     * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
042     *
043     * <p>Similar functionality for {@code int} and for {@link BigInteger} can be found in
044     * {@link IntMath} and {@link BigIntegerMath} respectively.  For other common operations on
045     * {@code long} values, see {@link com.google.common.primitives.Longs}.
046     *
047     * @author Louis Wasserman
048     * @since 11.0
049     */
050    @Beta
051    public final class LongMath {
052      // NOTE: Whenever both tests are cheap and functional, it's faster to use &, | instead of &&, ||
053    
054      /**
055       * Returns {@code true} if {@code x} represents a power of two.
056       *
057       * <p>This differs from {@code Long.bitCount(x) == 1}, because
058       * {@code Long.bitCount(Long.MIN_VALUE) == 1}, but {@link Long#MIN_VALUE} is not a power of two.
059       */
060      public static boolean isPowerOfTwo(long x) {
061        return x > 0 & (x & (x - 1)) == 0;
062      }
063    
064      /**
065       * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
066       *
067       * @throws IllegalArgumentException if {@code x <= 0}
068       * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
069       *         is not a power of two
070       */
071      @SuppressWarnings("fallthrough")
072      public static int log2(long x, RoundingMode mode) {
073        checkPositive("x", x);
074        switch (mode) {
075          case UNNECESSARY:
076            checkRoundingUnnecessary(isPowerOfTwo(x));
077            // fall through
078          case DOWN:
079          case FLOOR:
080            return (Long.SIZE - 1) - Long.numberOfLeadingZeros(x);
081    
082          case UP:
083          case CEILING:
084            return Long.SIZE - Long.numberOfLeadingZeros(x - 1);
085    
086          case HALF_DOWN:
087          case HALF_UP:
088          case HALF_EVEN:
089            // Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
090            int leadingZeros = Long.numberOfLeadingZeros(x);
091            long cmp = MAX_POWER_OF_SQRT2_UNSIGNED >>> leadingZeros;
092            // floor(2^(logFloor + 0.5))
093            int logFloor = (Long.SIZE - 1) - leadingZeros;
094            return (x <= cmp) ? logFloor : logFloor + 1;
095    
096          default:
097            throw new AssertionError("impossible");
098        }
099      }
100    
101      /** The biggest half power of two that fits into an unsigned long */
102      @VisibleForTesting static final long MAX_POWER_OF_SQRT2_UNSIGNED = 0xB504F333F9DE6484L;
103    
104      /**
105       * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
106       *
107       * @throws IllegalArgumentException if {@code x <= 0}
108       * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
109       *         is not a power of ten
110       */
111      @SuppressWarnings("fallthrough")
112      public static int log10(long x, RoundingMode mode) {
113        checkPositive("x", x);
114        if (fitsInInt(x)) {
115          return IntMath.log10((int) x, mode);
116        }
117        int logFloor = log10Floor(x);
118        long floorPow = POWERS_OF_10[logFloor];
119        switch (mode) {
120          case UNNECESSARY:
121            checkRoundingUnnecessary(x == floorPow);
122            // fall through
123          case FLOOR:
124          case DOWN:
125            return logFloor;
126          case CEILING:
127          case UP:
128            return (x == floorPow) ? logFloor : logFloor + 1;
129          case HALF_DOWN:
130          case HALF_UP:
131          case HALF_EVEN:
132            // sqrt(10) is irrational, so log10(x)-logFloor is never exactly 0.5
133            return (x <= HALF_POWERS_OF_10[logFloor]) ? logFloor : logFloor + 1;
134          default:
135            throw new AssertionError();
136        }
137      }
138    
139      static int log10Floor(long x) {
140        for (int i = 1; i < POWERS_OF_10.length; i++) {
141          if (x < POWERS_OF_10[i]) {
142            return i - 1;
143          }
144        }
145        return POWERS_OF_10.length - 1;
146      }
147    
148      @VisibleForTesting
149      static final long[] POWERS_OF_10 = {
150        1L,
151        10L,
152        100L,
153        1000L,
154        10000L,
155        100000L,
156        1000000L,
157        10000000L,
158        100000000L,
159        1000000000L,
160        10000000000L,
161        100000000000L,
162        1000000000000L,
163        10000000000000L,
164        100000000000000L,
165        1000000000000000L,
166        10000000000000000L,
167        100000000000000000L,
168        1000000000000000000L
169      };
170    
171      // HALF_POWERS_OF_10[i] = largest long less than 10^(i + 0.5)
172      @VisibleForTesting
173      static final long[] HALF_POWERS_OF_10 = {
174        3L,
175        31L,
176        316L,
177        3162L,
178        31622L,
179        316227L,
180        3162277L,
181        31622776L,
182        316227766L,
183        3162277660L,
184        31622776601L,
185        316227766016L,
186        3162277660168L,
187        31622776601683L,
188        316227766016837L,
189        3162277660168379L,
190        31622776601683793L,
191        316227766016837933L,
192        3162277660168379331L
193      };
194    
195      /**
196       * Returns {@code b} to the {@code k}th power. Even if the result overflows, it will be equal to
197       * {@code BigInteger.valueOf(b).pow(k).longValue()}. This implementation runs in {@code O(log k)}
198       * time.
199       *
200       * @throws IllegalArgumentException if {@code k < 0}
201       */
202      public static long pow(long b, int k) {
203        checkNonNegative("exponent", k);
204        if (-2 <= b && b <= 2) {
205          switch ((int) b) {
206            case 0:
207              return (k == 0) ? 1 : 0;
208            case 1:
209              return 1;
210            case (-1):
211              return ((k & 1) == 0) ? 1 : -1;
212            case 2:
213              return (k < Long.SIZE) ? 1L << k : 0;
214            case (-2):
215              if (k < Long.SIZE) {
216                return ((k & 1) == 0) ? 1L << k : -(1L << k);
217              } else {
218                return 0;
219              }
220          }
221        }
222        for (long accum = 1;; k >>= 1) {
223          switch (k) {
224            case 0:
225              return accum;
226            case 1:
227              return accum * b;
228            default:
229              accum *= ((k & 1) == 0) ? 1 : b;
230              b *= b;
231          }
232        }
233      }
234    
235      /**
236       * Returns the square root of {@code x}, rounded with the specified rounding mode.
237       *
238       * @throws IllegalArgumentException if {@code x < 0}
239       * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and
240       *         {@code sqrt(x)} is not an integer
241       */
242      @SuppressWarnings("fallthrough")
243      public static long sqrt(long x, RoundingMode mode) {
244        checkNonNegative("x", x);
245        if (fitsInInt(x)) {
246          return IntMath.sqrt((int) x, mode);
247        }
248        long sqrtFloor = sqrtFloor(x);
249        switch (mode) {
250          case UNNECESSARY:
251            checkRoundingUnnecessary(sqrtFloor * sqrtFloor == x); // fall through
252          case FLOOR:
253          case DOWN:
254            return sqrtFloor;
255          case CEILING:
256          case UP:
257            return (sqrtFloor * sqrtFloor == x) ? sqrtFloor : sqrtFloor + 1;
258          case HALF_DOWN:
259          case HALF_UP:
260          case HALF_EVEN:
261            long halfSquare = sqrtFloor * sqrtFloor + sqrtFloor;
262            /*
263             * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both
264             * x and halfSquare are integers, this is equivalent to testing whether or not x <=
265             * halfSquare. (We have to deal with overflow, though.)
266             */
267            return (halfSquare >= x | halfSquare < 0) ? sqrtFloor : sqrtFloor + 1;
268          default:
269            throw new AssertionError();
270        }
271      }
272    
273      private static long sqrtFloor(long x) {
274        // Hackers's Delight, Figure 11-1
275        long sqrt0 = (long) Math.sqrt(x);
276        // Precision can be lost in the cast to double, so we use this as a starting estimate.
277        long sqrt1 = (sqrt0 + (x / sqrt0)) >> 1;
278        if (sqrt1 == sqrt0) {
279          return sqrt0;
280        }
281        do {
282          sqrt0 = sqrt1;
283          sqrt1 = (sqrt0 + (x / sqrt0)) >> 1;
284        } while (sqrt1 < sqrt0);
285        return sqrt0;
286      }
287    
288      /**
289       * Returns the result of dividing {@code p} by {@code q}, rounding using the specified
290       * {@code RoundingMode}.
291       *
292       * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
293       *         is not an integer multiple of {@code b}
294       */
295      @SuppressWarnings("fallthrough")
296      public static long divide(long p, long q, RoundingMode mode) {
297        checkNotNull(mode);
298        long div = p / q; // throws if q == 0
299        long rem = p - q * div; // equals p % q
300    
301        if (rem == 0) {
302          return div;
303        }
304    
305        /*
306         * Normal Java division rounds towards 0, consistently with RoundingMode.DOWN. We just have to
307         * deal with the cases where rounding towards 0 is wrong, which typically depends on the sign of
308         * p / q.
309         *
310         * signum is 1 if p and q are both nonnegative or both negative, and -1 otherwise.
311         */
312        int signum = 1 | (int) ((p ^ q) >> (Long.SIZE - 1));
313        boolean increment;
314        switch (mode) {
315          case UNNECESSARY:
316            checkRoundingUnnecessary(rem == 0);
317            // fall through
318          case DOWN:
319            increment = false;
320            break;
321          case UP:
322            increment = true;
323            break;
324          case CEILING:
325            increment = signum > 0;
326            break;
327          case FLOOR:
328            increment = signum < 0;
329            break;
330          case HALF_EVEN:
331          case HALF_DOWN:
332          case HALF_UP:
333            long absRem = abs(rem);
334            long cmpRemToHalfDivisor = absRem - (abs(q) - absRem);
335            // subtracting two nonnegative longs can't overflow
336            // cmpRemToHalfDivisor has the same sign as compare(abs(rem), abs(q) / 2).
337            if (cmpRemToHalfDivisor == 0) { // exactly on the half mark
338              increment = (mode == HALF_UP | (mode == HALF_EVEN & (div & 1) != 0));
339            } else {
340              increment = cmpRemToHalfDivisor > 0; // closer to the UP value
341            }
342            break;
343          default:
344            throw new AssertionError();
345        }
346        return increment ? div + signum : div;
347      }
348    
349      /**
350       * Returns {@code x mod m}. This differs from {@code x % m} in that it always returns a
351       * non-negative result.
352       *
353       * <p>For example:
354       *
355       * <pre> {@code
356       *
357       * mod(7, 4) == 3
358       * mod(-7, 4) == 1
359       * mod(-1, 4) == 3
360       * mod(-8, 4) == 0
361       * mod(8, 4) == 0}</pre>
362       *
363       * @throws ArithmeticException if {@code m <= 0}
364       */
365      public static int mod(long x, int m) {
366        // Cast is safe because the result is guaranteed in the range [0, m)
367        return (int) mod(x, (long) m);
368      }
369    
370      /**
371       * Returns {@code x mod m}. This differs from {@code x % m} in that it always returns a
372       * non-negative result.
373       *
374       * <p>For example:
375       *
376       * <pre> {@code
377       *
378       * mod(7, 4) == 3
379       * mod(-7, 4) == 1
380       * mod(-1, 4) == 3
381       * mod(-8, 4) == 0
382       * mod(8, 4) == 0}</pre>
383       *
384       * @throws ArithmeticException if {@code m <= 0}
385       */
386      public static long mod(long x, long m) {
387        if (m <= 0) {
388          throw new ArithmeticException("Modulus " + m + " must be > 0");
389        }
390        long result = x % m;
391        return (result >= 0) ? result : result + m;
392      }
393    
394      /**
395       * Returns the greatest common divisor of {@code a, b}. Returns {@code 0} if
396       * {@code a == 0 && b == 0}.
397       *
398       * @throws IllegalArgumentException if {@code a < 0} or {@code b < 0}
399       */
400      public static long gcd(long a, long b) {
401        /*
402         * The reason we require both arguments to be >= 0 is because otherwise, what do you return on
403         * gcd(0, Long.MIN_VALUE)? BigInteger.gcd would return positive 2^63, but positive 2^63 isn't
404         * an int.
405         */
406        checkNonNegative("a", a);
407        checkNonNegative("b", b);
408        if (a == 0 | b == 0) {
409          return a | b;
410        }
411        /*
412         * Uses the binary GCD algorithm; see http://en.wikipedia.org/wiki/Binary_GCD_algorithm.
413         * This is over 40% faster than the Euclidean algorithm in benchmarks.
414         */
415        int aTwos = Long.numberOfTrailingZeros(a);
416        a >>= aTwos; // divide out all 2s
417        int bTwos = Long.numberOfTrailingZeros(b);
418        b >>= bTwos; // divide out all 2s
419        while (a != b) { // both a, b are odd
420          if (a < b) { // swap a, b
421            long t = b;
422            b = a;
423            a = t;
424          }
425          a -= b; // a is now positive and even
426          a >>= Long.numberOfTrailingZeros(a); // divide out all 2s, since 2 doesn't divide b
427        }
428        return a << min(aTwos, bTwos);
429      }
430    
431      /**
432       * Returns the sum of {@code a} and {@code b}, provided it does not overflow.
433       *
434       * @throws ArithmeticException if {@code a + b} overflows in signed {@code long} arithmetic
435       */
436      public static long checkedAdd(long a, long b) {
437        long result = a + b;
438        checkNoOverflow((a ^ b) < 0 | (a ^ result) >= 0);
439        return result;
440      }
441    
442      /**
443       * Returns the difference of {@code a} and {@code b}, provided it does not overflow.
444       *
445       * @throws ArithmeticException if {@code a - b} overflows in signed {@code long} arithmetic
446       */
447      public static long checkedSubtract(long a, long b) {
448        long result = a - b;
449        checkNoOverflow((a ^ b) >= 0 | (a ^ result) >= 0);
450        return result;
451      }
452    
453      /**
454       * Returns the product of {@code a} and {@code b}, provided it does not overflow.
455       *
456       * @throws ArithmeticException if {@code a * b} overflows in signed {@code long} arithmetic
457       */
458      public static long checkedMultiply(long a, long b) {
459        // Hacker's Delight, Section 2-12
460        int leadingZeros = Long.numberOfLeadingZeros(a) + Long.numberOfLeadingZeros(~a)
461            + Long.numberOfLeadingZeros(b) + Long.numberOfLeadingZeros(~b);
462        /*
463         * If leadingZeros > Long.SIZE + 1 it's definitely fine, if it's < Long.SIZE it's definitely
464         * bad. We do the leadingZeros check to avoid the division below if at all possible.
465         *
466         * Otherwise, if b == Long.MIN_VALUE, then the only allowed values of a are 0 and 1. We take
467         * care of all a < 0 with their own check, because in particular, the case a == -1 will
468         * incorrectly pass the division check below.
469         *
470         * In all other cases, we check that either a is 0 or the result is consistent with division.
471         */
472        if (leadingZeros > Long.SIZE + 1) {
473          return a * b;
474        }
475        checkNoOverflow(leadingZeros >= Long.SIZE);
476        checkNoOverflow(a >= 0 | b != Long.MIN_VALUE);
477        long result = a * b;
478        checkNoOverflow(a == 0 || result / a == b);
479        return result;
480      }
481    
482      /**
483       * Returns the {@code b} to the {@code k}th power, provided it does not overflow.
484       *
485       * @throws ArithmeticException if {@code b} to the {@code k}th power overflows in signed
486       *         {@code long} arithmetic
487       */
488      public static long checkedPow(long b, int k) {
489        checkNonNegative("exponent", k);
490        if (b >= -2 & b <= 2) {
491          switch ((int) b) {
492            case 0:
493              return (k == 0) ? 1 : 0;
494            case 1:
495              return 1;
496            case (-1):
497              return ((k & 1) == 0) ? 1 : -1;
498            case 2:
499              checkNoOverflow(k < Long.SIZE - 1);
500              return 1L << k;
501            case (-2):
502              checkNoOverflow(k < Long.SIZE);
503              return ((k & 1) == 0) ? (1L << k) : (-1L << k);
504          }
505        }
506        long accum = 1;
507        while (true) {
508          switch (k) {
509            case 0:
510              return accum;
511            case 1:
512              return checkedMultiply(accum, b);
513            default:
514              if ((k & 1) != 0) {
515                accum = checkedMultiply(accum, b);
516              }
517              k >>= 1;
518              if (k > 0) {
519                checkNoOverflow(b <= FLOOR_SQRT_MAX_LONG);
520                b *= b;
521              }
522          }
523        }
524      }
525    
526      @VisibleForTesting static final long FLOOR_SQRT_MAX_LONG = 3037000499L;
527    
528      /**
529       * Returns {@code n!}, that is, the product of the first {@code n} positive
530       * integers, {@code 1} if {@code n == 0}, or {@link Long#MAX_VALUE} if the
531       * result does not fit in a {@code long}.
532       *
533       * @throws IllegalArgumentException if {@code n < 0}
534       */
535      public static long factorial(int n) {
536        checkNonNegative("n", n);
537        return (n < FACTORIALS.length) ? FACTORIALS[n] : Long.MAX_VALUE;
538      }
539    
540      static final long[] FACTORIALS = {
541          1L,
542          1L,
543          1L * 2,
544          1L * 2 * 3,
545          1L * 2 * 3 * 4,
546          1L * 2 * 3 * 4 * 5,
547          1L * 2 * 3 * 4 * 5 * 6,
548          1L * 2 * 3 * 4 * 5 * 6 * 7,
549          1L * 2 * 3 * 4 * 5 * 6 * 7 * 8,
550          1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9,
551          1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10,
552          1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11,
553          1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12,
554          1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13,
555          1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14,
556          1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15,
557          1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16,
558          1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17,
559          1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18,
560          1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18 * 19,
561          1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18 * 19 * 20
562      };
563    
564      /**
565       * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
566       * {@code k}, or {@link Long#MAX_VALUE} if the result does not fit in a {@code long}.
567       *
568       * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
569       */
570      public static long binomial(int n, int k) {
571        checkNonNegative("n", n);
572        checkNonNegative("k", k);
573        checkArgument(k <= n, "k (%s) > n (%s)", k, n);
574        if (k > (n >> 1)) {
575          k = n - k;
576        }
577        if (k >= BIGGEST_BINOMIALS.length || n > BIGGEST_BINOMIALS[k]) {
578          return Long.MAX_VALUE;
579        }
580        long result = 1;
581        if (k < BIGGEST_SIMPLE_BINOMIALS.length && n <= BIGGEST_SIMPLE_BINOMIALS[k]) {
582          // guaranteed not to overflow
583          for (int i = 0; i < k; i++) {
584            result *= n - i;
585            result /= i + 1;
586          }
587        } else {
588          // We want to do this in long math for speed, but want to avoid overflow.
589          // Dividing by the GCD suffices to avoid overflow in all the remaining cases.
590          for (int i = 1; i <= k; i++, n--) {
591            int d = IntMath.gcd(n, i);
592            result /= i / d; // (i/d) is guaranteed to divide result
593            result *= n / d;
594          }
595        }
596        return result;
597      }
598    
599      /*
600       * binomial(BIGGEST_BINOMIALS[k], k) fits in a long, but not
601       * binomial(BIGGEST_BINOMIALS[k] + 1, k).
602       */
603      static final int[] BIGGEST_BINOMIALS =
604          {Integer.MAX_VALUE, Integer.MAX_VALUE, Integer.MAX_VALUE, 3810779, 121977, 16175, 4337, 1733,
605              887, 534, 361, 265, 206, 169, 143, 125, 111, 101, 94, 88, 83, 79, 76, 74, 72, 70, 69, 68,
606              67, 67, 66, 66, 66, 66};
607    
608      /*
609       * binomial(BIGGEST_SIMPLE_BINOMIALS[k], k) doesn't need to use the slower GCD-based impl,
610       * but binomial(BIGGEST_SIMPLE_BINOMIALS[k] + 1, k) does.
611       */
612      @VisibleForTesting static final int[] BIGGEST_SIMPLE_BINOMIALS =
613          {Integer.MAX_VALUE, Integer.MAX_VALUE, Integer.MAX_VALUE, 2642246, 86251, 11724, 3218, 1313,
614              684, 419, 287, 214, 169, 139, 119, 105, 95, 87, 81, 76, 73, 70, 68, 66, 64, 63, 62, 62,
615              61, 61, 61};
616      // These values were generated by using checkedMultiply to see when the simple multiply/divide
617      // algorithm would lead to an overflow.
618    
619      static boolean fitsInInt(long x) {
620        return (int) x == x;
621      }
622    
623      private LongMath() {}
624    }