001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkNotNull;
021import static com.google.common.base.Preconditions.checkState;
022import static com.google.common.collect.CollectPreconditions.checkNonnegative;
023import static com.google.common.collect.NullnessCasts.uncheckedCastNullableTToT;
024import static java.util.Objects.requireNonNull;
025
026import com.google.common.annotations.GwtCompatible;
027import com.google.common.annotations.GwtIncompatible;
028import com.google.common.annotations.J2ktIncompatible;
029import com.google.common.base.MoreObjects;
030import com.google.common.primitives.Ints;
031import com.google.errorprone.annotations.CanIgnoreReturnValue;
032import java.io.IOException;
033import java.io.ObjectInputStream;
034import java.io.ObjectOutputStream;
035import java.io.Serializable;
036import java.util.Comparator;
037import java.util.ConcurrentModificationException;
038import java.util.Iterator;
039import java.util.NoSuchElementException;
040import java.util.function.ObjIntConsumer;
041import javax.annotation.CheckForNull;
042import org.checkerframework.checker.nullness.qual.Nullable;
043
044/**
045 * A multiset which maintains the ordering of its elements, according to either their natural order
046 * or an explicit {@link Comparator}. In all cases, this implementation uses {@link
047 * Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to determine
048 * equivalence of instances.
049 *
050 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
051 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the {@link
052 * java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
053 *
054 * <p>See the Guava User Guide article on <a href=
055 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset">{@code Multiset}</a>.
056 *
057 * @author Louis Wasserman
058 * @author Jared Levy
059 * @since 2.0
060 */
061@GwtCompatible(emulated = true)
062@ElementTypesAreNonnullByDefault
063public final class TreeMultiset<E extends @Nullable Object> extends AbstractSortedMultiset<E>
064    implements Serializable {
065
066  /**
067   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
068   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
069   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
070   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
071   * user attempts to add an element to the multiset that violates this constraint (for example, the
072   * user attempts to add a string element to a set whose elements are integers), the {@code
073   * add(Object)} call will throw a {@code ClassCastException}.
074   *
075   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
076   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
077   */
078  public static <E extends Comparable> TreeMultiset<E> create() {
079    return new TreeMultiset<E>(Ordering.natural());
080  }
081
082  /**
083   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
084   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
085   * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements
086   * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the
087   * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code
088   * ClassCastException}.
089   *
090   * @param comparator the comparator that will be used to sort this multiset. A null value
091   *     indicates that the elements' <i>natural ordering</i> should be used.
092   */
093  @SuppressWarnings("unchecked")
094  public static <E extends @Nullable Object> TreeMultiset<E> create(
095      @CheckForNull Comparator<? super E> comparator) {
096    return (comparator == null)
097        ? new TreeMultiset<E>((Comparator) Ordering.natural())
098        : new TreeMultiset<E>(comparator);
099  }
100
101  /**
102   * Creates an empty multiset containing the given initial elements, sorted according to the
103   * elements' natural order.
104   *
105   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
106   *
107   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
108   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
109   */
110  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
111    TreeMultiset<E> multiset = create();
112    Iterables.addAll(multiset, elements);
113    return multiset;
114  }
115
116  private final transient Reference<AvlNode<E>> rootReference;
117  private final transient GeneralRange<E> range;
118  private final transient AvlNode<E> header;
119
120  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
121    super(range.comparator());
122    this.rootReference = rootReference;
123    this.range = range;
124    this.header = endLink;
125  }
126
127  TreeMultiset(Comparator<? super E> comparator) {
128    super(comparator);
129    this.range = GeneralRange.all(comparator);
130    this.header = new AvlNode<>();
131    successor(header, header);
132    this.rootReference = new Reference<>();
133  }
134
135  /** A function which can be summed across a subtree. */
136  private enum Aggregate {
137    SIZE {
138      @Override
139      int nodeAggregate(AvlNode<?> node) {
140        return node.elemCount;
141      }
142
143      @Override
144      long treeAggregate(@CheckForNull AvlNode<?> root) {
145        return (root == null) ? 0 : root.totalCount;
146      }
147    },
148    DISTINCT {
149      @Override
150      int nodeAggregate(AvlNode<?> node) {
151        return 1;
152      }
153
154      @Override
155      long treeAggregate(@CheckForNull AvlNode<?> root) {
156        return (root == null) ? 0 : root.distinctElements;
157      }
158    };
159
160    abstract int nodeAggregate(AvlNode<?> node);
161
162    abstract long treeAggregate(@CheckForNull AvlNode<?> root);
163  }
164
165  private long aggregateForEntries(Aggregate aggr) {
166    AvlNode<E> root = rootReference.get();
167    long total = aggr.treeAggregate(root);
168    if (range.hasLowerBound()) {
169      total -= aggregateBelowRange(aggr, root);
170    }
171    if (range.hasUpperBound()) {
172      total -= aggregateAboveRange(aggr, root);
173    }
174    return total;
175  }
176
177  private long aggregateBelowRange(Aggregate aggr, @CheckForNull AvlNode<E> node) {
178    if (node == null) {
179      return 0;
180    }
181    // The cast is safe because we call this method only if hasLowerBound().
182    int cmp =
183        comparator()
184            .compare(uncheckedCastNullableTToT(range.getLowerEndpoint()), node.getElement());
185    if (cmp < 0) {
186      return aggregateBelowRange(aggr, node.left);
187    } else if (cmp == 0) {
188      switch (range.getLowerBoundType()) {
189        case OPEN:
190          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
191        case CLOSED:
192          return aggr.treeAggregate(node.left);
193        default:
194          throw new AssertionError();
195      }
196    } else {
197      return aggr.treeAggregate(node.left)
198          + aggr.nodeAggregate(node)
199          + aggregateBelowRange(aggr, node.right);
200    }
201  }
202
203  private long aggregateAboveRange(Aggregate aggr, @CheckForNull AvlNode<E> node) {
204    if (node == null) {
205      return 0;
206    }
207    // The cast is safe because we call this method only if hasUpperBound().
208    int cmp =
209        comparator()
210            .compare(uncheckedCastNullableTToT(range.getUpperEndpoint()), node.getElement());
211    if (cmp > 0) {
212      return aggregateAboveRange(aggr, node.right);
213    } else if (cmp == 0) {
214      switch (range.getUpperBoundType()) {
215        case OPEN:
216          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
217        case CLOSED:
218          return aggr.treeAggregate(node.right);
219        default:
220          throw new AssertionError();
221      }
222    } else {
223      return aggr.treeAggregate(node.right)
224          + aggr.nodeAggregate(node)
225          + aggregateAboveRange(aggr, node.left);
226    }
227  }
228
229  @Override
230  public int size() {
231    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
232  }
233
234  @Override
235  int distinctElements() {
236    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
237  }
238
239  static int distinctElements(@CheckForNull AvlNode<?> node) {
240    return (node == null) ? 0 : node.distinctElements;
241  }
242
243  @Override
244  public int count(@CheckForNull Object element) {
245    try {
246      @SuppressWarnings("unchecked")
247      E e = (E) element;
248      AvlNode<E> root = rootReference.get();
249      if (!range.contains(e) || root == null) {
250        return 0;
251      }
252      return root.count(comparator(), e);
253    } catch (ClassCastException | NullPointerException e) {
254      return 0;
255    }
256  }
257
258  @CanIgnoreReturnValue
259  @Override
260  public int add(@ParametricNullness E element, int occurrences) {
261    checkNonnegative(occurrences, "occurrences");
262    if (occurrences == 0) {
263      return count(element);
264    }
265    checkArgument(range.contains(element));
266    AvlNode<E> root = rootReference.get();
267    if (root == null) {
268      int unused = comparator().compare(element, element);
269      AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
270      successor(header, newRoot, header);
271      rootReference.checkAndSet(root, newRoot);
272      return 0;
273    }
274    int[] result = new int[1]; // used as a mutable int reference to hold result
275    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
276    rootReference.checkAndSet(root, newRoot);
277    return result[0];
278  }
279
280  @CanIgnoreReturnValue
281  @Override
282  public int remove(@CheckForNull Object element, int occurrences) {
283    checkNonnegative(occurrences, "occurrences");
284    if (occurrences == 0) {
285      return count(element);
286    }
287    AvlNode<E> root = rootReference.get();
288    int[] result = new int[1]; // used as a mutable int reference to hold result
289    AvlNode<E> newRoot;
290    try {
291      @SuppressWarnings("unchecked")
292      E e = (E) element;
293      if (!range.contains(e) || root == null) {
294        return 0;
295      }
296      newRoot = root.remove(comparator(), e, occurrences, result);
297    } catch (ClassCastException | NullPointerException e) {
298      return 0;
299    }
300    rootReference.checkAndSet(root, newRoot);
301    return result[0];
302  }
303
304  @CanIgnoreReturnValue
305  @Override
306  public int setCount(@ParametricNullness E element, int count) {
307    checkNonnegative(count, "count");
308    if (!range.contains(element)) {
309      checkArgument(count == 0);
310      return 0;
311    }
312
313    AvlNode<E> root = rootReference.get();
314    if (root == null) {
315      if (count > 0) {
316        add(element, count);
317      }
318      return 0;
319    }
320    int[] result = new int[1]; // used as a mutable int reference to hold result
321    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
322    rootReference.checkAndSet(root, newRoot);
323    return result[0];
324  }
325
326  @CanIgnoreReturnValue
327  @Override
328  public boolean setCount(@ParametricNullness E element, int oldCount, int newCount) {
329    checkNonnegative(newCount, "newCount");
330    checkNonnegative(oldCount, "oldCount");
331    checkArgument(range.contains(element));
332
333    AvlNode<E> root = rootReference.get();
334    if (root == null) {
335      if (oldCount == 0) {
336        if (newCount > 0) {
337          add(element, newCount);
338        }
339        return true;
340      } else {
341        return false;
342      }
343    }
344    int[] result = new int[1]; // used as a mutable int reference to hold result
345    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
346    rootReference.checkAndSet(root, newRoot);
347    return result[0] == oldCount;
348  }
349
350  @Override
351  public void clear() {
352    if (!range.hasLowerBound() && !range.hasUpperBound()) {
353      // We can do this in O(n) rather than removing one by one, which could force rebalancing.
354      for (AvlNode<E> current = header.succ(); current != header; ) {
355        AvlNode<E> next = current.succ();
356
357        current.elemCount = 0;
358        // Also clear these fields so that one deleted Entry doesn't retain all elements.
359        current.left = null;
360        current.right = null;
361        current.pred = null;
362        current.succ = null;
363
364        current = next;
365      }
366      successor(header, header);
367      rootReference.clear();
368    } else {
369      // TODO(cpovirk): Perhaps we can optimize in this case, too?
370      Iterators.clear(entryIterator());
371    }
372  }
373
374  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
375    return new Multisets.AbstractEntry<E>() {
376      @Override
377      @ParametricNullness
378      public E getElement() {
379        return baseEntry.getElement();
380      }
381
382      @Override
383      public int getCount() {
384        int result = baseEntry.getCount();
385        if (result == 0) {
386          return count(getElement());
387        } else {
388          return result;
389        }
390      }
391    };
392  }
393
394  /** Returns the first node in the tree that is in range. */
395  @CheckForNull
396  private AvlNode<E> firstNode() {
397    AvlNode<E> root = rootReference.get();
398    if (root == null) {
399      return null;
400    }
401    AvlNode<E> node;
402    if (range.hasLowerBound()) {
403      // The cast is safe because of the hasLowerBound check.
404      E endpoint = uncheckedCastNullableTToT(range.getLowerEndpoint());
405      node = root.ceiling(comparator(), endpoint);
406      if (node == null) {
407        return null;
408      }
409      if (range.getLowerBoundType() == BoundType.OPEN
410          && comparator().compare(endpoint, node.getElement()) == 0) {
411        node = node.succ();
412      }
413    } else {
414      node = header.succ();
415    }
416    return (node == header || !range.contains(node.getElement())) ? null : node;
417  }
418
419  @CheckForNull
420  private AvlNode<E> lastNode() {
421    AvlNode<E> root = rootReference.get();
422    if (root == null) {
423      return null;
424    }
425    AvlNode<E> node;
426    if (range.hasUpperBound()) {
427      // The cast is safe because of the hasUpperBound check.
428      E endpoint = uncheckedCastNullableTToT(range.getUpperEndpoint());
429      node = root.floor(comparator(), endpoint);
430      if (node == null) {
431        return null;
432      }
433      if (range.getUpperBoundType() == BoundType.OPEN
434          && comparator().compare(endpoint, node.getElement()) == 0) {
435        node = node.pred();
436      }
437    } else {
438      node = header.pred();
439    }
440    return (node == header || !range.contains(node.getElement())) ? null : node;
441  }
442
443  @Override
444  Iterator<E> elementIterator() {
445    return Multisets.elementIterator(entryIterator());
446  }
447
448  @Override
449  Iterator<Entry<E>> entryIterator() {
450    return new Iterator<Entry<E>>() {
451      @CheckForNull AvlNode<E> current = firstNode();
452      @CheckForNull Entry<E> prevEntry;
453
454      @Override
455      public boolean hasNext() {
456        if (current == null) {
457          return false;
458        } else if (range.tooHigh(current.getElement())) {
459          current = null;
460          return false;
461        } else {
462          return true;
463        }
464      }
465
466      @Override
467      public Entry<E> next() {
468        if (!hasNext()) {
469          throw new NoSuchElementException();
470        }
471        // requireNonNull is safe because current is only nulled out after iteration is complete.
472        Entry<E> result = wrapEntry(requireNonNull(current));
473        prevEntry = result;
474        if (current.succ() == header) {
475          current = null;
476        } else {
477          current = current.succ();
478        }
479        return result;
480      }
481
482      @Override
483      public void remove() {
484        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
485        setCount(prevEntry.getElement(), 0);
486        prevEntry = null;
487      }
488    };
489  }
490
491  @Override
492  Iterator<Entry<E>> descendingEntryIterator() {
493    return new Iterator<Entry<E>>() {
494      @CheckForNull AvlNode<E> current = lastNode();
495      @CheckForNull Entry<E> prevEntry = null;
496
497      @Override
498      public boolean hasNext() {
499        if (current == null) {
500          return false;
501        } else if (range.tooLow(current.getElement())) {
502          current = null;
503          return false;
504        } else {
505          return true;
506        }
507      }
508
509      @Override
510      public Entry<E> next() {
511        if (!hasNext()) {
512          throw new NoSuchElementException();
513        }
514        // requireNonNull is safe because current is only nulled out after iteration is complete.
515        requireNonNull(current);
516        Entry<E> result = wrapEntry(current);
517        prevEntry = result;
518        if (current.pred() == header) {
519          current = null;
520        } else {
521          current = current.pred();
522        }
523        return result;
524      }
525
526      @Override
527      public void remove() {
528        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
529        setCount(prevEntry.getElement(), 0);
530        prevEntry = null;
531      }
532    };
533  }
534
535  @Override
536  public void forEachEntry(ObjIntConsumer<? super E> action) {
537    checkNotNull(action);
538    for (AvlNode<E> node = firstNode();
539        node != header && node != null && !range.tooHigh(node.getElement());
540        node = node.succ()) {
541      action.accept(node.getElement(), node.getCount());
542    }
543  }
544
545  @Override
546  public Iterator<E> iterator() {
547    return Multisets.iteratorImpl(this);
548  }
549
550  @Override
551  public SortedMultiset<E> headMultiset(@ParametricNullness E upperBound, BoundType boundType) {
552    return new TreeMultiset<E>(
553        rootReference,
554        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
555        header);
556  }
557
558  @Override
559  public SortedMultiset<E> tailMultiset(@ParametricNullness E lowerBound, BoundType boundType) {
560    return new TreeMultiset<E>(
561        rootReference,
562        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
563        header);
564  }
565
566  private static final class Reference<T> {
567    @CheckForNull private T value;
568
569    @CheckForNull
570    public T get() {
571      return value;
572    }
573
574    public void checkAndSet(@CheckForNull T expected, @CheckForNull T newValue) {
575      if (value != expected) {
576        throw new ConcurrentModificationException();
577      }
578      value = newValue;
579    }
580
581    void clear() {
582      value = null;
583    }
584  }
585
586  private static final class AvlNode<E extends @Nullable Object> {
587    /*
588     * For "normal" nodes, the type of this field is `E`, not `@Nullable E` (though note that E is a
589     * type that can include null, as in a TreeMultiset<@Nullable String>).
590     *
591     * For the header node, though, this field contains `null`, regardless of the type of the
592     * multiset.
593     *
594     * Most code that operates on an AvlNode never operates on the header node. Such code can access
595     * the elem field without a null check by calling getElement().
596     */
597    @CheckForNull private final E elem;
598
599    // elemCount is 0 iff this node has been deleted.
600    private int elemCount;
601
602    private int distinctElements;
603    private long totalCount;
604    private int height;
605    @CheckForNull private AvlNode<E> left;
606    @CheckForNull private AvlNode<E> right;
607    /*
608     * pred and succ are nullable after construction, but we always call successor() to initialize
609     * them immediately thereafter.
610     *
611     * They may be subsequently nulled out by TreeMultiset.clear(). I think that the only place that
612     * we can reference a node whose fields have been cleared is inside the iterator (and presumably
613     * only under concurrent modification).
614     *
615     * To access these fields when you know that they are not null, call the pred() and succ()
616     * methods, which perform null checks before returning the fields.
617     */
618    @CheckForNull private AvlNode<E> pred;
619    @CheckForNull private AvlNode<E> succ;
620
621    AvlNode(@ParametricNullness E elem, int elemCount) {
622      checkArgument(elemCount > 0);
623      this.elem = elem;
624      this.elemCount = elemCount;
625      this.totalCount = elemCount;
626      this.distinctElements = 1;
627      this.height = 1;
628      this.left = null;
629      this.right = null;
630    }
631
632    /** Constructor for the header node. */
633    AvlNode() {
634      this.elem = null;
635      this.elemCount = 1;
636    }
637
638    // For discussion of pred() and succ(), see the comment on the pred and succ fields.
639
640    private AvlNode<E> pred() {
641      return requireNonNull(pred);
642    }
643
644    private AvlNode<E> succ() {
645      return requireNonNull(succ);
646    }
647
648    int count(Comparator<? super E> comparator, @ParametricNullness E e) {
649      int cmp = comparator.compare(e, getElement());
650      if (cmp < 0) {
651        return (left == null) ? 0 : left.count(comparator, e);
652      } else if (cmp > 0) {
653        return (right == null) ? 0 : right.count(comparator, e);
654      } else {
655        return elemCount;
656      }
657    }
658
659    private AvlNode<E> addRightChild(@ParametricNullness E e, int count) {
660      right = new AvlNode<E>(e, count);
661      successor(this, right, succ());
662      height = Math.max(2, height);
663      distinctElements++;
664      totalCount += count;
665      return this;
666    }
667
668    private AvlNode<E> addLeftChild(@ParametricNullness E e, int count) {
669      left = new AvlNode<E>(e, count);
670      successor(pred(), left, this);
671      height = Math.max(2, height);
672      distinctElements++;
673      totalCount += count;
674      return this;
675    }
676
677    AvlNode<E> add(
678        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
679      /*
680       * It speeds things up considerably to unconditionally add count to totalCount here,
681       * but that destroys failure atomicity in the case of count overflow. =(
682       */
683      int cmp = comparator.compare(e, getElement());
684      if (cmp < 0) {
685        AvlNode<E> initLeft = left;
686        if (initLeft == null) {
687          result[0] = 0;
688          return addLeftChild(e, count);
689        }
690        int initHeight = initLeft.height;
691
692        left = initLeft.add(comparator, e, count, result);
693        if (result[0] == 0) {
694          distinctElements++;
695        }
696        this.totalCount += count;
697        return (left.height == initHeight) ? this : rebalance();
698      } else if (cmp > 0) {
699        AvlNode<E> initRight = right;
700        if (initRight == null) {
701          result[0] = 0;
702          return addRightChild(e, count);
703        }
704        int initHeight = initRight.height;
705
706        right = initRight.add(comparator, e, count, result);
707        if (result[0] == 0) {
708          distinctElements++;
709        }
710        this.totalCount += count;
711        return (right.height == initHeight) ? this : rebalance();
712      }
713
714      // adding count to me!  No rebalance possible.
715      result[0] = elemCount;
716      long resultCount = (long) elemCount + count;
717      checkArgument(resultCount <= Integer.MAX_VALUE);
718      this.elemCount += count;
719      this.totalCount += count;
720      return this;
721    }
722
723    @CheckForNull
724    AvlNode<E> remove(
725        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
726      int cmp = comparator.compare(e, getElement());
727      if (cmp < 0) {
728        AvlNode<E> initLeft = left;
729        if (initLeft == null) {
730          result[0] = 0;
731          return this;
732        }
733
734        left = initLeft.remove(comparator, e, count, result);
735
736        if (result[0] > 0) {
737          if (count >= result[0]) {
738            this.distinctElements--;
739            this.totalCount -= result[0];
740          } else {
741            this.totalCount -= count;
742          }
743        }
744        return (result[0] == 0) ? this : rebalance();
745      } else if (cmp > 0) {
746        AvlNode<E> initRight = right;
747        if (initRight == null) {
748          result[0] = 0;
749          return this;
750        }
751
752        right = initRight.remove(comparator, e, count, result);
753
754        if (result[0] > 0) {
755          if (count >= result[0]) {
756            this.distinctElements--;
757            this.totalCount -= result[0];
758          } else {
759            this.totalCount -= count;
760          }
761        }
762        return rebalance();
763      }
764
765      // removing count from me!
766      result[0] = elemCount;
767      if (count >= elemCount) {
768        return deleteMe();
769      } else {
770        this.elemCount -= count;
771        this.totalCount -= count;
772        return this;
773      }
774    }
775
776    @CheckForNull
777    AvlNode<E> setCount(
778        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
779      int cmp = comparator.compare(e, getElement());
780      if (cmp < 0) {
781        AvlNode<E> initLeft = left;
782        if (initLeft == null) {
783          result[0] = 0;
784          return (count > 0) ? addLeftChild(e, count) : this;
785        }
786
787        left = initLeft.setCount(comparator, e, count, result);
788
789        if (count == 0 && result[0] != 0) {
790          this.distinctElements--;
791        } else if (count > 0 && result[0] == 0) {
792          this.distinctElements++;
793        }
794
795        this.totalCount += count - result[0];
796        return rebalance();
797      } else if (cmp > 0) {
798        AvlNode<E> initRight = right;
799        if (initRight == null) {
800          result[0] = 0;
801          return (count > 0) ? addRightChild(e, count) : this;
802        }
803
804        right = initRight.setCount(comparator, e, count, result);
805
806        if (count == 0 && result[0] != 0) {
807          this.distinctElements--;
808        } else if (count > 0 && result[0] == 0) {
809          this.distinctElements++;
810        }
811
812        this.totalCount += count - result[0];
813        return rebalance();
814      }
815
816      // setting my count
817      result[0] = elemCount;
818      if (count == 0) {
819        return deleteMe();
820      }
821      this.totalCount += count - elemCount;
822      this.elemCount = count;
823      return this;
824    }
825
826    @CheckForNull
827    AvlNode<E> setCount(
828        Comparator<? super E> comparator,
829        @ParametricNullness E e,
830        int expectedCount,
831        int newCount,
832        int[] result) {
833      int cmp = comparator.compare(e, getElement());
834      if (cmp < 0) {
835        AvlNode<E> initLeft = left;
836        if (initLeft == null) {
837          result[0] = 0;
838          if (expectedCount == 0 && newCount > 0) {
839            return addLeftChild(e, newCount);
840          }
841          return this;
842        }
843
844        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
845
846        if (result[0] == expectedCount) {
847          if (newCount == 0 && result[0] != 0) {
848            this.distinctElements--;
849          } else if (newCount > 0 && result[0] == 0) {
850            this.distinctElements++;
851          }
852          this.totalCount += newCount - result[0];
853        }
854        return rebalance();
855      } else if (cmp > 0) {
856        AvlNode<E> initRight = right;
857        if (initRight == null) {
858          result[0] = 0;
859          if (expectedCount == 0 && newCount > 0) {
860            return addRightChild(e, newCount);
861          }
862          return this;
863        }
864
865        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
866
867        if (result[0] == expectedCount) {
868          if (newCount == 0 && result[0] != 0) {
869            this.distinctElements--;
870          } else if (newCount > 0 && result[0] == 0) {
871            this.distinctElements++;
872          }
873          this.totalCount += newCount - result[0];
874        }
875        return rebalance();
876      }
877
878      // setting my count
879      result[0] = elemCount;
880      if (expectedCount == elemCount) {
881        if (newCount == 0) {
882          return deleteMe();
883        }
884        this.totalCount += newCount - elemCount;
885        this.elemCount = newCount;
886      }
887      return this;
888    }
889
890    @CheckForNull
891    private AvlNode<E> deleteMe() {
892      int oldElemCount = this.elemCount;
893      this.elemCount = 0;
894      successor(pred(), succ());
895      if (left == null) {
896        return right;
897      } else if (right == null) {
898        return left;
899      } else if (left.height >= right.height) {
900        AvlNode<E> newTop = pred();
901        // newTop is the maximum node in my left subtree
902        newTop.left = left.removeMax(newTop);
903        newTop.right = right;
904        newTop.distinctElements = distinctElements - 1;
905        newTop.totalCount = totalCount - oldElemCount;
906        return newTop.rebalance();
907      } else {
908        AvlNode<E> newTop = succ();
909        newTop.right = right.removeMin(newTop);
910        newTop.left = left;
911        newTop.distinctElements = distinctElements - 1;
912        newTop.totalCount = totalCount - oldElemCount;
913        return newTop.rebalance();
914      }
915    }
916
917    // Removes the minimum node from this subtree to be reused elsewhere
918    @CheckForNull
919    private AvlNode<E> removeMin(AvlNode<E> node) {
920      if (left == null) {
921        return right;
922      } else {
923        left = left.removeMin(node);
924        distinctElements--;
925        totalCount -= node.elemCount;
926        return rebalance();
927      }
928    }
929
930    // Removes the maximum node from this subtree to be reused elsewhere
931    @CheckForNull
932    private AvlNode<E> removeMax(AvlNode<E> node) {
933      if (right == null) {
934        return left;
935      } else {
936        right = right.removeMax(node);
937        distinctElements--;
938        totalCount -= node.elemCount;
939        return rebalance();
940      }
941    }
942
943    private void recomputeMultiset() {
944      this.distinctElements =
945          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
946      this.totalCount = elemCount + totalCount(left) + totalCount(right);
947    }
948
949    private void recomputeHeight() {
950      this.height = 1 + Math.max(height(left), height(right));
951    }
952
953    private void recompute() {
954      recomputeMultiset();
955      recomputeHeight();
956    }
957
958    private AvlNode<E> rebalance() {
959      switch (balanceFactor()) {
960        case -2:
961          // requireNonNull is safe because right must exist in order to get a negative factor.
962          requireNonNull(right);
963          if (right.balanceFactor() > 0) {
964            right = right.rotateRight();
965          }
966          return rotateLeft();
967        case 2:
968          // requireNonNull is safe because left must exist in order to get a positive factor.
969          requireNonNull(left);
970          if (left.balanceFactor() < 0) {
971            left = left.rotateLeft();
972          }
973          return rotateRight();
974        default:
975          recomputeHeight();
976          return this;
977      }
978    }
979
980    private int balanceFactor() {
981      return height(left) - height(right);
982    }
983
984    private AvlNode<E> rotateLeft() {
985      checkState(right != null);
986      AvlNode<E> newTop = right;
987      this.right = newTop.left;
988      newTop.left = this;
989      newTop.totalCount = this.totalCount;
990      newTop.distinctElements = this.distinctElements;
991      this.recompute();
992      newTop.recomputeHeight();
993      return newTop;
994    }
995
996    private AvlNode<E> rotateRight() {
997      checkState(left != null);
998      AvlNode<E> newTop = left;
999      this.left = newTop.right;
1000      newTop.right = this;
1001      newTop.totalCount = this.totalCount;
1002      newTop.distinctElements = this.distinctElements;
1003      this.recompute();
1004      newTop.recomputeHeight();
1005      return newTop;
1006    }
1007
1008    private static long totalCount(@CheckForNull AvlNode<?> node) {
1009      return (node == null) ? 0 : node.totalCount;
1010    }
1011
1012    private static int height(@CheckForNull AvlNode<?> node) {
1013      return (node == null) ? 0 : node.height;
1014    }
1015
1016    @CheckForNull
1017    private AvlNode<E> ceiling(Comparator<? super E> comparator, @ParametricNullness E e) {
1018      int cmp = comparator.compare(e, getElement());
1019      if (cmp < 0) {
1020        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
1021      } else if (cmp == 0) {
1022        return this;
1023      } else {
1024        return (right == null) ? null : right.ceiling(comparator, e);
1025      }
1026    }
1027
1028    @CheckForNull
1029    private AvlNode<E> floor(Comparator<? super E> comparator, @ParametricNullness E e) {
1030      int cmp = comparator.compare(e, getElement());
1031      if (cmp > 0) {
1032        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
1033      } else if (cmp == 0) {
1034        return this;
1035      } else {
1036        return (left == null) ? null : left.floor(comparator, e);
1037      }
1038    }
1039
1040    @ParametricNullness
1041    E getElement() {
1042      // For discussion of this cast, see the comment on the elem field.
1043      return uncheckedCastNullableTToT(elem);
1044    }
1045
1046    int getCount() {
1047      return elemCount;
1048    }
1049
1050    @Override
1051    public String toString() {
1052      return Multisets.immutableEntry(getElement(), getCount()).toString();
1053    }
1054  }
1055
1056  private static <T extends @Nullable Object> void successor(AvlNode<T> a, AvlNode<T> b) {
1057    a.succ = b;
1058    b.pred = a;
1059  }
1060
1061  private static <T extends @Nullable Object> void successor(
1062      AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
1063    successor(a, b);
1064    successor(b, c);
1065  }
1066
1067  /*
1068   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
1069   * calls the comparator to compare the two keys. If that change is made,
1070   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
1071   */
1072
1073  /**
1074   * @serialData the comparator, the number of distinct elements, the first element, its count, the
1075   *     second element, its count, and so on
1076   */
1077  @J2ktIncompatible
1078  @GwtIncompatible // java.io.ObjectOutputStream
1079  private void writeObject(ObjectOutputStream stream) throws IOException {
1080    stream.defaultWriteObject();
1081    stream.writeObject(elementSet().comparator());
1082    Serialization.writeMultiset(this, stream);
1083  }
1084
1085  @J2ktIncompatible
1086  @GwtIncompatible // java.io.ObjectInputStream
1087  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
1088    stream.defaultReadObject();
1089    @SuppressWarnings("unchecked")
1090    // reading data stored by writeObject
1091    Comparator<? super E> comparator = (Comparator<? super E>) requireNonNull(stream.readObject());
1092    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
1093    Serialization.getFieldSetter(TreeMultiset.class, "range")
1094        .set(this, GeneralRange.all(comparator));
1095    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
1096        .set(this, new Reference<AvlNode<E>>());
1097    AvlNode<E> header = new AvlNode<>();
1098    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
1099    successor(header, header);
1100    Serialization.populateMultiset(this, stream);
1101  }
1102
1103  @GwtIncompatible // not needed in emulated source
1104  @J2ktIncompatible
1105  private static final long serialVersionUID = 1;
1106}