001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkState;
021import static com.google.common.collect.CollectPreconditions.checkNonnegative;
022import static com.google.common.collect.NullnessCasts.uncheckedCastNullableTToT;
023import static java.util.Objects.requireNonNull;
024
025import com.google.common.annotations.GwtCompatible;
026import com.google.common.annotations.GwtIncompatible;
027import com.google.common.annotations.J2ktIncompatible;
028import com.google.common.base.MoreObjects;
029import com.google.common.primitives.Ints;
030import com.google.errorprone.annotations.CanIgnoreReturnValue;
031import java.io.IOException;
032import java.io.ObjectInputStream;
033import java.io.ObjectOutputStream;
034import java.io.Serializable;
035import java.util.Comparator;
036import java.util.ConcurrentModificationException;
037import java.util.Iterator;
038import java.util.NoSuchElementException;
039import javax.annotation.CheckForNull;
040import org.checkerframework.checker.nullness.qual.Nullable;
041
042/**
043 * A multiset which maintains the ordering of its elements, according to either their natural order
044 * or an explicit {@link Comparator}. In all cases, this implementation uses {@link
045 * Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to determine
046 * equivalence of instances.
047 *
048 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
049 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the {@link
050 * java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
051 *
052 * <p>See the Guava User Guide article on <a href=
053 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset">{@code Multiset}</a>.
054 *
055 * @author Louis Wasserman
056 * @author Jared Levy
057 * @since 2.0
058 */
059@GwtCompatible(emulated = true)
060@ElementTypesAreNonnullByDefault
061public final class TreeMultiset<E extends @Nullable Object> extends AbstractSortedMultiset<E>
062    implements Serializable {
063
064  /**
065   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
066   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
067   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
068   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
069   * user attempts to add an element to the multiset that violates this constraint (for example, the
070   * user attempts to add a string element to a set whose elements are integers), the {@code
071   * add(Object)} call will throw a {@code ClassCastException}.
072   *
073   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
074   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
075   */
076  @SuppressWarnings("rawtypes") // https://github.com/google/guava/issues/989
077  public static <E extends Comparable> TreeMultiset<E> create() {
078    return new TreeMultiset<E>(Ordering.natural());
079  }
080
081  /**
082   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
083   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
084   * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements
085   * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the
086   * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code
087   * ClassCastException}.
088   *
089   * @param comparator the comparator that will be used to sort this multiset. A null value
090   *     indicates that the elements' <i>natural ordering</i> should be used.
091   */
092  @SuppressWarnings("unchecked")
093  public static <E extends @Nullable Object> TreeMultiset<E> create(
094      @CheckForNull Comparator<? super E> comparator) {
095    return (comparator == null)
096        ? new TreeMultiset<E>((Comparator) Ordering.natural())
097        : new TreeMultiset<E>(comparator);
098  }
099
100  /**
101   * Creates an empty multiset containing the given initial elements, sorted according to the
102   * elements' natural order.
103   *
104   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
105   *
106   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
107   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
108   */
109  @SuppressWarnings("rawtypes") // https://github.com/google/guava/issues/989
110  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
111    TreeMultiset<E> multiset = create();
112    Iterables.addAll(multiset, elements);
113    return multiset;
114  }
115
116  private final transient Reference<AvlNode<E>> rootReference;
117  private final transient GeneralRange<E> range;
118  private final transient AvlNode<E> header;
119
120  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
121    super(range.comparator());
122    this.rootReference = rootReference;
123    this.range = range;
124    this.header = endLink;
125  }
126
127  TreeMultiset(Comparator<? super E> comparator) {
128    super(comparator);
129    this.range = GeneralRange.all(comparator);
130    this.header = new AvlNode<>();
131    successor(header, header);
132    this.rootReference = new Reference<>();
133  }
134
135  /** A function which can be summed across a subtree. */
136  private enum Aggregate {
137    SIZE {
138      @Override
139      int nodeAggregate(AvlNode<?> node) {
140        return node.elemCount;
141      }
142
143      @Override
144      long treeAggregate(@CheckForNull AvlNode<?> root) {
145        return (root == null) ? 0 : root.totalCount;
146      }
147    },
148    DISTINCT {
149      @Override
150      int nodeAggregate(AvlNode<?> node) {
151        return 1;
152      }
153
154      @Override
155      long treeAggregate(@CheckForNull AvlNode<?> root) {
156        return (root == null) ? 0 : root.distinctElements;
157      }
158    };
159
160    abstract int nodeAggregate(AvlNode<?> node);
161
162    abstract long treeAggregate(@CheckForNull AvlNode<?> root);
163  }
164
165  private long aggregateForEntries(Aggregate aggr) {
166    AvlNode<E> root = rootReference.get();
167    long total = aggr.treeAggregate(root);
168    if (range.hasLowerBound()) {
169      total -= aggregateBelowRange(aggr, root);
170    }
171    if (range.hasUpperBound()) {
172      total -= aggregateAboveRange(aggr, root);
173    }
174    return total;
175  }
176
177  private long aggregateBelowRange(Aggregate aggr, @CheckForNull AvlNode<E> node) {
178    if (node == null) {
179      return 0;
180    }
181    // The cast is safe because we call this method only if hasLowerBound().
182    int cmp =
183        comparator()
184            .compare(uncheckedCastNullableTToT(range.getLowerEndpoint()), node.getElement());
185    if (cmp < 0) {
186      return aggregateBelowRange(aggr, node.left);
187    } else if (cmp == 0) {
188      switch (range.getLowerBoundType()) {
189        case OPEN:
190          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
191        case CLOSED:
192          return aggr.treeAggregate(node.left);
193        default:
194          throw new AssertionError();
195      }
196    } else {
197      return aggr.treeAggregate(node.left)
198          + aggr.nodeAggregate(node)
199          + aggregateBelowRange(aggr, node.right);
200    }
201  }
202
203  private long aggregateAboveRange(Aggregate aggr, @CheckForNull AvlNode<E> node) {
204    if (node == null) {
205      return 0;
206    }
207    // The cast is safe because we call this method only if hasUpperBound().
208    int cmp =
209        comparator()
210            .compare(uncheckedCastNullableTToT(range.getUpperEndpoint()), node.getElement());
211    if (cmp > 0) {
212      return aggregateAboveRange(aggr, node.right);
213    } else if (cmp == 0) {
214      switch (range.getUpperBoundType()) {
215        case OPEN:
216          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
217        case CLOSED:
218          return aggr.treeAggregate(node.right);
219        default:
220          throw new AssertionError();
221      }
222    } else {
223      return aggr.treeAggregate(node.right)
224          + aggr.nodeAggregate(node)
225          + aggregateAboveRange(aggr, node.left);
226    }
227  }
228
229  @Override
230  public int size() {
231    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
232  }
233
234  @Override
235  int distinctElements() {
236    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
237  }
238
239  static int distinctElements(@CheckForNull AvlNode<?> node) {
240    return (node == null) ? 0 : node.distinctElements;
241  }
242
243  @Override
244  public int count(@CheckForNull Object element) {
245    try {
246      @SuppressWarnings("unchecked")
247      E e = (E) element;
248      AvlNode<E> root = rootReference.get();
249      if (!range.contains(e) || root == null) {
250        return 0;
251      }
252      return root.count(comparator(), e);
253    } catch (ClassCastException | NullPointerException e) {
254      return 0;
255    }
256  }
257
258  @CanIgnoreReturnValue
259  @Override
260  public int add(@ParametricNullness E element, int occurrences) {
261    checkNonnegative(occurrences, "occurrences");
262    if (occurrences == 0) {
263      return count(element);
264    }
265    checkArgument(range.contains(element));
266    AvlNode<E> root = rootReference.get();
267    if (root == null) {
268      int unused = comparator().compare(element, element);
269      AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
270      successor(header, newRoot, header);
271      rootReference.checkAndSet(root, newRoot);
272      return 0;
273    }
274    int[] result = new int[1]; // used as a mutable int reference to hold result
275    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
276    rootReference.checkAndSet(root, newRoot);
277    return result[0];
278  }
279
280  @CanIgnoreReturnValue
281  @Override
282  public int remove(@CheckForNull Object element, int occurrences) {
283    checkNonnegative(occurrences, "occurrences");
284    if (occurrences == 0) {
285      return count(element);
286    }
287    AvlNode<E> root = rootReference.get();
288    int[] result = new int[1]; // used as a mutable int reference to hold result
289    AvlNode<E> newRoot;
290    try {
291      @SuppressWarnings("unchecked")
292      E e = (E) element;
293      if (!range.contains(e) || root == null) {
294        return 0;
295      }
296      newRoot = root.remove(comparator(), e, occurrences, result);
297    } catch (ClassCastException | NullPointerException e) {
298      return 0;
299    }
300    rootReference.checkAndSet(root, newRoot);
301    return result[0];
302  }
303
304  @CanIgnoreReturnValue
305  @Override
306  public int setCount(@ParametricNullness E element, int count) {
307    checkNonnegative(count, "count");
308    if (!range.contains(element)) {
309      checkArgument(count == 0);
310      return 0;
311    }
312
313    AvlNode<E> root = rootReference.get();
314    if (root == null) {
315      if (count > 0) {
316        add(element, count);
317      }
318      return 0;
319    }
320    int[] result = new int[1]; // used as a mutable int reference to hold result
321    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
322    rootReference.checkAndSet(root, newRoot);
323    return result[0];
324  }
325
326  @CanIgnoreReturnValue
327  @Override
328  public boolean setCount(@ParametricNullness E element, int oldCount, int newCount) {
329    checkNonnegative(newCount, "newCount");
330    checkNonnegative(oldCount, "oldCount");
331    checkArgument(range.contains(element));
332
333    AvlNode<E> root = rootReference.get();
334    if (root == null) {
335      if (oldCount == 0) {
336        if (newCount > 0) {
337          add(element, newCount);
338        }
339        return true;
340      } else {
341        return false;
342      }
343    }
344    int[] result = new int[1]; // used as a mutable int reference to hold result
345    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
346    rootReference.checkAndSet(root, newRoot);
347    return result[0] == oldCount;
348  }
349
350  @Override
351  public void clear() {
352    if (!range.hasLowerBound() && !range.hasUpperBound()) {
353      // We can do this in O(n) rather than removing one by one, which could force rebalancing.
354      for (AvlNode<E> current = header.succ(); current != header; ) {
355        AvlNode<E> next = current.succ();
356
357        current.elemCount = 0;
358        // Also clear these fields so that one deleted Entry doesn't retain all elements.
359        current.left = null;
360        current.right = null;
361        current.pred = null;
362        current.succ = null;
363
364        current = next;
365      }
366      successor(header, header);
367      rootReference.clear();
368    } else {
369      // TODO(cpovirk): Perhaps we can optimize in this case, too?
370      Iterators.clear(entryIterator());
371    }
372  }
373
374  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
375    return new Multisets.AbstractEntry<E>() {
376      @Override
377      @ParametricNullness
378      public E getElement() {
379        return baseEntry.getElement();
380      }
381
382      @Override
383      public int getCount() {
384        int result = baseEntry.getCount();
385        if (result == 0) {
386          return count(getElement());
387        } else {
388          return result;
389        }
390      }
391    };
392  }
393
394  /** Returns the first node in the tree that is in range. */
395  @CheckForNull
396  private AvlNode<E> firstNode() {
397    AvlNode<E> root = rootReference.get();
398    if (root == null) {
399      return null;
400    }
401    AvlNode<E> node;
402    if (range.hasLowerBound()) {
403      // The cast is safe because of the hasLowerBound check.
404      E endpoint = uncheckedCastNullableTToT(range.getLowerEndpoint());
405      node = root.ceiling(comparator(), endpoint);
406      if (node == null) {
407        return null;
408      }
409      if (range.getLowerBoundType() == BoundType.OPEN
410          && comparator().compare(endpoint, node.getElement()) == 0) {
411        node = node.succ();
412      }
413    } else {
414      node = header.succ();
415    }
416    return (node == header || !range.contains(node.getElement())) ? null : node;
417  }
418
419  @CheckForNull
420  private AvlNode<E> lastNode() {
421    AvlNode<E> root = rootReference.get();
422    if (root == null) {
423      return null;
424    }
425    AvlNode<E> node;
426    if (range.hasUpperBound()) {
427      // The cast is safe because of the hasUpperBound check.
428      E endpoint = uncheckedCastNullableTToT(range.getUpperEndpoint());
429      node = root.floor(comparator(), endpoint);
430      if (node == null) {
431        return null;
432      }
433      if (range.getUpperBoundType() == BoundType.OPEN
434          && comparator().compare(endpoint, node.getElement()) == 0) {
435        node = node.pred();
436      }
437    } else {
438      node = header.pred();
439    }
440    return (node == header || !range.contains(node.getElement())) ? null : node;
441  }
442
443  @Override
444  Iterator<E> elementIterator() {
445    return Multisets.elementIterator(entryIterator());
446  }
447
448  @Override
449  Iterator<Entry<E>> entryIterator() {
450    return new Iterator<Entry<E>>() {
451      @CheckForNull AvlNode<E> current = firstNode();
452      @CheckForNull Entry<E> prevEntry;
453
454      @Override
455      public boolean hasNext() {
456        if (current == null) {
457          return false;
458        } else if (range.tooHigh(current.getElement())) {
459          current = null;
460          return false;
461        } else {
462          return true;
463        }
464      }
465
466      @Override
467      public Entry<E> next() {
468        if (!hasNext()) {
469          throw new NoSuchElementException();
470        }
471        // requireNonNull is safe because current is only nulled out after iteration is complete.
472        Entry<E> result = wrapEntry(requireNonNull(current));
473        prevEntry = result;
474        if (current.succ() == header) {
475          current = null;
476        } else {
477          current = current.succ();
478        }
479        return result;
480      }
481
482      @Override
483      public void remove() {
484        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
485        setCount(prevEntry.getElement(), 0);
486        prevEntry = null;
487      }
488    };
489  }
490
491  @Override
492  Iterator<Entry<E>> descendingEntryIterator() {
493    return new Iterator<Entry<E>>() {
494      @CheckForNull AvlNode<E> current = lastNode();
495      @CheckForNull Entry<E> prevEntry = null;
496
497      @Override
498      public boolean hasNext() {
499        if (current == null) {
500          return false;
501        } else if (range.tooLow(current.getElement())) {
502          current = null;
503          return false;
504        } else {
505          return true;
506        }
507      }
508
509      @Override
510      public Entry<E> next() {
511        if (!hasNext()) {
512          throw new NoSuchElementException();
513        }
514        // requireNonNull is safe because current is only nulled out after iteration is complete.
515        requireNonNull(current);
516        Entry<E> result = wrapEntry(current);
517        prevEntry = result;
518        if (current.pred() == header) {
519          current = null;
520        } else {
521          current = current.pred();
522        }
523        return result;
524      }
525
526      @Override
527      public void remove() {
528        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
529        setCount(prevEntry.getElement(), 0);
530        prevEntry = null;
531      }
532    };
533  }
534
535  @Override
536  public Iterator<E> iterator() {
537    return Multisets.iteratorImpl(this);
538  }
539
540  @Override
541  public SortedMultiset<E> headMultiset(@ParametricNullness E upperBound, BoundType boundType) {
542    return new TreeMultiset<E>(
543        rootReference,
544        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
545        header);
546  }
547
548  @Override
549  public SortedMultiset<E> tailMultiset(@ParametricNullness E lowerBound, BoundType boundType) {
550    return new TreeMultiset<E>(
551        rootReference,
552        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
553        header);
554  }
555
556  private static final class Reference<T> {
557    @CheckForNull private T value;
558
559    @CheckForNull
560    public T get() {
561      return value;
562    }
563
564    public void checkAndSet(@CheckForNull T expected, @CheckForNull T newValue) {
565      if (value != expected) {
566        throw new ConcurrentModificationException();
567      }
568      value = newValue;
569    }
570
571    void clear() {
572      value = null;
573    }
574  }
575
576  private static final class AvlNode<E extends @Nullable Object> {
577    /*
578     * For "normal" nodes, the type of this field is `E`, not `@Nullable E` (though note that E is a
579     * type that can include null, as in a TreeMultiset<@Nullable String>).
580     *
581     * For the header node, though, this field contains `null`, regardless of the type of the
582     * multiset.
583     *
584     * Most code that operates on an AvlNode never operates on the header node. Such code can access
585     * the elem field without a null check by calling getElement().
586     */
587    @CheckForNull private final E elem;
588
589    // elemCount is 0 iff this node has been deleted.
590    private int elemCount;
591
592    private int distinctElements;
593    private long totalCount;
594    private int height;
595    @CheckForNull private AvlNode<E> left;
596    @CheckForNull private AvlNode<E> right;
597    /*
598     * pred and succ are nullable after construction, but we always call successor() to initialize
599     * them immediately thereafter.
600     *
601     * They may be subsequently nulled out by TreeMultiset.clear(). I think that the only place that
602     * we can reference a node whose fields have been cleared is inside the iterator (and presumably
603     * only under concurrent modification).
604     *
605     * To access these fields when you know that they are not null, call the pred() and succ()
606     * methods, which perform null checks before returning the fields.
607     */
608    @CheckForNull private AvlNode<E> pred;
609    @CheckForNull private AvlNode<E> succ;
610
611    AvlNode(@ParametricNullness E elem, int elemCount) {
612      checkArgument(elemCount > 0);
613      this.elem = elem;
614      this.elemCount = elemCount;
615      this.totalCount = elemCount;
616      this.distinctElements = 1;
617      this.height = 1;
618      this.left = null;
619      this.right = null;
620    }
621
622    /** Constructor for the header node. */
623    AvlNode() {
624      this.elem = null;
625      this.elemCount = 1;
626    }
627
628    // For discussion of pred() and succ(), see the comment on the pred and succ fields.
629
630    private AvlNode<E> pred() {
631      return requireNonNull(pred);
632    }
633
634    private AvlNode<E> succ() {
635      return requireNonNull(succ);
636    }
637
638    int count(Comparator<? super E> comparator, @ParametricNullness E e) {
639      int cmp = comparator.compare(e, getElement());
640      if (cmp < 0) {
641        return (left == null) ? 0 : left.count(comparator, e);
642      } else if (cmp > 0) {
643        return (right == null) ? 0 : right.count(comparator, e);
644      } else {
645        return elemCount;
646      }
647    }
648
649    private AvlNode<E> addRightChild(@ParametricNullness E e, int count) {
650      right = new AvlNode<E>(e, count);
651      successor(this, right, succ());
652      height = Math.max(2, height);
653      distinctElements++;
654      totalCount += count;
655      return this;
656    }
657
658    private AvlNode<E> addLeftChild(@ParametricNullness E e, int count) {
659      left = new AvlNode<E>(e, count);
660      successor(pred(), left, this);
661      height = Math.max(2, height);
662      distinctElements++;
663      totalCount += count;
664      return this;
665    }
666
667    AvlNode<E> add(
668        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
669      /*
670       * It speeds things up considerably to unconditionally add count to totalCount here,
671       * but that destroys failure atomicity in the case of count overflow. =(
672       */
673      int cmp = comparator.compare(e, getElement());
674      if (cmp < 0) {
675        AvlNode<E> initLeft = left;
676        if (initLeft == null) {
677          result[0] = 0;
678          return addLeftChild(e, count);
679        }
680        int initHeight = initLeft.height;
681
682        left = initLeft.add(comparator, e, count, result);
683        if (result[0] == 0) {
684          distinctElements++;
685        }
686        this.totalCount += count;
687        return (left.height == initHeight) ? this : rebalance();
688      } else if (cmp > 0) {
689        AvlNode<E> initRight = right;
690        if (initRight == null) {
691          result[0] = 0;
692          return addRightChild(e, count);
693        }
694        int initHeight = initRight.height;
695
696        right = initRight.add(comparator, e, count, result);
697        if (result[0] == 0) {
698          distinctElements++;
699        }
700        this.totalCount += count;
701        return (right.height == initHeight) ? this : rebalance();
702      }
703
704      // adding count to me!  No rebalance possible.
705      result[0] = elemCount;
706      long resultCount = (long) elemCount + count;
707      checkArgument(resultCount <= Integer.MAX_VALUE);
708      this.elemCount += count;
709      this.totalCount += count;
710      return this;
711    }
712
713    @CheckForNull
714    AvlNode<E> remove(
715        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
716      int cmp = comparator.compare(e, getElement());
717      if (cmp < 0) {
718        AvlNode<E> initLeft = left;
719        if (initLeft == null) {
720          result[0] = 0;
721          return this;
722        }
723
724        left = initLeft.remove(comparator, e, count, result);
725
726        if (result[0] > 0) {
727          if (count >= result[0]) {
728            this.distinctElements--;
729            this.totalCount -= result[0];
730          } else {
731            this.totalCount -= count;
732          }
733        }
734        return (result[0] == 0) ? this : rebalance();
735      } else if (cmp > 0) {
736        AvlNode<E> initRight = right;
737        if (initRight == null) {
738          result[0] = 0;
739          return this;
740        }
741
742        right = initRight.remove(comparator, e, count, result);
743
744        if (result[0] > 0) {
745          if (count >= result[0]) {
746            this.distinctElements--;
747            this.totalCount -= result[0];
748          } else {
749            this.totalCount -= count;
750          }
751        }
752        return rebalance();
753      }
754
755      // removing count from me!
756      result[0] = elemCount;
757      if (count >= elemCount) {
758        return deleteMe();
759      } else {
760        this.elemCount -= count;
761        this.totalCount -= count;
762        return this;
763      }
764    }
765
766    @CheckForNull
767    AvlNode<E> setCount(
768        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
769      int cmp = comparator.compare(e, getElement());
770      if (cmp < 0) {
771        AvlNode<E> initLeft = left;
772        if (initLeft == null) {
773          result[0] = 0;
774          return (count > 0) ? addLeftChild(e, count) : this;
775        }
776
777        left = initLeft.setCount(comparator, e, count, result);
778
779        if (count == 0 && result[0] != 0) {
780          this.distinctElements--;
781        } else if (count > 0 && result[0] == 0) {
782          this.distinctElements++;
783        }
784
785        this.totalCount += count - result[0];
786        return rebalance();
787      } else if (cmp > 0) {
788        AvlNode<E> initRight = right;
789        if (initRight == null) {
790          result[0] = 0;
791          return (count > 0) ? addRightChild(e, count) : this;
792        }
793
794        right = initRight.setCount(comparator, e, count, result);
795
796        if (count == 0 && result[0] != 0) {
797          this.distinctElements--;
798        } else if (count > 0 && result[0] == 0) {
799          this.distinctElements++;
800        }
801
802        this.totalCount += count - result[0];
803        return rebalance();
804      }
805
806      // setting my count
807      result[0] = elemCount;
808      if (count == 0) {
809        return deleteMe();
810      }
811      this.totalCount += count - elemCount;
812      this.elemCount = count;
813      return this;
814    }
815
816    @CheckForNull
817    AvlNode<E> setCount(
818        Comparator<? super E> comparator,
819        @ParametricNullness E e,
820        int expectedCount,
821        int newCount,
822        int[] result) {
823      int cmp = comparator.compare(e, getElement());
824      if (cmp < 0) {
825        AvlNode<E> initLeft = left;
826        if (initLeft == null) {
827          result[0] = 0;
828          if (expectedCount == 0 && newCount > 0) {
829            return addLeftChild(e, newCount);
830          }
831          return this;
832        }
833
834        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
835
836        if (result[0] == expectedCount) {
837          if (newCount == 0 && result[0] != 0) {
838            this.distinctElements--;
839          } else if (newCount > 0 && result[0] == 0) {
840            this.distinctElements++;
841          }
842          this.totalCount += newCount - result[0];
843        }
844        return rebalance();
845      } else if (cmp > 0) {
846        AvlNode<E> initRight = right;
847        if (initRight == null) {
848          result[0] = 0;
849          if (expectedCount == 0 && newCount > 0) {
850            return addRightChild(e, newCount);
851          }
852          return this;
853        }
854
855        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
856
857        if (result[0] == expectedCount) {
858          if (newCount == 0 && result[0] != 0) {
859            this.distinctElements--;
860          } else if (newCount > 0 && result[0] == 0) {
861            this.distinctElements++;
862          }
863          this.totalCount += newCount - result[0];
864        }
865        return rebalance();
866      }
867
868      // setting my count
869      result[0] = elemCount;
870      if (expectedCount == elemCount) {
871        if (newCount == 0) {
872          return deleteMe();
873        }
874        this.totalCount += newCount - elemCount;
875        this.elemCount = newCount;
876      }
877      return this;
878    }
879
880    @CheckForNull
881    private AvlNode<E> deleteMe() {
882      int oldElemCount = this.elemCount;
883      this.elemCount = 0;
884      successor(pred(), succ());
885      if (left == null) {
886        return right;
887      } else if (right == null) {
888        return left;
889      } else if (left.height >= right.height) {
890        AvlNode<E> newTop = pred();
891        // newTop is the maximum node in my left subtree
892        newTop.left = left.removeMax(newTop);
893        newTop.right = right;
894        newTop.distinctElements = distinctElements - 1;
895        newTop.totalCount = totalCount - oldElemCount;
896        return newTop.rebalance();
897      } else {
898        AvlNode<E> newTop = succ();
899        newTop.right = right.removeMin(newTop);
900        newTop.left = left;
901        newTop.distinctElements = distinctElements - 1;
902        newTop.totalCount = totalCount - oldElemCount;
903        return newTop.rebalance();
904      }
905    }
906
907    // Removes the minimum node from this subtree to be reused elsewhere
908    @CheckForNull
909    private AvlNode<E> removeMin(AvlNode<E> node) {
910      if (left == null) {
911        return right;
912      } else {
913        left = left.removeMin(node);
914        distinctElements--;
915        totalCount -= node.elemCount;
916        return rebalance();
917      }
918    }
919
920    // Removes the maximum node from this subtree to be reused elsewhere
921    @CheckForNull
922    private AvlNode<E> removeMax(AvlNode<E> node) {
923      if (right == null) {
924        return left;
925      } else {
926        right = right.removeMax(node);
927        distinctElements--;
928        totalCount -= node.elemCount;
929        return rebalance();
930      }
931    }
932
933    private void recomputeMultiset() {
934      this.distinctElements =
935          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
936      this.totalCount = elemCount + totalCount(left) + totalCount(right);
937    }
938
939    private void recomputeHeight() {
940      this.height = 1 + Math.max(height(left), height(right));
941    }
942
943    private void recompute() {
944      recomputeMultiset();
945      recomputeHeight();
946    }
947
948    private AvlNode<E> rebalance() {
949      switch (balanceFactor()) {
950        case -2:
951          // requireNonNull is safe because right must exist in order to get a negative factor.
952          requireNonNull(right);
953          if (right.balanceFactor() > 0) {
954            right = right.rotateRight();
955          }
956          return rotateLeft();
957        case 2:
958          // requireNonNull is safe because left must exist in order to get a positive factor.
959          requireNonNull(left);
960          if (left.balanceFactor() < 0) {
961            left = left.rotateLeft();
962          }
963          return rotateRight();
964        default:
965          recomputeHeight();
966          return this;
967      }
968    }
969
970    private int balanceFactor() {
971      return height(left) - height(right);
972    }
973
974    private AvlNode<E> rotateLeft() {
975      checkState(right != null);
976      AvlNode<E> newTop = right;
977      this.right = newTop.left;
978      newTop.left = this;
979      newTop.totalCount = this.totalCount;
980      newTop.distinctElements = this.distinctElements;
981      this.recompute();
982      newTop.recomputeHeight();
983      return newTop;
984    }
985
986    private AvlNode<E> rotateRight() {
987      checkState(left != null);
988      AvlNode<E> newTop = left;
989      this.left = newTop.right;
990      newTop.right = this;
991      newTop.totalCount = this.totalCount;
992      newTop.distinctElements = this.distinctElements;
993      this.recompute();
994      newTop.recomputeHeight();
995      return newTop;
996    }
997
998    private static long totalCount(@CheckForNull AvlNode<?> node) {
999      return (node == null) ? 0 : node.totalCount;
1000    }
1001
1002    private static int height(@CheckForNull AvlNode<?> node) {
1003      return (node == null) ? 0 : node.height;
1004    }
1005
1006    @CheckForNull
1007    private AvlNode<E> ceiling(Comparator<? super E> comparator, @ParametricNullness E e) {
1008      int cmp = comparator.compare(e, getElement());
1009      if (cmp < 0) {
1010        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
1011      } else if (cmp == 0) {
1012        return this;
1013      } else {
1014        return (right == null) ? null : right.ceiling(comparator, e);
1015      }
1016    }
1017
1018    @CheckForNull
1019    private AvlNode<E> floor(Comparator<? super E> comparator, @ParametricNullness E e) {
1020      int cmp = comparator.compare(e, getElement());
1021      if (cmp > 0) {
1022        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
1023      } else if (cmp == 0) {
1024        return this;
1025      } else {
1026        return (left == null) ? null : left.floor(comparator, e);
1027      }
1028    }
1029
1030    @ParametricNullness
1031    E getElement() {
1032      // For discussion of this cast, see the comment on the elem field.
1033      return uncheckedCastNullableTToT(elem);
1034    }
1035
1036    int getCount() {
1037      return elemCount;
1038    }
1039
1040    @Override
1041    public String toString() {
1042      return Multisets.immutableEntry(getElement(), getCount()).toString();
1043    }
1044  }
1045
1046  private static <T extends @Nullable Object> void successor(AvlNode<T> a, AvlNode<T> b) {
1047    a.succ = b;
1048    b.pred = a;
1049  }
1050
1051  private static <T extends @Nullable Object> void successor(
1052      AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
1053    successor(a, b);
1054    successor(b, c);
1055  }
1056
1057  /*
1058   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
1059   * calls the comparator to compare the two keys. If that change is made,
1060   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
1061   */
1062
1063  /**
1064   * @serialData the comparator, the number of distinct elements, the first element, its count, the
1065   *     second element, its count, and so on
1066   */
1067  @J2ktIncompatible
1068  @GwtIncompatible // java.io.ObjectOutputStream
1069  private void writeObject(ObjectOutputStream stream) throws IOException {
1070    stream.defaultWriteObject();
1071    stream.writeObject(elementSet().comparator());
1072    Serialization.writeMultiset(this, stream);
1073  }
1074
1075  @J2ktIncompatible
1076  @GwtIncompatible // java.io.ObjectInputStream
1077  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
1078    stream.defaultReadObject();
1079    @SuppressWarnings("unchecked")
1080    // reading data stored by writeObject
1081    Comparator<? super E> comparator = (Comparator<? super E>) requireNonNull(stream.readObject());
1082    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
1083    Serialization.getFieldSetter(TreeMultiset.class, "range")
1084        .set(this, GeneralRange.all(comparator));
1085    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
1086        .set(this, new Reference<AvlNode<E>>());
1087    AvlNode<E> header = new AvlNode<>();
1088    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
1089    successor(header, header);
1090    Serialization.populateMultiset(this, stream);
1091  }
1092
1093  @GwtIncompatible // not needed in emulated source
1094  @J2ktIncompatible
1095  private static final long serialVersionUID = 1;
1096}