001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkState;
021import static com.google.common.collect.CollectPreconditions.checkNonnegative;
022import static com.google.common.collect.NullnessCasts.uncheckedCastNullableTToT;
023import static java.util.Objects.requireNonNull;
024
025import com.google.common.annotations.GwtCompatible;
026import com.google.common.annotations.GwtIncompatible;
027import com.google.common.base.MoreObjects;
028import com.google.common.primitives.Ints;
029import com.google.errorprone.annotations.CanIgnoreReturnValue;
030import java.io.IOException;
031import java.io.ObjectInputStream;
032import java.io.ObjectOutputStream;
033import java.io.Serializable;
034import java.util.Comparator;
035import java.util.ConcurrentModificationException;
036import java.util.Iterator;
037import java.util.NoSuchElementException;
038import javax.annotation.CheckForNull;
039import org.checkerframework.checker.nullness.qual.Nullable;
040
041/**
042 * A multiset which maintains the ordering of its elements, according to either their natural order
043 * or an explicit {@link Comparator}. In all cases, this implementation uses {@link
044 * Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to determine
045 * equivalence of instances.
046 *
047 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
048 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the {@link
049 * java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
050 *
051 * <p>See the Guava User Guide article on <a href=
052 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset"> {@code
053 * Multiset}</a>.
054 *
055 * @author Louis Wasserman
056 * @author Jared Levy
057 * @since 2.0
058 */
059@GwtCompatible(emulated = true)
060@ElementTypesAreNonnullByDefault
061public final class TreeMultiset<E extends @Nullable Object> extends AbstractSortedMultiset<E>
062    implements Serializable {
063
064  /**
065   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
066   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
067   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
068   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
069   * user attempts to add an element to the multiset that violates this constraint (for example, the
070   * user attempts to add a string element to a set whose elements are integers), the {@code
071   * add(Object)} call will throw a {@code ClassCastException}.
072   *
073   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
074   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
075   */
076  public static <E extends Comparable> TreeMultiset<E> create() {
077    return new TreeMultiset<E>(Ordering.natural());
078  }
079
080  /**
081   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
082   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
083   * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements
084   * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the
085   * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code
086   * ClassCastException}.
087   *
088   * @param comparator the comparator that will be used to sort this multiset. A null value
089   *     indicates that the elements' <i>natural ordering</i> should be used.
090   */
091  @SuppressWarnings("unchecked")
092  public static <E extends @Nullable Object> TreeMultiset<E> create(
093      @CheckForNull Comparator<? super E> comparator) {
094    return (comparator == null)
095        ? new TreeMultiset<E>((Comparator) Ordering.natural())
096        : new TreeMultiset<E>(comparator);
097  }
098
099  /**
100   * Creates an empty multiset containing the given initial elements, sorted according to the
101   * elements' natural order.
102   *
103   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
104   *
105   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
106   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
107   */
108  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
109    TreeMultiset<E> multiset = create();
110    Iterables.addAll(multiset, elements);
111    return multiset;
112  }
113
114  private final transient Reference<AvlNode<E>> rootReference;
115  private final transient GeneralRange<E> range;
116  private final transient AvlNode<E> header;
117
118  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
119    super(range.comparator());
120    this.rootReference = rootReference;
121    this.range = range;
122    this.header = endLink;
123  }
124
125  TreeMultiset(Comparator<? super E> comparator) {
126    super(comparator);
127    this.range = GeneralRange.all(comparator);
128    this.header = new AvlNode<>();
129    successor(header, header);
130    this.rootReference = new Reference<>();
131  }
132
133  /** A function which can be summed across a subtree. */
134  private enum Aggregate {
135    SIZE {
136      @Override
137      int nodeAggregate(AvlNode<?> node) {
138        return node.elemCount;
139      }
140
141      @Override
142      long treeAggregate(@CheckForNull AvlNode<?> root) {
143        return (root == null) ? 0 : root.totalCount;
144      }
145    },
146    DISTINCT {
147      @Override
148      int nodeAggregate(AvlNode<?> node) {
149        return 1;
150      }
151
152      @Override
153      long treeAggregate(@CheckForNull AvlNode<?> root) {
154        return (root == null) ? 0 : root.distinctElements;
155      }
156    };
157
158    abstract int nodeAggregate(AvlNode<?> node);
159
160    abstract long treeAggregate(@CheckForNull AvlNode<?> root);
161  }
162
163  private long aggregateForEntries(Aggregate aggr) {
164    AvlNode<E> root = rootReference.get();
165    long total = aggr.treeAggregate(root);
166    if (range.hasLowerBound()) {
167      total -= aggregateBelowRange(aggr, root);
168    }
169    if (range.hasUpperBound()) {
170      total -= aggregateAboveRange(aggr, root);
171    }
172    return total;
173  }
174
175  private long aggregateBelowRange(Aggregate aggr, @CheckForNull AvlNode<E> node) {
176    if (node == null) {
177      return 0;
178    }
179    // The cast is safe because we call this method only if hasLowerBound().
180    int cmp =
181        comparator()
182            .compare(uncheckedCastNullableTToT(range.getLowerEndpoint()), node.getElement());
183    if (cmp < 0) {
184      return aggregateBelowRange(aggr, node.left);
185    } else if (cmp == 0) {
186      switch (range.getLowerBoundType()) {
187        case OPEN:
188          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
189        case CLOSED:
190          return aggr.treeAggregate(node.left);
191        default:
192          throw new AssertionError();
193      }
194    } else {
195      return aggr.treeAggregate(node.left)
196          + aggr.nodeAggregate(node)
197          + aggregateBelowRange(aggr, node.right);
198    }
199  }
200
201  private long aggregateAboveRange(Aggregate aggr, @CheckForNull AvlNode<E> node) {
202    if (node == null) {
203      return 0;
204    }
205    // The cast is safe because we call this method only if hasUpperBound().
206    int cmp =
207        comparator()
208            .compare(uncheckedCastNullableTToT(range.getUpperEndpoint()), node.getElement());
209    if (cmp > 0) {
210      return aggregateAboveRange(aggr, node.right);
211    } else if (cmp == 0) {
212      switch (range.getUpperBoundType()) {
213        case OPEN:
214          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
215        case CLOSED:
216          return aggr.treeAggregate(node.right);
217        default:
218          throw new AssertionError();
219      }
220    } else {
221      return aggr.treeAggregate(node.right)
222          + aggr.nodeAggregate(node)
223          + aggregateAboveRange(aggr, node.left);
224    }
225  }
226
227  @Override
228  public int size() {
229    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
230  }
231
232  @Override
233  int distinctElements() {
234    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
235  }
236
237  static int distinctElements(@CheckForNull AvlNode<?> node) {
238    return (node == null) ? 0 : node.distinctElements;
239  }
240
241  @Override
242  public int count(@CheckForNull Object element) {
243    try {
244      @SuppressWarnings("unchecked")
245      E e = (E) element;
246      AvlNode<E> root = rootReference.get();
247      if (!range.contains(e) || root == null) {
248        return 0;
249      }
250      return root.count(comparator(), e);
251    } catch (ClassCastException | NullPointerException e) {
252      return 0;
253    }
254  }
255
256  @CanIgnoreReturnValue
257  @Override
258  public int add(@ParametricNullness E element, int occurrences) {
259    checkNonnegative(occurrences, "occurrences");
260    if (occurrences == 0) {
261      return count(element);
262    }
263    checkArgument(range.contains(element));
264    AvlNode<E> root = rootReference.get();
265    if (root == null) {
266      comparator().compare(element, element);
267      AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
268      successor(header, newRoot, header);
269      rootReference.checkAndSet(root, newRoot);
270      return 0;
271    }
272    int[] result = new int[1]; // used as a mutable int reference to hold result
273    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
274    rootReference.checkAndSet(root, newRoot);
275    return result[0];
276  }
277
278  @CanIgnoreReturnValue
279  @Override
280  public int remove(@CheckForNull Object element, int occurrences) {
281    checkNonnegative(occurrences, "occurrences");
282    if (occurrences == 0) {
283      return count(element);
284    }
285    AvlNode<E> root = rootReference.get();
286    int[] result = new int[1]; // used as a mutable int reference to hold result
287    AvlNode<E> newRoot;
288    try {
289      @SuppressWarnings("unchecked")
290      E e = (E) element;
291      if (!range.contains(e) || root == null) {
292        return 0;
293      }
294      newRoot = root.remove(comparator(), e, occurrences, result);
295    } catch (ClassCastException | NullPointerException e) {
296      return 0;
297    }
298    rootReference.checkAndSet(root, newRoot);
299    return result[0];
300  }
301
302  @CanIgnoreReturnValue
303  @Override
304  public int setCount(@ParametricNullness E element, int count) {
305    checkNonnegative(count, "count");
306    if (!range.contains(element)) {
307      checkArgument(count == 0);
308      return 0;
309    }
310
311    AvlNode<E> root = rootReference.get();
312    if (root == null) {
313      if (count > 0) {
314        add(element, count);
315      }
316      return 0;
317    }
318    int[] result = new int[1]; // used as a mutable int reference to hold result
319    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
320    rootReference.checkAndSet(root, newRoot);
321    return result[0];
322  }
323
324  @CanIgnoreReturnValue
325  @Override
326  public boolean setCount(@ParametricNullness E element, int oldCount, int newCount) {
327    checkNonnegative(newCount, "newCount");
328    checkNonnegative(oldCount, "oldCount");
329    checkArgument(range.contains(element));
330
331    AvlNode<E> root = rootReference.get();
332    if (root == null) {
333      if (oldCount == 0) {
334        if (newCount > 0) {
335          add(element, newCount);
336        }
337        return true;
338      } else {
339        return false;
340      }
341    }
342    int[] result = new int[1]; // used as a mutable int reference to hold result
343    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
344    rootReference.checkAndSet(root, newRoot);
345    return result[0] == oldCount;
346  }
347
348  @Override
349  public void clear() {
350    if (!range.hasLowerBound() && !range.hasUpperBound()) {
351      // We can do this in O(n) rather than removing one by one, which could force rebalancing.
352      for (AvlNode<E> current = header.succ(); current != header; ) {
353        AvlNode<E> next = current.succ();
354
355        current.elemCount = 0;
356        // Also clear these fields so that one deleted Entry doesn't retain all elements.
357        current.left = null;
358        current.right = null;
359        current.pred = null;
360        current.succ = null;
361
362        current = next;
363      }
364      successor(header, header);
365      rootReference.clear();
366    } else {
367      // TODO(cpovirk): Perhaps we can optimize in this case, too?
368      Iterators.clear(entryIterator());
369    }
370  }
371
372  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
373    return new Multisets.AbstractEntry<E>() {
374      @Override
375      @ParametricNullness
376      public E getElement() {
377        return baseEntry.getElement();
378      }
379
380      @Override
381      public int getCount() {
382        int result = baseEntry.getCount();
383        if (result == 0) {
384          return count(getElement());
385        } else {
386          return result;
387        }
388      }
389    };
390  }
391
392  /** Returns the first node in the tree that is in range. */
393  @CheckForNull
394  private AvlNode<E> firstNode() {
395    AvlNode<E> root = rootReference.get();
396    if (root == null) {
397      return null;
398    }
399    AvlNode<E> node;
400    if (range.hasLowerBound()) {
401      // The cast is safe because of the hasLowerBound check.
402      E endpoint = uncheckedCastNullableTToT(range.getLowerEndpoint());
403      node = root.ceiling(comparator(), endpoint);
404      if (node == null) {
405        return null;
406      }
407      if (range.getLowerBoundType() == BoundType.OPEN
408          && comparator().compare(endpoint, node.getElement()) == 0) {
409        node = node.succ();
410      }
411    } else {
412      node = header.succ();
413    }
414    return (node == header || !range.contains(node.getElement())) ? null : node;
415  }
416
417  @CheckForNull
418  private AvlNode<E> lastNode() {
419    AvlNode<E> root = rootReference.get();
420    if (root == null) {
421      return null;
422    }
423    AvlNode<E> node;
424    if (range.hasUpperBound()) {
425      // The cast is safe because of the hasUpperBound check.
426      E endpoint = uncheckedCastNullableTToT(range.getUpperEndpoint());
427      node = root.floor(comparator(), endpoint);
428      if (node == null) {
429        return null;
430      }
431      if (range.getUpperBoundType() == BoundType.OPEN
432          && comparator().compare(endpoint, node.getElement()) == 0) {
433        node = node.pred();
434      }
435    } else {
436      node = header.pred();
437    }
438    return (node == header || !range.contains(node.getElement())) ? null : node;
439  }
440
441  @Override
442  Iterator<E> elementIterator() {
443    return Multisets.elementIterator(entryIterator());
444  }
445
446  @Override
447  Iterator<Entry<E>> entryIterator() {
448    return new Iterator<Entry<E>>() {
449      @CheckForNull AvlNode<E> current = firstNode();
450      @CheckForNull Entry<E> prevEntry;
451
452      @Override
453      public boolean hasNext() {
454        if (current == null) {
455          return false;
456        } else if (range.tooHigh(current.getElement())) {
457          current = null;
458          return false;
459        } else {
460          return true;
461        }
462      }
463
464      @Override
465      public Entry<E> next() {
466        if (!hasNext()) {
467          throw new NoSuchElementException();
468        }
469        // requireNonNull is safe because current is only nulled out after iteration is complete.
470        Entry<E> result = wrapEntry(requireNonNull(current));
471        prevEntry = result;
472        if (current.succ() == header) {
473          current = null;
474        } else {
475          current = current.succ();
476        }
477        return result;
478      }
479
480      @Override
481      public void remove() {
482        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
483        setCount(prevEntry.getElement(), 0);
484        prevEntry = null;
485      }
486    };
487  }
488
489  @Override
490  Iterator<Entry<E>> descendingEntryIterator() {
491    return new Iterator<Entry<E>>() {
492      @CheckForNull AvlNode<E> current = lastNode();
493      @CheckForNull Entry<E> prevEntry = null;
494
495      @Override
496      public boolean hasNext() {
497        if (current == null) {
498          return false;
499        } else if (range.tooLow(current.getElement())) {
500          current = null;
501          return false;
502        } else {
503          return true;
504        }
505      }
506
507      @Override
508      public Entry<E> next() {
509        if (!hasNext()) {
510          throw new NoSuchElementException();
511        }
512        // requireNonNull is safe because current is only nulled out after iteration is complete.
513        requireNonNull(current);
514        Entry<E> result = wrapEntry(current);
515        prevEntry = result;
516        if (current.pred() == header) {
517          current = null;
518        } else {
519          current = current.pred();
520        }
521        return result;
522      }
523
524      @Override
525      public void remove() {
526        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
527        setCount(prevEntry.getElement(), 0);
528        prevEntry = null;
529      }
530    };
531  }
532
533  @Override
534  public Iterator<E> iterator() {
535    return Multisets.iteratorImpl(this);
536  }
537
538  @Override
539  public SortedMultiset<E> headMultiset(@ParametricNullness E upperBound, BoundType boundType) {
540    return new TreeMultiset<E>(
541        rootReference,
542        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
543        header);
544  }
545
546  @Override
547  public SortedMultiset<E> tailMultiset(@ParametricNullness E lowerBound, BoundType boundType) {
548    return new TreeMultiset<E>(
549        rootReference,
550        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
551        header);
552  }
553
554  private static final class Reference<T> {
555    @CheckForNull private T value;
556
557    @CheckForNull
558    public T get() {
559      return value;
560    }
561
562    public void checkAndSet(@CheckForNull T expected, @CheckForNull T newValue) {
563      if (value != expected) {
564        throw new ConcurrentModificationException();
565      }
566      value = newValue;
567    }
568
569    void clear() {
570      value = null;
571    }
572  }
573
574  private static final class AvlNode<E extends @Nullable Object> {
575    /*
576     * For "normal" nodes, the type of this field is `E`, not `@Nullable E` (though note that E is a
577     * type that can include null, as in a TreeMultiset<@Nullable String>).
578     *
579     * For the header node, though, this field contains `null`, regardless of the type of the
580     * multiset.
581     *
582     * Most code that operates on an AvlNode never operates on the header node. Such code can access
583     * the elem field without a null check by calling getElement().
584     */
585    @CheckForNull private final E elem;
586
587    // elemCount is 0 iff this node has been deleted.
588    private int elemCount;
589
590    private int distinctElements;
591    private long totalCount;
592    private int height;
593    @CheckForNull private AvlNode<E> left;
594    @CheckForNull private AvlNode<E> right;
595    /*
596     * pred and succ are nullable after construction, but we always call successor() to initialize
597     * them immediately thereafter.
598     *
599     * They may be subsequently nulled out by TreeMultiset.clear(). I think that the only place that
600     * we can reference a node whose fields have been cleared is inside the iterator (and presumably
601     * only under concurrent modification).
602     *
603     * To access these fields when you know that they are not null, call the pred() and succ()
604     * methods, which perform null checks before returning the fields.
605     */
606    @CheckForNull private AvlNode<E> pred;
607    @CheckForNull private AvlNode<E> succ;
608
609    AvlNode(@ParametricNullness E elem, int elemCount) {
610      checkArgument(elemCount > 0);
611      this.elem = elem;
612      this.elemCount = elemCount;
613      this.totalCount = elemCount;
614      this.distinctElements = 1;
615      this.height = 1;
616      this.left = null;
617      this.right = null;
618    }
619
620    /** Constructor for the header node. */
621    AvlNode() {
622      this.elem = null;
623      this.elemCount = 1;
624    }
625
626    // For discussion of pred() and succ(), see the comment on the pred and succ fields.
627
628    private AvlNode<E> pred() {
629      return requireNonNull(pred);
630    }
631
632    private AvlNode<E> succ() {
633      return requireNonNull(succ);
634    }
635
636    int count(Comparator<? super E> comparator, @ParametricNullness E e) {
637      int cmp = comparator.compare(e, getElement());
638      if (cmp < 0) {
639        return (left == null) ? 0 : left.count(comparator, e);
640      } else if (cmp > 0) {
641        return (right == null) ? 0 : right.count(comparator, e);
642      } else {
643        return elemCount;
644      }
645    }
646
647    private AvlNode<E> addRightChild(@ParametricNullness E e, int count) {
648      right = new AvlNode<E>(e, count);
649      successor(this, right, succ());
650      height = Math.max(2, height);
651      distinctElements++;
652      totalCount += count;
653      return this;
654    }
655
656    private AvlNode<E> addLeftChild(@ParametricNullness E e, int count) {
657      left = new AvlNode<E>(e, count);
658      successor(pred(), left, this);
659      height = Math.max(2, height);
660      distinctElements++;
661      totalCount += count;
662      return this;
663    }
664
665    AvlNode<E> add(
666        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
667      /*
668       * It speeds things up considerably to unconditionally add count to totalCount here,
669       * but that destroys failure atomicity in the case of count overflow. =(
670       */
671      int cmp = comparator.compare(e, getElement());
672      if (cmp < 0) {
673        AvlNode<E> initLeft = left;
674        if (initLeft == null) {
675          result[0] = 0;
676          return addLeftChild(e, count);
677        }
678        int initHeight = initLeft.height;
679
680        left = initLeft.add(comparator, e, count, result);
681        if (result[0] == 0) {
682          distinctElements++;
683        }
684        this.totalCount += count;
685        return (left.height == initHeight) ? this : rebalance();
686      } else if (cmp > 0) {
687        AvlNode<E> initRight = right;
688        if (initRight == null) {
689          result[0] = 0;
690          return addRightChild(e, count);
691        }
692        int initHeight = initRight.height;
693
694        right = initRight.add(comparator, e, count, result);
695        if (result[0] == 0) {
696          distinctElements++;
697        }
698        this.totalCount += count;
699        return (right.height == initHeight) ? this : rebalance();
700      }
701
702      // adding count to me!  No rebalance possible.
703      result[0] = elemCount;
704      long resultCount = (long) elemCount + count;
705      checkArgument(resultCount <= Integer.MAX_VALUE);
706      this.elemCount += count;
707      this.totalCount += count;
708      return this;
709    }
710
711    @CheckForNull
712    AvlNode<E> remove(
713        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
714      int cmp = comparator.compare(e, getElement());
715      if (cmp < 0) {
716        AvlNode<E> initLeft = left;
717        if (initLeft == null) {
718          result[0] = 0;
719          return this;
720        }
721
722        left = initLeft.remove(comparator, e, count, result);
723
724        if (result[0] > 0) {
725          if (count >= result[0]) {
726            this.distinctElements--;
727            this.totalCount -= result[0];
728          } else {
729            this.totalCount -= count;
730          }
731        }
732        return (result[0] == 0) ? this : rebalance();
733      } else if (cmp > 0) {
734        AvlNode<E> initRight = right;
735        if (initRight == null) {
736          result[0] = 0;
737          return this;
738        }
739
740        right = initRight.remove(comparator, e, count, result);
741
742        if (result[0] > 0) {
743          if (count >= result[0]) {
744            this.distinctElements--;
745            this.totalCount -= result[0];
746          } else {
747            this.totalCount -= count;
748          }
749        }
750        return rebalance();
751      }
752
753      // removing count from me!
754      result[0] = elemCount;
755      if (count >= elemCount) {
756        return deleteMe();
757      } else {
758        this.elemCount -= count;
759        this.totalCount -= count;
760        return this;
761      }
762    }
763
764    @CheckForNull
765    AvlNode<E> setCount(
766        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
767      int cmp = comparator.compare(e, getElement());
768      if (cmp < 0) {
769        AvlNode<E> initLeft = left;
770        if (initLeft == null) {
771          result[0] = 0;
772          return (count > 0) ? addLeftChild(e, count) : this;
773        }
774
775        left = initLeft.setCount(comparator, e, count, result);
776
777        if (count == 0 && result[0] != 0) {
778          this.distinctElements--;
779        } else if (count > 0 && result[0] == 0) {
780          this.distinctElements++;
781        }
782
783        this.totalCount += count - result[0];
784        return rebalance();
785      } else if (cmp > 0) {
786        AvlNode<E> initRight = right;
787        if (initRight == null) {
788          result[0] = 0;
789          return (count > 0) ? addRightChild(e, count) : this;
790        }
791
792        right = initRight.setCount(comparator, e, count, result);
793
794        if (count == 0 && result[0] != 0) {
795          this.distinctElements--;
796        } else if (count > 0 && result[0] == 0) {
797          this.distinctElements++;
798        }
799
800        this.totalCount += count - result[0];
801        return rebalance();
802      }
803
804      // setting my count
805      result[0] = elemCount;
806      if (count == 0) {
807        return deleteMe();
808      }
809      this.totalCount += count - elemCount;
810      this.elemCount = count;
811      return this;
812    }
813
814    @CheckForNull
815    AvlNode<E> setCount(
816        Comparator<? super E> comparator,
817        @ParametricNullness E e,
818        int expectedCount,
819        int newCount,
820        int[] result) {
821      int cmp = comparator.compare(e, getElement());
822      if (cmp < 0) {
823        AvlNode<E> initLeft = left;
824        if (initLeft == null) {
825          result[0] = 0;
826          if (expectedCount == 0 && newCount > 0) {
827            return addLeftChild(e, newCount);
828          }
829          return this;
830        }
831
832        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
833
834        if (result[0] == expectedCount) {
835          if (newCount == 0 && result[0] != 0) {
836            this.distinctElements--;
837          } else if (newCount > 0 && result[0] == 0) {
838            this.distinctElements++;
839          }
840          this.totalCount += newCount - result[0];
841        }
842        return rebalance();
843      } else if (cmp > 0) {
844        AvlNode<E> initRight = right;
845        if (initRight == null) {
846          result[0] = 0;
847          if (expectedCount == 0 && newCount > 0) {
848            return addRightChild(e, newCount);
849          }
850          return this;
851        }
852
853        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
854
855        if (result[0] == expectedCount) {
856          if (newCount == 0 && result[0] != 0) {
857            this.distinctElements--;
858          } else if (newCount > 0 && result[0] == 0) {
859            this.distinctElements++;
860          }
861          this.totalCount += newCount - result[0];
862        }
863        return rebalance();
864      }
865
866      // setting my count
867      result[0] = elemCount;
868      if (expectedCount == elemCount) {
869        if (newCount == 0) {
870          return deleteMe();
871        }
872        this.totalCount += newCount - elemCount;
873        this.elemCount = newCount;
874      }
875      return this;
876    }
877
878    @CheckForNull
879    private AvlNode<E> deleteMe() {
880      int oldElemCount = this.elemCount;
881      this.elemCount = 0;
882      successor(pred(), succ());
883      if (left == null) {
884        return right;
885      } else if (right == null) {
886        return left;
887      } else if (left.height >= right.height) {
888        AvlNode<E> newTop = pred();
889        // newTop is the maximum node in my left subtree
890        newTop.left = left.removeMax(newTop);
891        newTop.right = right;
892        newTop.distinctElements = distinctElements - 1;
893        newTop.totalCount = totalCount - oldElemCount;
894        return newTop.rebalance();
895      } else {
896        AvlNode<E> newTop = succ();
897        newTop.right = right.removeMin(newTop);
898        newTop.left = left;
899        newTop.distinctElements = distinctElements - 1;
900        newTop.totalCount = totalCount - oldElemCount;
901        return newTop.rebalance();
902      }
903    }
904
905    // Removes the minimum node from this subtree to be reused elsewhere
906    @CheckForNull
907    private AvlNode<E> removeMin(AvlNode<E> node) {
908      if (left == null) {
909        return right;
910      } else {
911        left = left.removeMin(node);
912        distinctElements--;
913        totalCount -= node.elemCount;
914        return rebalance();
915      }
916    }
917
918    // Removes the maximum node from this subtree to be reused elsewhere
919    @CheckForNull
920    private AvlNode<E> removeMax(AvlNode<E> node) {
921      if (right == null) {
922        return left;
923      } else {
924        right = right.removeMax(node);
925        distinctElements--;
926        totalCount -= node.elemCount;
927        return rebalance();
928      }
929    }
930
931    private void recomputeMultiset() {
932      this.distinctElements =
933          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
934      this.totalCount = elemCount + totalCount(left) + totalCount(right);
935    }
936
937    private void recomputeHeight() {
938      this.height = 1 + Math.max(height(left), height(right));
939    }
940
941    private void recompute() {
942      recomputeMultiset();
943      recomputeHeight();
944    }
945
946    private AvlNode<E> rebalance() {
947      switch (balanceFactor()) {
948        case -2:
949          // requireNonNull is safe because right must exist in order to get a negative factor.
950          requireNonNull(right);
951          if (right.balanceFactor() > 0) {
952            right = right.rotateRight();
953          }
954          return rotateLeft();
955        case 2:
956          // requireNonNull is safe because left must exist in order to get a positive factor.
957          requireNonNull(left);
958          if (left.balanceFactor() < 0) {
959            left = left.rotateLeft();
960          }
961          return rotateRight();
962        default:
963          recomputeHeight();
964          return this;
965      }
966    }
967
968    private int balanceFactor() {
969      return height(left) - height(right);
970    }
971
972    private AvlNode<E> rotateLeft() {
973      checkState(right != null);
974      AvlNode<E> newTop = right;
975      this.right = newTop.left;
976      newTop.left = this;
977      newTop.totalCount = this.totalCount;
978      newTop.distinctElements = this.distinctElements;
979      this.recompute();
980      newTop.recomputeHeight();
981      return newTop;
982    }
983
984    private AvlNode<E> rotateRight() {
985      checkState(left != null);
986      AvlNode<E> newTop = left;
987      this.left = newTop.right;
988      newTop.right = this;
989      newTop.totalCount = this.totalCount;
990      newTop.distinctElements = this.distinctElements;
991      this.recompute();
992      newTop.recomputeHeight();
993      return newTop;
994    }
995
996    private static long totalCount(@CheckForNull AvlNode<?> node) {
997      return (node == null) ? 0 : node.totalCount;
998    }
999
1000    private static int height(@CheckForNull AvlNode<?> node) {
1001      return (node == null) ? 0 : node.height;
1002    }
1003
1004    @CheckForNull
1005    private AvlNode<E> ceiling(Comparator<? super E> comparator, @ParametricNullness E e) {
1006      int cmp = comparator.compare(e, getElement());
1007      if (cmp < 0) {
1008        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
1009      } else if (cmp == 0) {
1010        return this;
1011      } else {
1012        return (right == null) ? null : right.ceiling(comparator, e);
1013      }
1014    }
1015
1016    @CheckForNull
1017    private AvlNode<E> floor(Comparator<? super E> comparator, @ParametricNullness E e) {
1018      int cmp = comparator.compare(e, getElement());
1019      if (cmp > 0) {
1020        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
1021      } else if (cmp == 0) {
1022        return this;
1023      } else {
1024        return (left == null) ? null : left.floor(comparator, e);
1025      }
1026    }
1027
1028    @ParametricNullness
1029    E getElement() {
1030      // For discussion of this cast, see the comment on the elem field.
1031      return uncheckedCastNullableTToT(elem);
1032    }
1033
1034    int getCount() {
1035      return elemCount;
1036    }
1037
1038    @Override
1039    public String toString() {
1040      return Multisets.immutableEntry(getElement(), getCount()).toString();
1041    }
1042  }
1043
1044  private static <T extends @Nullable Object> void successor(AvlNode<T> a, AvlNode<T> b) {
1045    a.succ = b;
1046    b.pred = a;
1047  }
1048
1049  private static <T extends @Nullable Object> void successor(
1050      AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
1051    successor(a, b);
1052    successor(b, c);
1053  }
1054
1055  /*
1056   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
1057   * calls the comparator to compare the two keys. If that change is made,
1058   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
1059   */
1060
1061  /**
1062   * @serialData the comparator, the number of distinct elements, the first element, its count, the
1063   *     second element, its count, and so on
1064   */
1065  @GwtIncompatible // java.io.ObjectOutputStream
1066  private void writeObject(ObjectOutputStream stream) throws IOException {
1067    stream.defaultWriteObject();
1068    stream.writeObject(elementSet().comparator());
1069    Serialization.writeMultiset(this, stream);
1070  }
1071
1072  @GwtIncompatible // java.io.ObjectInputStream
1073  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
1074    stream.defaultReadObject();
1075    @SuppressWarnings("unchecked")
1076    // reading data stored by writeObject
1077    Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
1078    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
1079    Serialization.getFieldSetter(TreeMultiset.class, "range")
1080        .set(this, GeneralRange.all(comparator));
1081    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
1082        .set(this, new Reference<AvlNode<E>>());
1083    AvlNode<E> header = new AvlNode<>();
1084    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
1085    successor(header, header);
1086    Serialization.populateMultiset(this, stream);
1087  }
1088
1089  @GwtIncompatible // not needed in emulated source
1090  private static final long serialVersionUID = 1;
1091}