001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkState;
021import static com.google.common.collect.CollectPreconditions.checkNonnegative;
022import static com.google.common.collect.NullnessCasts.uncheckedCastNullableTToT;
023import static java.util.Objects.requireNonNull;
024
025import com.google.common.annotations.GwtCompatible;
026import com.google.common.annotations.GwtIncompatible;
027import com.google.common.base.MoreObjects;
028import com.google.common.primitives.Ints;
029import com.google.errorprone.annotations.CanIgnoreReturnValue;
030import java.io.IOException;
031import java.io.ObjectInputStream;
032import java.io.ObjectOutputStream;
033import java.io.Serializable;
034import java.util.Comparator;
035import java.util.ConcurrentModificationException;
036import java.util.Iterator;
037import java.util.NoSuchElementException;
038import javax.annotation.CheckForNull;
039import org.checkerframework.checker.nullness.qual.Nullable;
040
041/**
042 * A multiset which maintains the ordering of its elements, according to either their natural order
043 * or an explicit {@link Comparator}. In all cases, this implementation uses {@link
044 * Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to determine
045 * equivalence of instances.
046 *
047 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
048 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the {@link
049 * java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
050 *
051 * <p>See the Guava User Guide article on <a href=
052 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset">{@code Multiset}</a>.
053 *
054 * @author Louis Wasserman
055 * @author Jared Levy
056 * @since 2.0
057 */
058@GwtCompatible(emulated = true)
059@ElementTypesAreNonnullByDefault
060public final class TreeMultiset<E extends @Nullable Object> extends AbstractSortedMultiset<E>
061    implements Serializable {
062
063  /**
064   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
065   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
066   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
067   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
068   * user attempts to add an element to the multiset that violates this constraint (for example, the
069   * user attempts to add a string element to a set whose elements are integers), the {@code
070   * add(Object)} call will throw a {@code ClassCastException}.
071   *
072   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
073   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
074   */
075  public static <E extends Comparable> TreeMultiset<E> create() {
076    return new TreeMultiset<E>(Ordering.natural());
077  }
078
079  /**
080   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
081   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
082   * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements
083   * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the
084   * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code
085   * ClassCastException}.
086   *
087   * @param comparator the comparator that will be used to sort this multiset. A null value
088   *     indicates that the elements' <i>natural ordering</i> should be used.
089   */
090  @SuppressWarnings("unchecked")
091  public static <E extends @Nullable Object> TreeMultiset<E> create(
092      @CheckForNull Comparator<? super E> comparator) {
093    return (comparator == null)
094        ? new TreeMultiset<E>((Comparator) Ordering.natural())
095        : new TreeMultiset<E>(comparator);
096  }
097
098  /**
099   * Creates an empty multiset containing the given initial elements, sorted according to the
100   * elements' natural order.
101   *
102   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
103   *
104   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
105   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
106   */
107  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
108    TreeMultiset<E> multiset = create();
109    Iterables.addAll(multiset, elements);
110    return multiset;
111  }
112
113  private final transient Reference<AvlNode<E>> rootReference;
114  private final transient GeneralRange<E> range;
115  private final transient AvlNode<E> header;
116
117  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
118    super(range.comparator());
119    this.rootReference = rootReference;
120    this.range = range;
121    this.header = endLink;
122  }
123
124  TreeMultiset(Comparator<? super E> comparator) {
125    super(comparator);
126    this.range = GeneralRange.all(comparator);
127    this.header = new AvlNode<>();
128    successor(header, header);
129    this.rootReference = new Reference<>();
130  }
131
132  /** A function which can be summed across a subtree. */
133  private enum Aggregate {
134    SIZE {
135      @Override
136      int nodeAggregate(AvlNode<?> node) {
137        return node.elemCount;
138      }
139
140      @Override
141      long treeAggregate(@CheckForNull AvlNode<?> root) {
142        return (root == null) ? 0 : root.totalCount;
143      }
144    },
145    DISTINCT {
146      @Override
147      int nodeAggregate(AvlNode<?> node) {
148        return 1;
149      }
150
151      @Override
152      long treeAggregate(@CheckForNull AvlNode<?> root) {
153        return (root == null) ? 0 : root.distinctElements;
154      }
155    };
156
157    abstract int nodeAggregate(AvlNode<?> node);
158
159    abstract long treeAggregate(@CheckForNull AvlNode<?> root);
160  }
161
162  private long aggregateForEntries(Aggregate aggr) {
163    AvlNode<E> root = rootReference.get();
164    long total = aggr.treeAggregate(root);
165    if (range.hasLowerBound()) {
166      total -= aggregateBelowRange(aggr, root);
167    }
168    if (range.hasUpperBound()) {
169      total -= aggregateAboveRange(aggr, root);
170    }
171    return total;
172  }
173
174  private long aggregateBelowRange(Aggregate aggr, @CheckForNull AvlNode<E> node) {
175    if (node == null) {
176      return 0;
177    }
178    // The cast is safe because we call this method only if hasLowerBound().
179    int cmp =
180        comparator()
181            .compare(uncheckedCastNullableTToT(range.getLowerEndpoint()), node.getElement());
182    if (cmp < 0) {
183      return aggregateBelowRange(aggr, node.left);
184    } else if (cmp == 0) {
185      switch (range.getLowerBoundType()) {
186        case OPEN:
187          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
188        case CLOSED:
189          return aggr.treeAggregate(node.left);
190        default:
191          throw new AssertionError();
192      }
193    } else {
194      return aggr.treeAggregate(node.left)
195          + aggr.nodeAggregate(node)
196          + aggregateBelowRange(aggr, node.right);
197    }
198  }
199
200  private long aggregateAboveRange(Aggregate aggr, @CheckForNull AvlNode<E> node) {
201    if (node == null) {
202      return 0;
203    }
204    // The cast is safe because we call this method only if hasUpperBound().
205    int cmp =
206        comparator()
207            .compare(uncheckedCastNullableTToT(range.getUpperEndpoint()), node.getElement());
208    if (cmp > 0) {
209      return aggregateAboveRange(aggr, node.right);
210    } else if (cmp == 0) {
211      switch (range.getUpperBoundType()) {
212        case OPEN:
213          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
214        case CLOSED:
215          return aggr.treeAggregate(node.right);
216        default:
217          throw new AssertionError();
218      }
219    } else {
220      return aggr.treeAggregate(node.right)
221          + aggr.nodeAggregate(node)
222          + aggregateAboveRange(aggr, node.left);
223    }
224  }
225
226  @Override
227  public int size() {
228    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
229  }
230
231  @Override
232  int distinctElements() {
233    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
234  }
235
236  static int distinctElements(@CheckForNull AvlNode<?> node) {
237    return (node == null) ? 0 : node.distinctElements;
238  }
239
240  @Override
241  public int count(@CheckForNull Object element) {
242    try {
243      @SuppressWarnings("unchecked")
244      E e = (E) element;
245      AvlNode<E> root = rootReference.get();
246      if (!range.contains(e) || root == null) {
247        return 0;
248      }
249      return root.count(comparator(), e);
250    } catch (ClassCastException | NullPointerException e) {
251      return 0;
252    }
253  }
254
255  @CanIgnoreReturnValue
256  @Override
257  public int add(@ParametricNullness E element, int occurrences) {
258    checkNonnegative(occurrences, "occurrences");
259    if (occurrences == 0) {
260      return count(element);
261    }
262    checkArgument(range.contains(element));
263    AvlNode<E> root = rootReference.get();
264    if (root == null) {
265      comparator().compare(element, element);
266      AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
267      successor(header, newRoot, header);
268      rootReference.checkAndSet(root, newRoot);
269      return 0;
270    }
271    int[] result = new int[1]; // used as a mutable int reference to hold result
272    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
273    rootReference.checkAndSet(root, newRoot);
274    return result[0];
275  }
276
277  @CanIgnoreReturnValue
278  @Override
279  public int remove(@CheckForNull Object element, int occurrences) {
280    checkNonnegative(occurrences, "occurrences");
281    if (occurrences == 0) {
282      return count(element);
283    }
284    AvlNode<E> root = rootReference.get();
285    int[] result = new int[1]; // used as a mutable int reference to hold result
286    AvlNode<E> newRoot;
287    try {
288      @SuppressWarnings("unchecked")
289      E e = (E) element;
290      if (!range.contains(e) || root == null) {
291        return 0;
292      }
293      newRoot = root.remove(comparator(), e, occurrences, result);
294    } catch (ClassCastException | NullPointerException e) {
295      return 0;
296    }
297    rootReference.checkAndSet(root, newRoot);
298    return result[0];
299  }
300
301  @CanIgnoreReturnValue
302  @Override
303  public int setCount(@ParametricNullness E element, int count) {
304    checkNonnegative(count, "count");
305    if (!range.contains(element)) {
306      checkArgument(count == 0);
307      return 0;
308    }
309
310    AvlNode<E> root = rootReference.get();
311    if (root == null) {
312      if (count > 0) {
313        add(element, count);
314      }
315      return 0;
316    }
317    int[] result = new int[1]; // used as a mutable int reference to hold result
318    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
319    rootReference.checkAndSet(root, newRoot);
320    return result[0];
321  }
322
323  @CanIgnoreReturnValue
324  @Override
325  public boolean setCount(@ParametricNullness E element, int oldCount, int newCount) {
326    checkNonnegative(newCount, "newCount");
327    checkNonnegative(oldCount, "oldCount");
328    checkArgument(range.contains(element));
329
330    AvlNode<E> root = rootReference.get();
331    if (root == null) {
332      if (oldCount == 0) {
333        if (newCount > 0) {
334          add(element, newCount);
335        }
336        return true;
337      } else {
338        return false;
339      }
340    }
341    int[] result = new int[1]; // used as a mutable int reference to hold result
342    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
343    rootReference.checkAndSet(root, newRoot);
344    return result[0] == oldCount;
345  }
346
347  @Override
348  public void clear() {
349    if (!range.hasLowerBound() && !range.hasUpperBound()) {
350      // We can do this in O(n) rather than removing one by one, which could force rebalancing.
351      for (AvlNode<E> current = header.succ(); current != header; ) {
352        AvlNode<E> next = current.succ();
353
354        current.elemCount = 0;
355        // Also clear these fields so that one deleted Entry doesn't retain all elements.
356        current.left = null;
357        current.right = null;
358        current.pred = null;
359        current.succ = null;
360
361        current = next;
362      }
363      successor(header, header);
364      rootReference.clear();
365    } else {
366      // TODO(cpovirk): Perhaps we can optimize in this case, too?
367      Iterators.clear(entryIterator());
368    }
369  }
370
371  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
372    return new Multisets.AbstractEntry<E>() {
373      @Override
374      @ParametricNullness
375      public E getElement() {
376        return baseEntry.getElement();
377      }
378
379      @Override
380      public int getCount() {
381        int result = baseEntry.getCount();
382        if (result == 0) {
383          return count(getElement());
384        } else {
385          return result;
386        }
387      }
388    };
389  }
390
391  /** Returns the first node in the tree that is in range. */
392  @CheckForNull
393  private AvlNode<E> firstNode() {
394    AvlNode<E> root = rootReference.get();
395    if (root == null) {
396      return null;
397    }
398    AvlNode<E> node;
399    if (range.hasLowerBound()) {
400      // The cast is safe because of the hasLowerBound check.
401      E endpoint = uncheckedCastNullableTToT(range.getLowerEndpoint());
402      node = root.ceiling(comparator(), endpoint);
403      if (node == null) {
404        return null;
405      }
406      if (range.getLowerBoundType() == BoundType.OPEN
407          && comparator().compare(endpoint, node.getElement()) == 0) {
408        node = node.succ();
409      }
410    } else {
411      node = header.succ();
412    }
413    return (node == header || !range.contains(node.getElement())) ? null : node;
414  }
415
416  @CheckForNull
417  private AvlNode<E> lastNode() {
418    AvlNode<E> root = rootReference.get();
419    if (root == null) {
420      return null;
421    }
422    AvlNode<E> node;
423    if (range.hasUpperBound()) {
424      // The cast is safe because of the hasUpperBound check.
425      E endpoint = uncheckedCastNullableTToT(range.getUpperEndpoint());
426      node = root.floor(comparator(), endpoint);
427      if (node == null) {
428        return null;
429      }
430      if (range.getUpperBoundType() == BoundType.OPEN
431          && comparator().compare(endpoint, node.getElement()) == 0) {
432        node = node.pred();
433      }
434    } else {
435      node = header.pred();
436    }
437    return (node == header || !range.contains(node.getElement())) ? null : node;
438  }
439
440  @Override
441  Iterator<E> elementIterator() {
442    return Multisets.elementIterator(entryIterator());
443  }
444
445  @Override
446  Iterator<Entry<E>> entryIterator() {
447    return new Iterator<Entry<E>>() {
448      @CheckForNull AvlNode<E> current = firstNode();
449      @CheckForNull Entry<E> prevEntry;
450
451      @Override
452      public boolean hasNext() {
453        if (current == null) {
454          return false;
455        } else if (range.tooHigh(current.getElement())) {
456          current = null;
457          return false;
458        } else {
459          return true;
460        }
461      }
462
463      @Override
464      public Entry<E> next() {
465        if (!hasNext()) {
466          throw new NoSuchElementException();
467        }
468        // requireNonNull is safe because current is only nulled out after iteration is complete.
469        Entry<E> result = wrapEntry(requireNonNull(current));
470        prevEntry = result;
471        if (current.succ() == header) {
472          current = null;
473        } else {
474          current = current.succ();
475        }
476        return result;
477      }
478
479      @Override
480      public void remove() {
481        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
482        setCount(prevEntry.getElement(), 0);
483        prevEntry = null;
484      }
485    };
486  }
487
488  @Override
489  Iterator<Entry<E>> descendingEntryIterator() {
490    return new Iterator<Entry<E>>() {
491      @CheckForNull AvlNode<E> current = lastNode();
492      @CheckForNull Entry<E> prevEntry = null;
493
494      @Override
495      public boolean hasNext() {
496        if (current == null) {
497          return false;
498        } else if (range.tooLow(current.getElement())) {
499          current = null;
500          return false;
501        } else {
502          return true;
503        }
504      }
505
506      @Override
507      public Entry<E> next() {
508        if (!hasNext()) {
509          throw new NoSuchElementException();
510        }
511        // requireNonNull is safe because current is only nulled out after iteration is complete.
512        requireNonNull(current);
513        Entry<E> result = wrapEntry(current);
514        prevEntry = result;
515        if (current.pred() == header) {
516          current = null;
517        } else {
518          current = current.pred();
519        }
520        return result;
521      }
522
523      @Override
524      public void remove() {
525        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
526        setCount(prevEntry.getElement(), 0);
527        prevEntry = null;
528      }
529    };
530  }
531
532  @Override
533  public Iterator<E> iterator() {
534    return Multisets.iteratorImpl(this);
535  }
536
537  @Override
538  public SortedMultiset<E> headMultiset(@ParametricNullness E upperBound, BoundType boundType) {
539    return new TreeMultiset<E>(
540        rootReference,
541        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
542        header);
543  }
544
545  @Override
546  public SortedMultiset<E> tailMultiset(@ParametricNullness E lowerBound, BoundType boundType) {
547    return new TreeMultiset<E>(
548        rootReference,
549        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
550        header);
551  }
552
553  private static final class Reference<T> {
554    @CheckForNull private T value;
555
556    @CheckForNull
557    public T get() {
558      return value;
559    }
560
561    public void checkAndSet(@CheckForNull T expected, @CheckForNull T newValue) {
562      if (value != expected) {
563        throw new ConcurrentModificationException();
564      }
565      value = newValue;
566    }
567
568    void clear() {
569      value = null;
570    }
571  }
572
573  private static final class AvlNode<E extends @Nullable Object> {
574    /*
575     * For "normal" nodes, the type of this field is `E`, not `@Nullable E` (though note that E is a
576     * type that can include null, as in a TreeMultiset<@Nullable String>).
577     *
578     * For the header node, though, this field contains `null`, regardless of the type of the
579     * multiset.
580     *
581     * Most code that operates on an AvlNode never operates on the header node. Such code can access
582     * the elem field without a null check by calling getElement().
583     */
584    @CheckForNull private final E elem;
585
586    // elemCount is 0 iff this node has been deleted.
587    private int elemCount;
588
589    private int distinctElements;
590    private long totalCount;
591    private int height;
592    @CheckForNull private AvlNode<E> left;
593    @CheckForNull private AvlNode<E> right;
594    /*
595     * pred and succ are nullable after construction, but we always call successor() to initialize
596     * them immediately thereafter.
597     *
598     * They may be subsequently nulled out by TreeMultiset.clear(). I think that the only place that
599     * we can reference a node whose fields have been cleared is inside the iterator (and presumably
600     * only under concurrent modification).
601     *
602     * To access these fields when you know that they are not null, call the pred() and succ()
603     * methods, which perform null checks before returning the fields.
604     */
605    @CheckForNull private AvlNode<E> pred;
606    @CheckForNull private AvlNode<E> succ;
607
608    AvlNode(@ParametricNullness E elem, int elemCount) {
609      checkArgument(elemCount > 0);
610      this.elem = elem;
611      this.elemCount = elemCount;
612      this.totalCount = elemCount;
613      this.distinctElements = 1;
614      this.height = 1;
615      this.left = null;
616      this.right = null;
617    }
618
619    /** Constructor for the header node. */
620    AvlNode() {
621      this.elem = null;
622      this.elemCount = 1;
623    }
624
625    // For discussion of pred() and succ(), see the comment on the pred and succ fields.
626
627    private AvlNode<E> pred() {
628      return requireNonNull(pred);
629    }
630
631    private AvlNode<E> succ() {
632      return requireNonNull(succ);
633    }
634
635    int count(Comparator<? super E> comparator, @ParametricNullness E e) {
636      int cmp = comparator.compare(e, getElement());
637      if (cmp < 0) {
638        return (left == null) ? 0 : left.count(comparator, e);
639      } else if (cmp > 0) {
640        return (right == null) ? 0 : right.count(comparator, e);
641      } else {
642        return elemCount;
643      }
644    }
645
646    private AvlNode<E> addRightChild(@ParametricNullness E e, int count) {
647      right = new AvlNode<E>(e, count);
648      successor(this, right, succ());
649      height = Math.max(2, height);
650      distinctElements++;
651      totalCount += count;
652      return this;
653    }
654
655    private AvlNode<E> addLeftChild(@ParametricNullness E e, int count) {
656      left = new AvlNode<E>(e, count);
657      successor(pred(), left, this);
658      height = Math.max(2, height);
659      distinctElements++;
660      totalCount += count;
661      return this;
662    }
663
664    AvlNode<E> add(
665        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
666      /*
667       * It speeds things up considerably to unconditionally add count to totalCount here,
668       * but that destroys failure atomicity in the case of count overflow. =(
669       */
670      int cmp = comparator.compare(e, getElement());
671      if (cmp < 0) {
672        AvlNode<E> initLeft = left;
673        if (initLeft == null) {
674          result[0] = 0;
675          return addLeftChild(e, count);
676        }
677        int initHeight = initLeft.height;
678
679        left = initLeft.add(comparator, e, count, result);
680        if (result[0] == 0) {
681          distinctElements++;
682        }
683        this.totalCount += count;
684        return (left.height == initHeight) ? this : rebalance();
685      } else if (cmp > 0) {
686        AvlNode<E> initRight = right;
687        if (initRight == null) {
688          result[0] = 0;
689          return addRightChild(e, count);
690        }
691        int initHeight = initRight.height;
692
693        right = initRight.add(comparator, e, count, result);
694        if (result[0] == 0) {
695          distinctElements++;
696        }
697        this.totalCount += count;
698        return (right.height == initHeight) ? this : rebalance();
699      }
700
701      // adding count to me!  No rebalance possible.
702      result[0] = elemCount;
703      long resultCount = (long) elemCount + count;
704      checkArgument(resultCount <= Integer.MAX_VALUE);
705      this.elemCount += count;
706      this.totalCount += count;
707      return this;
708    }
709
710    @CheckForNull
711    AvlNode<E> remove(
712        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
713      int cmp = comparator.compare(e, getElement());
714      if (cmp < 0) {
715        AvlNode<E> initLeft = left;
716        if (initLeft == null) {
717          result[0] = 0;
718          return this;
719        }
720
721        left = initLeft.remove(comparator, e, count, result);
722
723        if (result[0] > 0) {
724          if (count >= result[0]) {
725            this.distinctElements--;
726            this.totalCount -= result[0];
727          } else {
728            this.totalCount -= count;
729          }
730        }
731        return (result[0] == 0) ? this : rebalance();
732      } else if (cmp > 0) {
733        AvlNode<E> initRight = right;
734        if (initRight == null) {
735          result[0] = 0;
736          return this;
737        }
738
739        right = initRight.remove(comparator, e, count, result);
740
741        if (result[0] > 0) {
742          if (count >= result[0]) {
743            this.distinctElements--;
744            this.totalCount -= result[0];
745          } else {
746            this.totalCount -= count;
747          }
748        }
749        return rebalance();
750      }
751
752      // removing count from me!
753      result[0] = elemCount;
754      if (count >= elemCount) {
755        return deleteMe();
756      } else {
757        this.elemCount -= count;
758        this.totalCount -= count;
759        return this;
760      }
761    }
762
763    @CheckForNull
764    AvlNode<E> setCount(
765        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
766      int cmp = comparator.compare(e, getElement());
767      if (cmp < 0) {
768        AvlNode<E> initLeft = left;
769        if (initLeft == null) {
770          result[0] = 0;
771          return (count > 0) ? addLeftChild(e, count) : this;
772        }
773
774        left = initLeft.setCount(comparator, e, count, result);
775
776        if (count == 0 && result[0] != 0) {
777          this.distinctElements--;
778        } else if (count > 0 && result[0] == 0) {
779          this.distinctElements++;
780        }
781
782        this.totalCount += count - result[0];
783        return rebalance();
784      } else if (cmp > 0) {
785        AvlNode<E> initRight = right;
786        if (initRight == null) {
787          result[0] = 0;
788          return (count > 0) ? addRightChild(e, count) : this;
789        }
790
791        right = initRight.setCount(comparator, e, count, result);
792
793        if (count == 0 && result[0] != 0) {
794          this.distinctElements--;
795        } else if (count > 0 && result[0] == 0) {
796          this.distinctElements++;
797        }
798
799        this.totalCount += count - result[0];
800        return rebalance();
801      }
802
803      // setting my count
804      result[0] = elemCount;
805      if (count == 0) {
806        return deleteMe();
807      }
808      this.totalCount += count - elemCount;
809      this.elemCount = count;
810      return this;
811    }
812
813    @CheckForNull
814    AvlNode<E> setCount(
815        Comparator<? super E> comparator,
816        @ParametricNullness E e,
817        int expectedCount,
818        int newCount,
819        int[] result) {
820      int cmp = comparator.compare(e, getElement());
821      if (cmp < 0) {
822        AvlNode<E> initLeft = left;
823        if (initLeft == null) {
824          result[0] = 0;
825          if (expectedCount == 0 && newCount > 0) {
826            return addLeftChild(e, newCount);
827          }
828          return this;
829        }
830
831        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
832
833        if (result[0] == expectedCount) {
834          if (newCount == 0 && result[0] != 0) {
835            this.distinctElements--;
836          } else if (newCount > 0 && result[0] == 0) {
837            this.distinctElements++;
838          }
839          this.totalCount += newCount - result[0];
840        }
841        return rebalance();
842      } else if (cmp > 0) {
843        AvlNode<E> initRight = right;
844        if (initRight == null) {
845          result[0] = 0;
846          if (expectedCount == 0 && newCount > 0) {
847            return addRightChild(e, newCount);
848          }
849          return this;
850        }
851
852        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
853
854        if (result[0] == expectedCount) {
855          if (newCount == 0 && result[0] != 0) {
856            this.distinctElements--;
857          } else if (newCount > 0 && result[0] == 0) {
858            this.distinctElements++;
859          }
860          this.totalCount += newCount - result[0];
861        }
862        return rebalance();
863      }
864
865      // setting my count
866      result[0] = elemCount;
867      if (expectedCount == elemCount) {
868        if (newCount == 0) {
869          return deleteMe();
870        }
871        this.totalCount += newCount - elemCount;
872        this.elemCount = newCount;
873      }
874      return this;
875    }
876
877    @CheckForNull
878    private AvlNode<E> deleteMe() {
879      int oldElemCount = this.elemCount;
880      this.elemCount = 0;
881      successor(pred(), succ());
882      if (left == null) {
883        return right;
884      } else if (right == null) {
885        return left;
886      } else if (left.height >= right.height) {
887        AvlNode<E> newTop = pred();
888        // newTop is the maximum node in my left subtree
889        newTop.left = left.removeMax(newTop);
890        newTop.right = right;
891        newTop.distinctElements = distinctElements - 1;
892        newTop.totalCount = totalCount - oldElemCount;
893        return newTop.rebalance();
894      } else {
895        AvlNode<E> newTop = succ();
896        newTop.right = right.removeMin(newTop);
897        newTop.left = left;
898        newTop.distinctElements = distinctElements - 1;
899        newTop.totalCount = totalCount - oldElemCount;
900        return newTop.rebalance();
901      }
902    }
903
904    // Removes the minimum node from this subtree to be reused elsewhere
905    @CheckForNull
906    private AvlNode<E> removeMin(AvlNode<E> node) {
907      if (left == null) {
908        return right;
909      } else {
910        left = left.removeMin(node);
911        distinctElements--;
912        totalCount -= node.elemCount;
913        return rebalance();
914      }
915    }
916
917    // Removes the maximum node from this subtree to be reused elsewhere
918    @CheckForNull
919    private AvlNode<E> removeMax(AvlNode<E> node) {
920      if (right == null) {
921        return left;
922      } else {
923        right = right.removeMax(node);
924        distinctElements--;
925        totalCount -= node.elemCount;
926        return rebalance();
927      }
928    }
929
930    private void recomputeMultiset() {
931      this.distinctElements =
932          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
933      this.totalCount = elemCount + totalCount(left) + totalCount(right);
934    }
935
936    private void recomputeHeight() {
937      this.height = 1 + Math.max(height(left), height(right));
938    }
939
940    private void recompute() {
941      recomputeMultiset();
942      recomputeHeight();
943    }
944
945    private AvlNode<E> rebalance() {
946      switch (balanceFactor()) {
947        case -2:
948          // requireNonNull is safe because right must exist in order to get a negative factor.
949          requireNonNull(right);
950          if (right.balanceFactor() > 0) {
951            right = right.rotateRight();
952          }
953          return rotateLeft();
954        case 2:
955          // requireNonNull is safe because left must exist in order to get a positive factor.
956          requireNonNull(left);
957          if (left.balanceFactor() < 0) {
958            left = left.rotateLeft();
959          }
960          return rotateRight();
961        default:
962          recomputeHeight();
963          return this;
964      }
965    }
966
967    private int balanceFactor() {
968      return height(left) - height(right);
969    }
970
971    private AvlNode<E> rotateLeft() {
972      checkState(right != null);
973      AvlNode<E> newTop = right;
974      this.right = newTop.left;
975      newTop.left = this;
976      newTop.totalCount = this.totalCount;
977      newTop.distinctElements = this.distinctElements;
978      this.recompute();
979      newTop.recomputeHeight();
980      return newTop;
981    }
982
983    private AvlNode<E> rotateRight() {
984      checkState(left != null);
985      AvlNode<E> newTop = left;
986      this.left = newTop.right;
987      newTop.right = this;
988      newTop.totalCount = this.totalCount;
989      newTop.distinctElements = this.distinctElements;
990      this.recompute();
991      newTop.recomputeHeight();
992      return newTop;
993    }
994
995    private static long totalCount(@CheckForNull AvlNode<?> node) {
996      return (node == null) ? 0 : node.totalCount;
997    }
998
999    private static int height(@CheckForNull AvlNode<?> node) {
1000      return (node == null) ? 0 : node.height;
1001    }
1002
1003    @CheckForNull
1004    private AvlNode<E> ceiling(Comparator<? super E> comparator, @ParametricNullness E e) {
1005      int cmp = comparator.compare(e, getElement());
1006      if (cmp < 0) {
1007        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
1008      } else if (cmp == 0) {
1009        return this;
1010      } else {
1011        return (right == null) ? null : right.ceiling(comparator, e);
1012      }
1013    }
1014
1015    @CheckForNull
1016    private AvlNode<E> floor(Comparator<? super E> comparator, @ParametricNullness E e) {
1017      int cmp = comparator.compare(e, getElement());
1018      if (cmp > 0) {
1019        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
1020      } else if (cmp == 0) {
1021        return this;
1022      } else {
1023        return (left == null) ? null : left.floor(comparator, e);
1024      }
1025    }
1026
1027    @ParametricNullness
1028    E getElement() {
1029      // For discussion of this cast, see the comment on the elem field.
1030      return uncheckedCastNullableTToT(elem);
1031    }
1032
1033    int getCount() {
1034      return elemCount;
1035    }
1036
1037    @Override
1038    public String toString() {
1039      return Multisets.immutableEntry(getElement(), getCount()).toString();
1040    }
1041  }
1042
1043  private static <T extends @Nullable Object> void successor(AvlNode<T> a, AvlNode<T> b) {
1044    a.succ = b;
1045    b.pred = a;
1046  }
1047
1048  private static <T extends @Nullable Object> void successor(
1049      AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
1050    successor(a, b);
1051    successor(b, c);
1052  }
1053
1054  /*
1055   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
1056   * calls the comparator to compare the two keys. If that change is made,
1057   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
1058   */
1059
1060  /**
1061   * @serialData the comparator, the number of distinct elements, the first element, its count, the
1062   *     second element, its count, and so on
1063   */
1064  @GwtIncompatible // java.io.ObjectOutputStream
1065  private void writeObject(ObjectOutputStream stream) throws IOException {
1066    stream.defaultWriteObject();
1067    stream.writeObject(elementSet().comparator());
1068    Serialization.writeMultiset(this, stream);
1069  }
1070
1071  @GwtIncompatible // java.io.ObjectInputStream
1072  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
1073    stream.defaultReadObject();
1074    @SuppressWarnings("unchecked")
1075    // reading data stored by writeObject
1076    Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
1077    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
1078    Serialization.getFieldSetter(TreeMultiset.class, "range")
1079        .set(this, GeneralRange.all(comparator));
1080    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
1081        .set(this, new Reference<AvlNode<E>>());
1082    AvlNode<E> header = new AvlNode<>();
1083    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
1084    successor(header, header);
1085    Serialization.populateMultiset(this, stream);
1086  }
1087
1088  @GwtIncompatible // not needed in emulated source
1089  private static final long serialVersionUID = 1;
1090}