001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkNotNull;
021import static com.google.common.base.Preconditions.checkState;
022import static com.google.common.collect.CollectPreconditions.checkNonnegative;
023import static com.google.common.collect.NullnessCasts.uncheckedCastNullableTToT;
024import static java.lang.Math.max;
025import static java.util.Objects.requireNonNull;
026
027import com.google.common.annotations.GwtCompatible;
028import com.google.common.annotations.GwtIncompatible;
029import com.google.common.annotations.J2ktIncompatible;
030import com.google.common.base.MoreObjects;
031import com.google.common.primitives.Ints;
032import com.google.errorprone.annotations.CanIgnoreReturnValue;
033import java.io.IOException;
034import java.io.ObjectInputStream;
035import java.io.ObjectOutputStream;
036import java.io.Serializable;
037import java.util.Comparator;
038import java.util.ConcurrentModificationException;
039import java.util.Iterator;
040import java.util.NoSuchElementException;
041import java.util.function.ObjIntConsumer;
042import javax.annotation.CheckForNull;
043import org.checkerframework.checker.nullness.qual.Nullable;
044
045/**
046 * A multiset which maintains the ordering of its elements, according to either their natural order
047 * or an explicit {@link Comparator}. In all cases, this implementation uses {@link
048 * Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to determine
049 * equivalence of instances.
050 *
051 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
052 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the {@link
053 * java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
054 *
055 * <p>See the Guava User Guide article on <a href=
056 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset">{@code Multiset}</a>.
057 *
058 * @author Louis Wasserman
059 * @author Jared Levy
060 * @since 2.0
061 */
062@GwtCompatible(emulated = true)
063@ElementTypesAreNonnullByDefault
064public final class TreeMultiset<E extends @Nullable Object> extends AbstractSortedMultiset<E>
065    implements Serializable {
066
067  /**
068   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
069   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
070   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
071   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
072   * user attempts to add an element to the multiset that violates this constraint (for example, the
073   * user attempts to add a string element to a set whose elements are integers), the {@code
074   * add(Object)} call will throw a {@code ClassCastException}.
075   *
076   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
077   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
078   */
079  @SuppressWarnings("rawtypes") // https://github.com/google/guava/issues/989
080  public static <E extends Comparable> TreeMultiset<E> create() {
081    return new TreeMultiset<>(Ordering.natural());
082  }
083
084  /**
085   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
086   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
087   * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements
088   * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the
089   * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code
090   * ClassCastException}.
091   *
092   * @param comparator the comparator that will be used to sort this multiset. A null value
093   *     indicates that the elements' <i>natural ordering</i> should be used.
094   */
095  @SuppressWarnings("unchecked")
096  public static <E extends @Nullable Object> TreeMultiset<E> create(
097      @CheckForNull Comparator<? super E> comparator) {
098    return (comparator == null)
099        ? new TreeMultiset<E>((Comparator) Ordering.natural())
100        : new TreeMultiset<E>(comparator);
101  }
102
103  /**
104   * Creates an empty multiset containing the given initial elements, sorted according to the
105   * elements' natural order.
106   *
107   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
108   *
109   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
110   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
111   */
112  @SuppressWarnings("rawtypes") // https://github.com/google/guava/issues/989
113  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
114    TreeMultiset<E> multiset = create();
115    Iterables.addAll(multiset, elements);
116    return multiset;
117  }
118
119  private final transient Reference<AvlNode<E>> rootReference;
120  private final transient GeneralRange<E> range;
121  private final transient AvlNode<E> header;
122
123  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
124    super(range.comparator());
125    this.rootReference = rootReference;
126    this.range = range;
127    this.header = endLink;
128  }
129
130  TreeMultiset(Comparator<? super E> comparator) {
131    super(comparator);
132    this.range = GeneralRange.all(comparator);
133    this.header = new AvlNode<>();
134    successor(header, header);
135    this.rootReference = new Reference<>();
136  }
137
138  /** A function which can be summed across a subtree. */
139  private enum Aggregate {
140    SIZE {
141      @Override
142      int nodeAggregate(AvlNode<?> node) {
143        return node.elemCount;
144      }
145
146      @Override
147      long treeAggregate(@CheckForNull AvlNode<?> root) {
148        return (root == null) ? 0 : root.totalCount;
149      }
150    },
151    DISTINCT {
152      @Override
153      int nodeAggregate(AvlNode<?> node) {
154        return 1;
155      }
156
157      @Override
158      long treeAggregate(@CheckForNull AvlNode<?> root) {
159        return (root == null) ? 0 : root.distinctElements;
160      }
161    };
162
163    abstract int nodeAggregate(AvlNode<?> node);
164
165    abstract long treeAggregate(@CheckForNull AvlNode<?> root);
166  }
167
168  private long aggregateForEntries(Aggregate aggr) {
169    AvlNode<E> root = rootReference.get();
170    long total = aggr.treeAggregate(root);
171    if (range.hasLowerBound()) {
172      total -= aggregateBelowRange(aggr, root);
173    }
174    if (range.hasUpperBound()) {
175      total -= aggregateAboveRange(aggr, root);
176    }
177    return total;
178  }
179
180  private long aggregateBelowRange(Aggregate aggr, @CheckForNull AvlNode<E> node) {
181    if (node == null) {
182      return 0;
183    }
184    // The cast is safe because we call this method only if hasLowerBound().
185    int cmp =
186        comparator()
187            .compare(uncheckedCastNullableTToT(range.getLowerEndpoint()), node.getElement());
188    if (cmp < 0) {
189      return aggregateBelowRange(aggr, node.left);
190    } else if (cmp == 0) {
191      switch (range.getLowerBoundType()) {
192        case OPEN:
193          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
194        case CLOSED:
195          return aggr.treeAggregate(node.left);
196      }
197      throw new AssertionError();
198    } else {
199      return aggr.treeAggregate(node.left)
200          + aggr.nodeAggregate(node)
201          + aggregateBelowRange(aggr, node.right);
202    }
203  }
204
205  private long aggregateAboveRange(Aggregate aggr, @CheckForNull AvlNode<E> node) {
206    if (node == null) {
207      return 0;
208    }
209    // The cast is safe because we call this method only if hasUpperBound().
210    int cmp =
211        comparator()
212            .compare(uncheckedCastNullableTToT(range.getUpperEndpoint()), node.getElement());
213    if (cmp > 0) {
214      return aggregateAboveRange(aggr, node.right);
215    } else if (cmp == 0) {
216      switch (range.getUpperBoundType()) {
217        case OPEN:
218          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
219        case CLOSED:
220          return aggr.treeAggregate(node.right);
221      }
222      throw new AssertionError();
223    } else {
224      return aggr.treeAggregate(node.right)
225          + aggr.nodeAggregate(node)
226          + aggregateAboveRange(aggr, node.left);
227    }
228  }
229
230  @Override
231  public int size() {
232    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
233  }
234
235  @Override
236  int distinctElements() {
237    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
238  }
239
240  static int distinctElements(@CheckForNull AvlNode<?> node) {
241    return (node == null) ? 0 : node.distinctElements;
242  }
243
244  @Override
245  public int count(@CheckForNull Object element) {
246    try {
247      @SuppressWarnings("unchecked")
248      E e = (E) element;
249      AvlNode<E> root = rootReference.get();
250      if (!range.contains(e) || root == null) {
251        return 0;
252      }
253      return root.count(comparator(), e);
254    } catch (ClassCastException | NullPointerException e) {
255      return 0;
256    }
257  }
258
259  @CanIgnoreReturnValue
260  @Override
261  public int add(@ParametricNullness E element, int occurrences) {
262    checkNonnegative(occurrences, "occurrences");
263    if (occurrences == 0) {
264      return count(element);
265    }
266    checkArgument(range.contains(element));
267    AvlNode<E> root = rootReference.get();
268    if (root == null) {
269      int unused = comparator().compare(element, element);
270      AvlNode<E> newRoot = new AvlNode<>(element, occurrences);
271      successor(header, newRoot, header);
272      rootReference.checkAndSet(root, newRoot);
273      return 0;
274    }
275    int[] result = new int[1]; // used as a mutable int reference to hold result
276    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
277    rootReference.checkAndSet(root, newRoot);
278    return result[0];
279  }
280
281  @CanIgnoreReturnValue
282  @Override
283  public int remove(@CheckForNull Object element, int occurrences) {
284    checkNonnegative(occurrences, "occurrences");
285    if (occurrences == 0) {
286      return count(element);
287    }
288    AvlNode<E> root = rootReference.get();
289    int[] result = new int[1]; // used as a mutable int reference to hold result
290    AvlNode<E> newRoot;
291    try {
292      @SuppressWarnings("unchecked")
293      E e = (E) element;
294      if (!range.contains(e) || root == null) {
295        return 0;
296      }
297      newRoot = root.remove(comparator(), e, occurrences, result);
298    } catch (ClassCastException | NullPointerException e) {
299      return 0;
300    }
301    rootReference.checkAndSet(root, newRoot);
302    return result[0];
303  }
304
305  @CanIgnoreReturnValue
306  @Override
307  public int setCount(@ParametricNullness E element, int count) {
308    checkNonnegative(count, "count");
309    if (!range.contains(element)) {
310      checkArgument(count == 0);
311      return 0;
312    }
313
314    AvlNode<E> root = rootReference.get();
315    if (root == null) {
316      if (count > 0) {
317        add(element, count);
318      }
319      return 0;
320    }
321    int[] result = new int[1]; // used as a mutable int reference to hold result
322    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
323    rootReference.checkAndSet(root, newRoot);
324    return result[0];
325  }
326
327  @CanIgnoreReturnValue
328  @Override
329  public boolean setCount(@ParametricNullness E element, int oldCount, int newCount) {
330    checkNonnegative(newCount, "newCount");
331    checkNonnegative(oldCount, "oldCount");
332    checkArgument(range.contains(element));
333
334    AvlNode<E> root = rootReference.get();
335    if (root == null) {
336      if (oldCount == 0) {
337        if (newCount > 0) {
338          add(element, newCount);
339        }
340        return true;
341      } else {
342        return false;
343      }
344    }
345    int[] result = new int[1]; // used as a mutable int reference to hold result
346    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
347    rootReference.checkAndSet(root, newRoot);
348    return result[0] == oldCount;
349  }
350
351  @Override
352  public void clear() {
353    if (!range.hasLowerBound() && !range.hasUpperBound()) {
354      // We can do this in O(n) rather than removing one by one, which could force rebalancing.
355      for (AvlNode<E> current = header.succ(); current != header; ) {
356        AvlNode<E> next = current.succ();
357
358        current.elemCount = 0;
359        // Also clear these fields so that one deleted Entry doesn't retain all elements.
360        current.left = null;
361        current.right = null;
362        current.pred = null;
363        current.succ = null;
364
365        current = next;
366      }
367      successor(header, header);
368      rootReference.clear();
369    } else {
370      // TODO(cpovirk): Perhaps we can optimize in this case, too?
371      Iterators.clear(entryIterator());
372    }
373  }
374
375  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
376    return new Multisets.AbstractEntry<E>() {
377      @Override
378      @ParametricNullness
379      public E getElement() {
380        return baseEntry.getElement();
381      }
382
383      @Override
384      public int getCount() {
385        int result = baseEntry.getCount();
386        if (result == 0) {
387          return count(getElement());
388        } else {
389          return result;
390        }
391      }
392    };
393  }
394
395  /** Returns the first node in the tree that is in range. */
396  @CheckForNull
397  private AvlNode<E> firstNode() {
398    AvlNode<E> root = rootReference.get();
399    if (root == null) {
400      return null;
401    }
402    AvlNode<E> node;
403    if (range.hasLowerBound()) {
404      // The cast is safe because of the hasLowerBound check.
405      E endpoint = uncheckedCastNullableTToT(range.getLowerEndpoint());
406      node = root.ceiling(comparator(), endpoint);
407      if (node == null) {
408        return null;
409      }
410      if (range.getLowerBoundType() == BoundType.OPEN
411          && comparator().compare(endpoint, node.getElement()) == 0) {
412        node = node.succ();
413      }
414    } else {
415      node = header.succ();
416    }
417    return (node == header || !range.contains(node.getElement())) ? null : node;
418  }
419
420  @CheckForNull
421  private AvlNode<E> lastNode() {
422    AvlNode<E> root = rootReference.get();
423    if (root == null) {
424      return null;
425    }
426    AvlNode<E> node;
427    if (range.hasUpperBound()) {
428      // The cast is safe because of the hasUpperBound check.
429      E endpoint = uncheckedCastNullableTToT(range.getUpperEndpoint());
430      node = root.floor(comparator(), endpoint);
431      if (node == null) {
432        return null;
433      }
434      if (range.getUpperBoundType() == BoundType.OPEN
435          && comparator().compare(endpoint, node.getElement()) == 0) {
436        node = node.pred();
437      }
438    } else {
439      node = header.pred();
440    }
441    return (node == header || !range.contains(node.getElement())) ? null : node;
442  }
443
444  @Override
445  Iterator<E> elementIterator() {
446    return Multisets.elementIterator(entryIterator());
447  }
448
449  @Override
450  Iterator<Entry<E>> entryIterator() {
451    return new Iterator<Entry<E>>() {
452      @CheckForNull AvlNode<E> current = firstNode();
453      @CheckForNull Entry<E> prevEntry;
454
455      @Override
456      public boolean hasNext() {
457        if (current == null) {
458          return false;
459        } else if (range.tooHigh(current.getElement())) {
460          current = null;
461          return false;
462        } else {
463          return true;
464        }
465      }
466
467      @Override
468      public Entry<E> next() {
469        if (!hasNext()) {
470          throw new NoSuchElementException();
471        }
472        // requireNonNull is safe because current is only nulled out after iteration is complete.
473        Entry<E> result = wrapEntry(requireNonNull(current));
474        prevEntry = result;
475        if (current.succ() == header) {
476          current = null;
477        } else {
478          current = current.succ();
479        }
480        return result;
481      }
482
483      @Override
484      public void remove() {
485        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
486        setCount(prevEntry.getElement(), 0);
487        prevEntry = null;
488      }
489    };
490  }
491
492  @Override
493  Iterator<Entry<E>> descendingEntryIterator() {
494    return new Iterator<Entry<E>>() {
495      @CheckForNull AvlNode<E> current = lastNode();
496      @CheckForNull Entry<E> prevEntry = null;
497
498      @Override
499      public boolean hasNext() {
500        if (current == null) {
501          return false;
502        } else if (range.tooLow(current.getElement())) {
503          current = null;
504          return false;
505        } else {
506          return true;
507        }
508      }
509
510      @Override
511      public Entry<E> next() {
512        if (!hasNext()) {
513          throw new NoSuchElementException();
514        }
515        // requireNonNull is safe because current is only nulled out after iteration is complete.
516        requireNonNull(current);
517        Entry<E> result = wrapEntry(current);
518        prevEntry = result;
519        if (current.pred() == header) {
520          current = null;
521        } else {
522          current = current.pred();
523        }
524        return result;
525      }
526
527      @Override
528      public void remove() {
529        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
530        setCount(prevEntry.getElement(), 0);
531        prevEntry = null;
532      }
533    };
534  }
535
536  @Override
537  public void forEachEntry(ObjIntConsumer<? super E> action) {
538    checkNotNull(action);
539    for (AvlNode<E> node = firstNode();
540        node != header && node != null && !range.tooHigh(node.getElement());
541        node = node.succ()) {
542      action.accept(node.getElement(), node.getCount());
543    }
544  }
545
546  @Override
547  public Iterator<E> iterator() {
548    return Multisets.iteratorImpl(this);
549  }
550
551  @Override
552  public SortedMultiset<E> headMultiset(@ParametricNullness E upperBound, BoundType boundType) {
553    return new TreeMultiset<>(
554        rootReference,
555        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
556        header);
557  }
558
559  @Override
560  public SortedMultiset<E> tailMultiset(@ParametricNullness E lowerBound, BoundType boundType) {
561    return new TreeMultiset<>(
562        rootReference,
563        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
564        header);
565  }
566
567  private static final class Reference<T> {
568    @CheckForNull private T value;
569
570    @CheckForNull
571    public T get() {
572      return value;
573    }
574
575    public void checkAndSet(@CheckForNull T expected, @CheckForNull T newValue) {
576      if (value != expected) {
577        throw new ConcurrentModificationException();
578      }
579      value = newValue;
580    }
581
582    void clear() {
583      value = null;
584    }
585  }
586
587  private static final class AvlNode<E extends @Nullable Object> {
588    /*
589     * For "normal" nodes, the type of this field is `E`, not `@Nullable E` (though note that E is a
590     * type that can include null, as in a TreeMultiset<@Nullable String>).
591     *
592     * For the header node, though, this field contains `null`, regardless of the type of the
593     * multiset.
594     *
595     * Most code that operates on an AvlNode never operates on the header node. Such code can access
596     * the elem field without a null check by calling getElement().
597     */
598    @CheckForNull private final E elem;
599
600    // elemCount is 0 iff this node has been deleted.
601    private int elemCount;
602
603    private int distinctElements;
604    private long totalCount;
605    private int height;
606    @CheckForNull private AvlNode<E> left;
607    @CheckForNull private AvlNode<E> right;
608    /*
609     * pred and succ are nullable after construction, but we always call successor() to initialize
610     * them immediately thereafter.
611     *
612     * They may be subsequently nulled out by TreeMultiset.clear(). I think that the only place that
613     * we can reference a node whose fields have been cleared is inside the iterator (and presumably
614     * only under concurrent modification).
615     *
616     * To access these fields when you know that they are not null, call the pred() and succ()
617     * methods, which perform null checks before returning the fields.
618     */
619    @CheckForNull private AvlNode<E> pred;
620    @CheckForNull private AvlNode<E> succ;
621
622    AvlNode(@ParametricNullness E elem, int elemCount) {
623      checkArgument(elemCount > 0);
624      this.elem = elem;
625      this.elemCount = elemCount;
626      this.totalCount = elemCount;
627      this.distinctElements = 1;
628      this.height = 1;
629      this.left = null;
630      this.right = null;
631    }
632
633    /** Constructor for the header node. */
634    AvlNode() {
635      this.elem = null;
636      this.elemCount = 1;
637    }
638
639    // For discussion of pred() and succ(), see the comment on the pred and succ fields.
640
641    private AvlNode<E> pred() {
642      return requireNonNull(pred);
643    }
644
645    private AvlNode<E> succ() {
646      return requireNonNull(succ);
647    }
648
649    int count(Comparator<? super E> comparator, @ParametricNullness E e) {
650      int cmp = comparator.compare(e, getElement());
651      if (cmp < 0) {
652        return (left == null) ? 0 : left.count(comparator, e);
653      } else if (cmp > 0) {
654        return (right == null) ? 0 : right.count(comparator, e);
655      } else {
656        return elemCount;
657      }
658    }
659
660    private AvlNode<E> addRightChild(@ParametricNullness E e, int count) {
661      right = new AvlNode<>(e, count);
662      successor(this, right, succ());
663      height = max(2, height);
664      distinctElements++;
665      totalCount += count;
666      return this;
667    }
668
669    private AvlNode<E> addLeftChild(@ParametricNullness E e, int count) {
670      left = new AvlNode<>(e, count);
671      successor(pred(), left, this);
672      height = max(2, height);
673      distinctElements++;
674      totalCount += count;
675      return this;
676    }
677
678    AvlNode<E> add(
679        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
680      /*
681       * It speeds things up considerably to unconditionally add count to totalCount here,
682       * but that destroys failure atomicity in the case of count overflow. =(
683       */
684      int cmp = comparator.compare(e, getElement());
685      if (cmp < 0) {
686        AvlNode<E> initLeft = left;
687        if (initLeft == null) {
688          result[0] = 0;
689          return addLeftChild(e, count);
690        }
691        int initHeight = initLeft.height;
692
693        left = initLeft.add(comparator, e, count, result);
694        if (result[0] == 0) {
695          distinctElements++;
696        }
697        this.totalCount += count;
698        return (left.height == initHeight) ? this : rebalance();
699      } else if (cmp > 0) {
700        AvlNode<E> initRight = right;
701        if (initRight == null) {
702          result[0] = 0;
703          return addRightChild(e, count);
704        }
705        int initHeight = initRight.height;
706
707        right = initRight.add(comparator, e, count, result);
708        if (result[0] == 0) {
709          distinctElements++;
710        }
711        this.totalCount += count;
712        return (right.height == initHeight) ? this : rebalance();
713      }
714
715      // adding count to me!  No rebalance possible.
716      result[0] = elemCount;
717      long resultCount = (long) elemCount + count;
718      checkArgument(resultCount <= Integer.MAX_VALUE);
719      this.elemCount += count;
720      this.totalCount += count;
721      return this;
722    }
723
724    @CheckForNull
725    AvlNode<E> remove(
726        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
727      int cmp = comparator.compare(e, getElement());
728      if (cmp < 0) {
729        AvlNode<E> initLeft = left;
730        if (initLeft == null) {
731          result[0] = 0;
732          return this;
733        }
734
735        left = initLeft.remove(comparator, e, count, result);
736
737        if (result[0] > 0) {
738          if (count >= result[0]) {
739            this.distinctElements--;
740            this.totalCount -= result[0];
741          } else {
742            this.totalCount -= count;
743          }
744        }
745        return (result[0] == 0) ? this : rebalance();
746      } else if (cmp > 0) {
747        AvlNode<E> initRight = right;
748        if (initRight == null) {
749          result[0] = 0;
750          return this;
751        }
752
753        right = initRight.remove(comparator, e, count, result);
754
755        if (result[0] > 0) {
756          if (count >= result[0]) {
757            this.distinctElements--;
758            this.totalCount -= result[0];
759          } else {
760            this.totalCount -= count;
761          }
762        }
763        return rebalance();
764      }
765
766      // removing count from me!
767      result[0] = elemCount;
768      if (count >= elemCount) {
769        return deleteMe();
770      } else {
771        this.elemCount -= count;
772        this.totalCount -= count;
773        return this;
774      }
775    }
776
777    @CheckForNull
778    AvlNode<E> setCount(
779        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
780      int cmp = comparator.compare(e, getElement());
781      if (cmp < 0) {
782        AvlNode<E> initLeft = left;
783        if (initLeft == null) {
784          result[0] = 0;
785          return (count > 0) ? addLeftChild(e, count) : this;
786        }
787
788        left = initLeft.setCount(comparator, e, count, result);
789
790        if (count == 0 && result[0] != 0) {
791          this.distinctElements--;
792        } else if (count > 0 && result[0] == 0) {
793          this.distinctElements++;
794        }
795
796        this.totalCount += count - result[0];
797        return rebalance();
798      } else if (cmp > 0) {
799        AvlNode<E> initRight = right;
800        if (initRight == null) {
801          result[0] = 0;
802          return (count > 0) ? addRightChild(e, count) : this;
803        }
804
805        right = initRight.setCount(comparator, e, count, result);
806
807        if (count == 0 && result[0] != 0) {
808          this.distinctElements--;
809        } else if (count > 0 && result[0] == 0) {
810          this.distinctElements++;
811        }
812
813        this.totalCount += count - result[0];
814        return rebalance();
815      }
816
817      // setting my count
818      result[0] = elemCount;
819      if (count == 0) {
820        return deleteMe();
821      }
822      this.totalCount += count - elemCount;
823      this.elemCount = count;
824      return this;
825    }
826
827    @CheckForNull
828    AvlNode<E> setCount(
829        Comparator<? super E> comparator,
830        @ParametricNullness E e,
831        int expectedCount,
832        int newCount,
833        int[] result) {
834      int cmp = comparator.compare(e, getElement());
835      if (cmp < 0) {
836        AvlNode<E> initLeft = left;
837        if (initLeft == null) {
838          result[0] = 0;
839          if (expectedCount == 0 && newCount > 0) {
840            return addLeftChild(e, newCount);
841          }
842          return this;
843        }
844
845        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
846
847        if (result[0] == expectedCount) {
848          if (newCount == 0 && result[0] != 0) {
849            this.distinctElements--;
850          } else if (newCount > 0 && result[0] == 0) {
851            this.distinctElements++;
852          }
853          this.totalCount += newCount - result[0];
854        }
855        return rebalance();
856      } else if (cmp > 0) {
857        AvlNode<E> initRight = right;
858        if (initRight == null) {
859          result[0] = 0;
860          if (expectedCount == 0 && newCount > 0) {
861            return addRightChild(e, newCount);
862          }
863          return this;
864        }
865
866        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
867
868        if (result[0] == expectedCount) {
869          if (newCount == 0 && result[0] != 0) {
870            this.distinctElements--;
871          } else if (newCount > 0 && result[0] == 0) {
872            this.distinctElements++;
873          }
874          this.totalCount += newCount - result[0];
875        }
876        return rebalance();
877      }
878
879      // setting my count
880      result[0] = elemCount;
881      if (expectedCount == elemCount) {
882        if (newCount == 0) {
883          return deleteMe();
884        }
885        this.totalCount += newCount - elemCount;
886        this.elemCount = newCount;
887      }
888      return this;
889    }
890
891    @CheckForNull
892    private AvlNode<E> deleteMe() {
893      int oldElemCount = this.elemCount;
894      this.elemCount = 0;
895      successor(pred(), succ());
896      if (left == null) {
897        return right;
898      } else if (right == null) {
899        return left;
900      } else if (left.height >= right.height) {
901        AvlNode<E> newTop = pred();
902        // newTop is the maximum node in my left subtree
903        newTop.left = left.removeMax(newTop);
904        newTop.right = right;
905        newTop.distinctElements = distinctElements - 1;
906        newTop.totalCount = totalCount - oldElemCount;
907        return newTop.rebalance();
908      } else {
909        AvlNode<E> newTop = succ();
910        newTop.right = right.removeMin(newTop);
911        newTop.left = left;
912        newTop.distinctElements = distinctElements - 1;
913        newTop.totalCount = totalCount - oldElemCount;
914        return newTop.rebalance();
915      }
916    }
917
918    // Removes the minimum node from this subtree to be reused elsewhere
919    @CheckForNull
920    private AvlNode<E> removeMin(AvlNode<E> node) {
921      if (left == null) {
922        return right;
923      } else {
924        left = left.removeMin(node);
925        distinctElements--;
926        totalCount -= node.elemCount;
927        return rebalance();
928      }
929    }
930
931    // Removes the maximum node from this subtree to be reused elsewhere
932    @CheckForNull
933    private AvlNode<E> removeMax(AvlNode<E> node) {
934      if (right == null) {
935        return left;
936      } else {
937        right = right.removeMax(node);
938        distinctElements--;
939        totalCount -= node.elemCount;
940        return rebalance();
941      }
942    }
943
944    private void recomputeMultiset() {
945      this.distinctElements =
946          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
947      this.totalCount = elemCount + totalCount(left) + totalCount(right);
948    }
949
950    private void recomputeHeight() {
951      this.height = 1 + max(height(left), height(right));
952    }
953
954    private void recompute() {
955      recomputeMultiset();
956      recomputeHeight();
957    }
958
959    private AvlNode<E> rebalance() {
960      switch (balanceFactor()) {
961        case -2:
962          // requireNonNull is safe because right must exist in order to get a negative factor.
963          requireNonNull(right);
964          if (right.balanceFactor() > 0) {
965            right = right.rotateRight();
966          }
967          return rotateLeft();
968        case 2:
969          // requireNonNull is safe because left must exist in order to get a positive factor.
970          requireNonNull(left);
971          if (left.balanceFactor() < 0) {
972            left = left.rotateLeft();
973          }
974          return rotateRight();
975        default:
976          recomputeHeight();
977          return this;
978      }
979    }
980
981    private int balanceFactor() {
982      return height(left) - height(right);
983    }
984
985    private AvlNode<E> rotateLeft() {
986      checkState(right != null);
987      AvlNode<E> newTop = right;
988      this.right = newTop.left;
989      newTop.left = this;
990      newTop.totalCount = this.totalCount;
991      newTop.distinctElements = this.distinctElements;
992      this.recompute();
993      newTop.recomputeHeight();
994      return newTop;
995    }
996
997    private AvlNode<E> rotateRight() {
998      checkState(left != null);
999      AvlNode<E> newTop = left;
1000      this.left = newTop.right;
1001      newTop.right = this;
1002      newTop.totalCount = this.totalCount;
1003      newTop.distinctElements = this.distinctElements;
1004      this.recompute();
1005      newTop.recomputeHeight();
1006      return newTop;
1007    }
1008
1009    private static long totalCount(@CheckForNull AvlNode<?> node) {
1010      return (node == null) ? 0 : node.totalCount;
1011    }
1012
1013    private static int height(@CheckForNull AvlNode<?> node) {
1014      return (node == null) ? 0 : node.height;
1015    }
1016
1017    @CheckForNull
1018    private AvlNode<E> ceiling(Comparator<? super E> comparator, @ParametricNullness E e) {
1019      int cmp = comparator.compare(e, getElement());
1020      if (cmp < 0) {
1021        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
1022      } else if (cmp == 0) {
1023        return this;
1024      } else {
1025        return (right == null) ? null : right.ceiling(comparator, e);
1026      }
1027    }
1028
1029    @CheckForNull
1030    private AvlNode<E> floor(Comparator<? super E> comparator, @ParametricNullness E e) {
1031      int cmp = comparator.compare(e, getElement());
1032      if (cmp > 0) {
1033        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
1034      } else if (cmp == 0) {
1035        return this;
1036      } else {
1037        return (left == null) ? null : left.floor(comparator, e);
1038      }
1039    }
1040
1041    @ParametricNullness
1042    E getElement() {
1043      // For discussion of this cast, see the comment on the elem field.
1044      return uncheckedCastNullableTToT(elem);
1045    }
1046
1047    int getCount() {
1048      return elemCount;
1049    }
1050
1051    @Override
1052    public String toString() {
1053      return Multisets.immutableEntry(getElement(), getCount()).toString();
1054    }
1055  }
1056
1057  private static <T extends @Nullable Object> void successor(AvlNode<T> a, AvlNode<T> b) {
1058    a.succ = b;
1059    b.pred = a;
1060  }
1061
1062  private static <T extends @Nullable Object> void successor(
1063      AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
1064    successor(a, b);
1065    successor(b, c);
1066  }
1067
1068  /*
1069   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
1070   * calls the comparator to compare the two keys. If that change is made,
1071   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
1072   */
1073
1074  /**
1075   * @serialData the comparator, the number of distinct elements, the first element, its count, the
1076   *     second element, its count, and so on
1077   */
1078  @J2ktIncompatible
1079  @GwtIncompatible // java.io.ObjectOutputStream
1080  private void writeObject(ObjectOutputStream stream) throws IOException {
1081    stream.defaultWriteObject();
1082    stream.writeObject(elementSet().comparator());
1083    Serialization.writeMultiset(this, stream);
1084  }
1085
1086  @J2ktIncompatible
1087  @GwtIncompatible // java.io.ObjectInputStream
1088  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
1089    stream.defaultReadObject();
1090    @SuppressWarnings("unchecked")
1091    // reading data stored by writeObject
1092    Comparator<? super E> comparator = (Comparator<? super E>) requireNonNull(stream.readObject());
1093    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
1094    Serialization.getFieldSetter(TreeMultiset.class, "range")
1095        .set(this, GeneralRange.all(comparator));
1096    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
1097        .set(this, new Reference<AvlNode<E>>());
1098    AvlNode<E> header = new AvlNode<>();
1099    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
1100    successor(header, header);
1101    Serialization.populateMultiset(this, stream);
1102  }
1103
1104  @GwtIncompatible // not needed in emulated source
1105  @J2ktIncompatible
1106  private static final long serialVersionUID = 1;
1107}