001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkState;
021import static com.google.common.collect.CollectPreconditions.checkNonnegative;
022import static com.google.common.collect.NullnessCasts.uncheckedCastNullableTToT;
023import static java.lang.Math.max;
024import static java.util.Objects.requireNonNull;
025
026import com.google.common.annotations.GwtCompatible;
027import com.google.common.annotations.GwtIncompatible;
028import com.google.common.annotations.J2ktIncompatible;
029import com.google.common.base.MoreObjects;
030import com.google.common.primitives.Ints;
031import com.google.errorprone.annotations.CanIgnoreReturnValue;
032import java.io.IOException;
033import java.io.ObjectInputStream;
034import java.io.ObjectOutputStream;
035import java.io.Serializable;
036import java.util.Comparator;
037import java.util.ConcurrentModificationException;
038import java.util.Iterator;
039import java.util.NoSuchElementException;
040import org.jspecify.annotations.Nullable;
041
042/**
043 * A multiset which maintains the ordering of its elements, according to either their natural order
044 * or an explicit {@link Comparator}. In all cases, this implementation uses {@link
045 * Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to determine
046 * equivalence of instances.
047 *
048 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
049 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the {@link
050 * java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
051 *
052 * <p>See the Guava User Guide article on <a href=
053 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset">{@code Multiset}</a>.
054 *
055 * @author Louis Wasserman
056 * @author Jared Levy
057 * @since 2.0
058 */
059@GwtCompatible(emulated = true)
060public final class TreeMultiset<E extends @Nullable Object> extends AbstractSortedMultiset<E>
061    implements Serializable {
062
063  /**
064   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
065   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
066   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
067   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
068   * user attempts to add an element to the multiset that violates this constraint (for example, the
069   * user attempts to add a string element to a set whose elements are integers), the {@code
070   * add(Object)} call will throw a {@code ClassCastException}.
071   *
072   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
073   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
074   */
075  @SuppressWarnings("rawtypes") // https://github.com/google/guava/issues/989
076  public static <E extends Comparable> TreeMultiset<E> create() {
077    return new TreeMultiset<>(Ordering.natural());
078  }
079
080  /**
081   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
082   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
083   * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements
084   * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the
085   * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code
086   * ClassCastException}.
087   *
088   * @param comparator the comparator that will be used to sort this multiset. A null value
089   *     indicates that the elements' <i>natural ordering</i> should be used.
090   */
091  @SuppressWarnings("unchecked")
092  public static <E extends @Nullable Object> TreeMultiset<E> create(
093      @Nullable Comparator<? super E> comparator) {
094    return (comparator == null)
095        ? new TreeMultiset<E>((Comparator) Ordering.natural())
096        : new TreeMultiset<E>(comparator);
097  }
098
099  /**
100   * Creates an empty multiset containing the given initial elements, sorted according to the
101   * elements' natural order.
102   *
103   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
104   *
105   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
106   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
107   */
108  @SuppressWarnings("rawtypes") // https://github.com/google/guava/issues/989
109  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
110    TreeMultiset<E> multiset = create();
111    Iterables.addAll(multiset, elements);
112    return multiset;
113  }
114
115  private final transient Reference<AvlNode<E>> rootReference;
116  private final transient GeneralRange<E> range;
117  private final transient AvlNode<E> header;
118
119  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
120    super(range.comparator());
121    this.rootReference = rootReference;
122    this.range = range;
123    this.header = endLink;
124  }
125
126  TreeMultiset(Comparator<? super E> comparator) {
127    super(comparator);
128    this.range = GeneralRange.all(comparator);
129    this.header = new AvlNode<>();
130    successor(header, header);
131    this.rootReference = new Reference<>();
132  }
133
134  /** A function which can be summed across a subtree. */
135  private enum Aggregate {
136    SIZE {
137      @Override
138      int nodeAggregate(AvlNode<?> node) {
139        return node.elemCount;
140      }
141
142      @Override
143      long treeAggregate(@Nullable AvlNode<?> root) {
144        return (root == null) ? 0 : root.totalCount;
145      }
146    },
147    DISTINCT {
148      @Override
149      int nodeAggregate(AvlNode<?> node) {
150        return 1;
151      }
152
153      @Override
154      long treeAggregate(@Nullable AvlNode<?> root) {
155        return (root == null) ? 0 : root.distinctElements;
156      }
157    };
158
159    abstract int nodeAggregate(AvlNode<?> node);
160
161    abstract long treeAggregate(@Nullable AvlNode<?> root);
162  }
163
164  private long aggregateForEntries(Aggregate aggr) {
165    AvlNode<E> root = rootReference.get();
166    long total = aggr.treeAggregate(root);
167    if (range.hasLowerBound()) {
168      total -= aggregateBelowRange(aggr, root);
169    }
170    if (range.hasUpperBound()) {
171      total -= aggregateAboveRange(aggr, root);
172    }
173    return total;
174  }
175
176  private long aggregateBelowRange(Aggregate aggr, @Nullable AvlNode<E> node) {
177    if (node == null) {
178      return 0;
179    }
180    // The cast is safe because we call this method only if hasLowerBound().
181    int cmp =
182        comparator()
183            .compare(uncheckedCastNullableTToT(range.getLowerEndpoint()), node.getElement());
184    if (cmp < 0) {
185      return aggregateBelowRange(aggr, node.left);
186    } else if (cmp == 0) {
187      switch (range.getLowerBoundType()) {
188        case OPEN:
189          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
190        case CLOSED:
191          return aggr.treeAggregate(node.left);
192      }
193      throw new AssertionError();
194    } else {
195      return aggr.treeAggregate(node.left)
196          + aggr.nodeAggregate(node)
197          + aggregateBelowRange(aggr, node.right);
198    }
199  }
200
201  private long aggregateAboveRange(Aggregate aggr, @Nullable AvlNode<E> node) {
202    if (node == null) {
203      return 0;
204    }
205    // The cast is safe because we call this method only if hasUpperBound().
206    int cmp =
207        comparator()
208            .compare(uncheckedCastNullableTToT(range.getUpperEndpoint()), node.getElement());
209    if (cmp > 0) {
210      return aggregateAboveRange(aggr, node.right);
211    } else if (cmp == 0) {
212      switch (range.getUpperBoundType()) {
213        case OPEN:
214          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
215        case CLOSED:
216          return aggr.treeAggregate(node.right);
217      }
218      throw new AssertionError();
219    } else {
220      return aggr.treeAggregate(node.right)
221          + aggr.nodeAggregate(node)
222          + aggregateAboveRange(aggr, node.left);
223    }
224  }
225
226  @Override
227  public int size() {
228    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
229  }
230
231  @Override
232  int distinctElements() {
233    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
234  }
235
236  static int distinctElements(@Nullable AvlNode<?> node) {
237    return (node == null) ? 0 : node.distinctElements;
238  }
239
240  @Override
241  public int count(@Nullable Object element) {
242    try {
243      @SuppressWarnings("unchecked")
244      E e = (E) element;
245      AvlNode<E> root = rootReference.get();
246      if (!range.contains(e) || root == null) {
247        return 0;
248      }
249      return root.count(comparator(), e);
250    } catch (ClassCastException | NullPointerException e) {
251      return 0;
252    }
253  }
254
255  @CanIgnoreReturnValue
256  @Override
257  public int add(@ParametricNullness E element, int occurrences) {
258    checkNonnegative(occurrences, "occurrences");
259    if (occurrences == 0) {
260      return count(element);
261    }
262    checkArgument(range.contains(element));
263    AvlNode<E> root = rootReference.get();
264    if (root == null) {
265      int unused = comparator().compare(element, element);
266      AvlNode<E> newRoot = new AvlNode<>(element, occurrences);
267      successor(header, newRoot, header);
268      rootReference.checkAndSet(root, newRoot);
269      return 0;
270    }
271    int[] result = new int[1]; // used as a mutable int reference to hold result
272    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
273    rootReference.checkAndSet(root, newRoot);
274    return result[0];
275  }
276
277  @CanIgnoreReturnValue
278  @Override
279  public int remove(@Nullable Object element, int occurrences) {
280    checkNonnegative(occurrences, "occurrences");
281    if (occurrences == 0) {
282      return count(element);
283    }
284    AvlNode<E> root = rootReference.get();
285    int[] result = new int[1]; // used as a mutable int reference to hold result
286    AvlNode<E> newRoot;
287    try {
288      @SuppressWarnings("unchecked")
289      E e = (E) element;
290      if (!range.contains(e) || root == null) {
291        return 0;
292      }
293      newRoot = root.remove(comparator(), e, occurrences, result);
294    } catch (ClassCastException | NullPointerException e) {
295      return 0;
296    }
297    rootReference.checkAndSet(root, newRoot);
298    return result[0];
299  }
300
301  @CanIgnoreReturnValue
302  @Override
303  public int setCount(@ParametricNullness E element, int count) {
304    checkNonnegative(count, "count");
305    if (!range.contains(element)) {
306      checkArgument(count == 0);
307      return 0;
308    }
309
310    AvlNode<E> root = rootReference.get();
311    if (root == null) {
312      if (count > 0) {
313        add(element, count);
314      }
315      return 0;
316    }
317    int[] result = new int[1]; // used as a mutable int reference to hold result
318    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
319    rootReference.checkAndSet(root, newRoot);
320    return result[0];
321  }
322
323  @CanIgnoreReturnValue
324  @Override
325  public boolean setCount(@ParametricNullness E element, int oldCount, int newCount) {
326    checkNonnegative(newCount, "newCount");
327    checkNonnegative(oldCount, "oldCount");
328    checkArgument(range.contains(element));
329
330    AvlNode<E> root = rootReference.get();
331    if (root == null) {
332      if (oldCount == 0) {
333        if (newCount > 0) {
334          add(element, newCount);
335        }
336        return true;
337      } else {
338        return false;
339      }
340    }
341    int[] result = new int[1]; // used as a mutable int reference to hold result
342    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
343    rootReference.checkAndSet(root, newRoot);
344    return result[0] == oldCount;
345  }
346
347  @Override
348  public void clear() {
349    if (!range.hasLowerBound() && !range.hasUpperBound()) {
350      // We can do this in O(n) rather than removing one by one, which could force rebalancing.
351      for (AvlNode<E> current = header.succ(); current != header; ) {
352        AvlNode<E> next = current.succ();
353
354        current.elemCount = 0;
355        // Also clear these fields so that one deleted Entry doesn't retain all elements.
356        current.left = null;
357        current.right = null;
358        current.pred = null;
359        current.succ = null;
360
361        current = next;
362      }
363      successor(header, header);
364      rootReference.clear();
365    } else {
366      // TODO(cpovirk): Perhaps we can optimize in this case, too?
367      Iterators.clear(entryIterator());
368    }
369  }
370
371  private Entry<E> wrapEntry(AvlNode<E> baseEntry) {
372    return new Multisets.AbstractEntry<E>() {
373      @Override
374      @ParametricNullness
375      public E getElement() {
376        return baseEntry.getElement();
377      }
378
379      @Override
380      public int getCount() {
381        int result = baseEntry.getCount();
382        if (result == 0) {
383          return count(getElement());
384        } else {
385          return result;
386        }
387      }
388    };
389  }
390
391  /** Returns the first node in the tree that is in range. */
392  private @Nullable AvlNode<E> firstNode() {
393    AvlNode<E> root = rootReference.get();
394    if (root == null) {
395      return null;
396    }
397    AvlNode<E> node;
398    if (range.hasLowerBound()) {
399      // The cast is safe because of the hasLowerBound check.
400      E endpoint = uncheckedCastNullableTToT(range.getLowerEndpoint());
401      node = root.ceiling(comparator(), endpoint);
402      if (node == null) {
403        return null;
404      }
405      if (range.getLowerBoundType() == BoundType.OPEN
406          && comparator().compare(endpoint, node.getElement()) == 0) {
407        node = node.succ();
408      }
409    } else {
410      node = header.succ();
411    }
412    return (node == header || !range.contains(node.getElement())) ? null : node;
413  }
414
415  private @Nullable AvlNode<E> lastNode() {
416    AvlNode<E> root = rootReference.get();
417    if (root == null) {
418      return null;
419    }
420    AvlNode<E> node;
421    if (range.hasUpperBound()) {
422      // The cast is safe because of the hasUpperBound check.
423      E endpoint = uncheckedCastNullableTToT(range.getUpperEndpoint());
424      node = root.floor(comparator(), endpoint);
425      if (node == null) {
426        return null;
427      }
428      if (range.getUpperBoundType() == BoundType.OPEN
429          && comparator().compare(endpoint, node.getElement()) == 0) {
430        node = node.pred();
431      }
432    } else {
433      node = header.pred();
434    }
435    return (node == header || !range.contains(node.getElement())) ? null : node;
436  }
437
438  @Override
439  Iterator<E> elementIterator() {
440    return Multisets.elementIterator(entryIterator());
441  }
442
443  @Override
444  Iterator<Entry<E>> entryIterator() {
445    return new Iterator<Entry<E>>() {
446      @Nullable AvlNode<E> current = firstNode();
447      @Nullable Entry<E> prevEntry;
448
449      @Override
450      public boolean hasNext() {
451        if (current == null) {
452          return false;
453        } else if (range.tooHigh(current.getElement())) {
454          current = null;
455          return false;
456        } else {
457          return true;
458        }
459      }
460
461      @Override
462      public Entry<E> next() {
463        if (!hasNext()) {
464          throw new NoSuchElementException();
465        }
466        // requireNonNull is safe because current is only nulled out after iteration is complete.
467        Entry<E> result = wrapEntry(requireNonNull(current));
468        prevEntry = result;
469        if (current.succ() == header) {
470          current = null;
471        } else {
472          current = current.succ();
473        }
474        return result;
475      }
476
477      @Override
478      public void remove() {
479        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
480        setCount(prevEntry.getElement(), 0);
481        prevEntry = null;
482      }
483    };
484  }
485
486  @Override
487  Iterator<Entry<E>> descendingEntryIterator() {
488    return new Iterator<Entry<E>>() {
489      @Nullable AvlNode<E> current = lastNode();
490      @Nullable Entry<E> prevEntry = null;
491
492      @Override
493      public boolean hasNext() {
494        if (current == null) {
495          return false;
496        } else if (range.tooLow(current.getElement())) {
497          current = null;
498          return false;
499        } else {
500          return true;
501        }
502      }
503
504      @Override
505      public Entry<E> next() {
506        if (!hasNext()) {
507          throw new NoSuchElementException();
508        }
509        // requireNonNull is safe because current is only nulled out after iteration is complete.
510        requireNonNull(current);
511        Entry<E> result = wrapEntry(current);
512        prevEntry = result;
513        if (current.pred() == header) {
514          current = null;
515        } else {
516          current = current.pred();
517        }
518        return result;
519      }
520
521      @Override
522      public void remove() {
523        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
524        setCount(prevEntry.getElement(), 0);
525        prevEntry = null;
526      }
527    };
528  }
529
530  @Override
531  public Iterator<E> iterator() {
532    return Multisets.iteratorImpl(this);
533  }
534
535  @Override
536  public SortedMultiset<E> headMultiset(@ParametricNullness E upperBound, BoundType boundType) {
537    return new TreeMultiset<>(
538        rootReference,
539        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
540        header);
541  }
542
543  @Override
544  public SortedMultiset<E> tailMultiset(@ParametricNullness E lowerBound, BoundType boundType) {
545    return new TreeMultiset<>(
546        rootReference,
547        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
548        header);
549  }
550
551  private static final class Reference<T> {
552    private @Nullable T value;
553
554    public @Nullable T get() {
555      return value;
556    }
557
558    public void checkAndSet(@Nullable T expected, @Nullable T newValue) {
559      if (value != expected) {
560        throw new ConcurrentModificationException();
561      }
562      value = newValue;
563    }
564
565    void clear() {
566      value = null;
567    }
568  }
569
570  private static final class AvlNode<E extends @Nullable Object> {
571    /*
572     * For "normal" nodes, the type of this field is `E`, not `@Nullable E` (though note that E is a
573     * type that can include null, as in a TreeMultiset<@Nullable String>).
574     *
575     * For the header node, though, this field contains `null`, regardless of the type of the
576     * multiset.
577     *
578     * Most code that operates on an AvlNode never operates on the header node. Such code can access
579     * the elem field without a null check by calling getElement().
580     */
581    private final @Nullable E elem;
582
583    // elemCount is 0 iff this node has been deleted.
584    private int elemCount;
585
586    private int distinctElements;
587    private long totalCount;
588    private int height;
589    private @Nullable AvlNode<E> left;
590    private @Nullable AvlNode<E> right;
591    /*
592     * pred and succ are nullable after construction, but we always call successor() to initialize
593     * them immediately thereafter.
594     *
595     * They may be subsequently nulled out by TreeMultiset.clear(). I think that the only place that
596     * we can reference a node whose fields have been cleared is inside the iterator (and presumably
597     * only under concurrent modification).
598     *
599     * To access these fields when you know that they are not null, call the pred() and succ()
600     * methods, which perform null checks before returning the fields.
601     */
602    private @Nullable AvlNode<E> pred;
603    private @Nullable AvlNode<E> succ;
604
605    AvlNode(@ParametricNullness E elem, int elemCount) {
606      checkArgument(elemCount > 0);
607      this.elem = elem;
608      this.elemCount = elemCount;
609      this.totalCount = elemCount;
610      this.distinctElements = 1;
611      this.height = 1;
612      this.left = null;
613      this.right = null;
614    }
615
616    /** Constructor for the header node. */
617    AvlNode() {
618      this.elem = null;
619      this.elemCount = 1;
620    }
621
622    // For discussion of pred() and succ(), see the comment on the pred and succ fields.
623
624    private AvlNode<E> pred() {
625      return requireNonNull(pred);
626    }
627
628    private AvlNode<E> succ() {
629      return requireNonNull(succ);
630    }
631
632    int count(Comparator<? super E> comparator, @ParametricNullness E e) {
633      int cmp = comparator.compare(e, getElement());
634      if (cmp < 0) {
635        return (left == null) ? 0 : left.count(comparator, e);
636      } else if (cmp > 0) {
637        return (right == null) ? 0 : right.count(comparator, e);
638      } else {
639        return elemCount;
640      }
641    }
642
643    @CanIgnoreReturnValue
644    private AvlNode<E> addRightChild(@ParametricNullness E e, int count) {
645      right = new AvlNode<>(e, count);
646      successor(this, right, succ());
647      height = max(2, height);
648      distinctElements++;
649      totalCount += count;
650      return this;
651    }
652
653    @CanIgnoreReturnValue
654    private AvlNode<E> addLeftChild(@ParametricNullness E e, int count) {
655      left = new AvlNode<>(e, count);
656      successor(pred(), left, this);
657      height = max(2, height);
658      distinctElements++;
659      totalCount += count;
660      return this;
661    }
662
663    AvlNode<E> add(
664        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
665      /*
666       * It speeds things up considerably to unconditionally add count to totalCount here,
667       * but that destroys failure atomicity in the case of count overflow. =(
668       */
669      int cmp = comparator.compare(e, getElement());
670      if (cmp < 0) {
671        AvlNode<E> initLeft = left;
672        if (initLeft == null) {
673          result[0] = 0;
674          return addLeftChild(e, count);
675        }
676        int initHeight = initLeft.height;
677
678        left = initLeft.add(comparator, e, count, result);
679        if (result[0] == 0) {
680          distinctElements++;
681        }
682        this.totalCount += count;
683        return (left.height == initHeight) ? this : rebalance();
684      } else if (cmp > 0) {
685        AvlNode<E> initRight = right;
686        if (initRight == null) {
687          result[0] = 0;
688          return addRightChild(e, count);
689        }
690        int initHeight = initRight.height;
691
692        right = initRight.add(comparator, e, count, result);
693        if (result[0] == 0) {
694          distinctElements++;
695        }
696        this.totalCount += count;
697        return (right.height == initHeight) ? this : rebalance();
698      }
699
700      // adding count to me!  No rebalance possible.
701      result[0] = elemCount;
702      long resultCount = (long) elemCount + count;
703      checkArgument(resultCount <= Integer.MAX_VALUE);
704      this.elemCount += count;
705      this.totalCount += count;
706      return this;
707    }
708
709    @Nullable AvlNode<E> remove(
710        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
711      int cmp = comparator.compare(e, getElement());
712      if (cmp < 0) {
713        AvlNode<E> initLeft = left;
714        if (initLeft == null) {
715          result[0] = 0;
716          return this;
717        }
718
719        left = initLeft.remove(comparator, e, count, result);
720
721        if (result[0] > 0) {
722          if (count >= result[0]) {
723            this.distinctElements--;
724            this.totalCount -= result[0];
725          } else {
726            this.totalCount -= count;
727          }
728        }
729        return (result[0] == 0) ? this : rebalance();
730      } else if (cmp > 0) {
731        AvlNode<E> initRight = right;
732        if (initRight == null) {
733          result[0] = 0;
734          return this;
735        }
736
737        right = initRight.remove(comparator, e, count, result);
738
739        if (result[0] > 0) {
740          if (count >= result[0]) {
741            this.distinctElements--;
742            this.totalCount -= result[0];
743          } else {
744            this.totalCount -= count;
745          }
746        }
747        return rebalance();
748      }
749
750      // removing count from me!
751      result[0] = elemCount;
752      if (count >= elemCount) {
753        return deleteMe();
754      } else {
755        this.elemCount -= count;
756        this.totalCount -= count;
757        return this;
758      }
759    }
760
761    @Nullable AvlNode<E> setCount(
762        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
763      int cmp = comparator.compare(e, getElement());
764      if (cmp < 0) {
765        AvlNode<E> initLeft = left;
766        if (initLeft == null) {
767          result[0] = 0;
768          return (count > 0) ? addLeftChild(e, count) : this;
769        }
770
771        left = initLeft.setCount(comparator, e, count, result);
772
773        if (count == 0 && result[0] != 0) {
774          this.distinctElements--;
775        } else if (count > 0 && result[0] == 0) {
776          this.distinctElements++;
777        }
778
779        this.totalCount += count - result[0];
780        return rebalance();
781      } else if (cmp > 0) {
782        AvlNode<E> initRight = right;
783        if (initRight == null) {
784          result[0] = 0;
785          return (count > 0) ? addRightChild(e, count) : this;
786        }
787
788        right = initRight.setCount(comparator, e, count, result);
789
790        if (count == 0 && result[0] != 0) {
791          this.distinctElements--;
792        } else if (count > 0 && result[0] == 0) {
793          this.distinctElements++;
794        }
795
796        this.totalCount += count - result[0];
797        return rebalance();
798      }
799
800      // setting my count
801      result[0] = elemCount;
802      if (count == 0) {
803        return deleteMe();
804      }
805      this.totalCount += count - elemCount;
806      this.elemCount = count;
807      return this;
808    }
809
810    @Nullable AvlNode<E> setCount(
811        Comparator<? super E> comparator,
812        @ParametricNullness E e,
813        int expectedCount,
814        int newCount,
815        int[] result) {
816      int cmp = comparator.compare(e, getElement());
817      if (cmp < 0) {
818        AvlNode<E> initLeft = left;
819        if (initLeft == null) {
820          result[0] = 0;
821          if (expectedCount == 0 && newCount > 0) {
822            return addLeftChild(e, newCount);
823          }
824          return this;
825        }
826
827        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
828
829        if (result[0] == expectedCount) {
830          if (newCount == 0 && result[0] != 0) {
831            this.distinctElements--;
832          } else if (newCount > 0 && result[0] == 0) {
833            this.distinctElements++;
834          }
835          this.totalCount += newCount - result[0];
836        }
837        return rebalance();
838      } else if (cmp > 0) {
839        AvlNode<E> initRight = right;
840        if (initRight == null) {
841          result[0] = 0;
842          if (expectedCount == 0 && newCount > 0) {
843            return addRightChild(e, newCount);
844          }
845          return this;
846        }
847
848        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
849
850        if (result[0] == expectedCount) {
851          if (newCount == 0 && result[0] != 0) {
852            this.distinctElements--;
853          } else if (newCount > 0 && result[0] == 0) {
854            this.distinctElements++;
855          }
856          this.totalCount += newCount - result[0];
857        }
858        return rebalance();
859      }
860
861      // setting my count
862      result[0] = elemCount;
863      if (expectedCount == elemCount) {
864        if (newCount == 0) {
865          return deleteMe();
866        }
867        this.totalCount += newCount - elemCount;
868        this.elemCount = newCount;
869      }
870      return this;
871    }
872
873    private @Nullable AvlNode<E> deleteMe() {
874      int oldElemCount = this.elemCount;
875      this.elemCount = 0;
876      successor(pred(), succ());
877      if (left == null) {
878        return right;
879      } else if (right == null) {
880        return left;
881      } else if (left.height >= right.height) {
882        AvlNode<E> newTop = pred();
883        // newTop is the maximum node in my left subtree
884        newTop.left = left.removeMax(newTop);
885        newTop.right = right;
886        newTop.distinctElements = distinctElements - 1;
887        newTop.totalCount = totalCount - oldElemCount;
888        return newTop.rebalance();
889      } else {
890        AvlNode<E> newTop = succ();
891        newTop.right = right.removeMin(newTop);
892        newTop.left = left;
893        newTop.distinctElements = distinctElements - 1;
894        newTop.totalCount = totalCount - oldElemCount;
895        return newTop.rebalance();
896      }
897    }
898
899    // Removes the minimum node from this subtree to be reused elsewhere
900    private @Nullable AvlNode<E> removeMin(AvlNode<E> node) {
901      if (left == null) {
902        return right;
903      } else {
904        left = left.removeMin(node);
905        distinctElements--;
906        totalCount -= node.elemCount;
907        return rebalance();
908      }
909    }
910
911    // Removes the maximum node from this subtree to be reused elsewhere
912    private @Nullable AvlNode<E> removeMax(AvlNode<E> node) {
913      if (right == null) {
914        return left;
915      } else {
916        right = right.removeMax(node);
917        distinctElements--;
918        totalCount -= node.elemCount;
919        return rebalance();
920      }
921    }
922
923    private void recomputeMultiset() {
924      this.distinctElements =
925          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
926      this.totalCount = elemCount + totalCount(left) + totalCount(right);
927    }
928
929    private void recomputeHeight() {
930      this.height = 1 + max(height(left), height(right));
931    }
932
933    private void recompute() {
934      recomputeMultiset();
935      recomputeHeight();
936    }
937
938    private AvlNode<E> rebalance() {
939      switch (balanceFactor()) {
940        case -2:
941          // requireNonNull is safe because right must exist in order to get a negative factor.
942          requireNonNull(right);
943          if (right.balanceFactor() > 0) {
944            right = right.rotateRight();
945          }
946          return rotateLeft();
947        case 2:
948          // requireNonNull is safe because left must exist in order to get a positive factor.
949          requireNonNull(left);
950          if (left.balanceFactor() < 0) {
951            left = left.rotateLeft();
952          }
953          return rotateRight();
954        default:
955          recomputeHeight();
956          return this;
957      }
958    }
959
960    private int balanceFactor() {
961      return height(left) - height(right);
962    }
963
964    private AvlNode<E> rotateLeft() {
965      checkState(right != null);
966      AvlNode<E> newTop = right;
967      this.right = newTop.left;
968      newTop.left = this;
969      newTop.totalCount = this.totalCount;
970      newTop.distinctElements = this.distinctElements;
971      this.recompute();
972      newTop.recomputeHeight();
973      return newTop;
974    }
975
976    private AvlNode<E> rotateRight() {
977      checkState(left != null);
978      AvlNode<E> newTop = left;
979      this.left = newTop.right;
980      newTop.right = this;
981      newTop.totalCount = this.totalCount;
982      newTop.distinctElements = this.distinctElements;
983      this.recompute();
984      newTop.recomputeHeight();
985      return newTop;
986    }
987
988    private static long totalCount(@Nullable AvlNode<?> node) {
989      return (node == null) ? 0 : node.totalCount;
990    }
991
992    private static int height(@Nullable AvlNode<?> node) {
993      return (node == null) ? 0 : node.height;
994    }
995
996    private @Nullable AvlNode<E> ceiling(
997        Comparator<? super E> comparator, @ParametricNullness E e) {
998      int cmp = comparator.compare(e, getElement());
999      if (cmp < 0) {
1000        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
1001      } else if (cmp == 0) {
1002        return this;
1003      } else {
1004        return (right == null) ? null : right.ceiling(comparator, e);
1005      }
1006    }
1007
1008    private @Nullable AvlNode<E> floor(Comparator<? super E> comparator, @ParametricNullness E e) {
1009      int cmp = comparator.compare(e, getElement());
1010      if (cmp > 0) {
1011        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
1012      } else if (cmp == 0) {
1013        return this;
1014      } else {
1015        return (left == null) ? null : left.floor(comparator, e);
1016      }
1017    }
1018
1019    @ParametricNullness
1020    E getElement() {
1021      // For discussion of this cast, see the comment on the elem field.
1022      return uncheckedCastNullableTToT(elem);
1023    }
1024
1025    int getCount() {
1026      return elemCount;
1027    }
1028
1029    @Override
1030    public String toString() {
1031      return Multisets.immutableEntry(getElement(), getCount()).toString();
1032    }
1033  }
1034
1035  private static <T extends @Nullable Object> void successor(AvlNode<T> a, AvlNode<T> b) {
1036    a.succ = b;
1037    b.pred = a;
1038  }
1039
1040  private static <T extends @Nullable Object> void successor(
1041      AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
1042    successor(a, b);
1043    successor(b, c);
1044  }
1045
1046  /*
1047   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
1048   * calls the comparator to compare the two keys. If that change is made,
1049   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
1050   */
1051
1052  /**
1053   * @serialData the comparator, the number of distinct elements, the first element, its count, the
1054   *     second element, its count, and so on
1055   */
1056  @J2ktIncompatible
1057  @GwtIncompatible // java.io.ObjectOutputStream
1058  private void writeObject(ObjectOutputStream stream) throws IOException {
1059    stream.defaultWriteObject();
1060    stream.writeObject(elementSet().comparator());
1061    Serialization.writeMultiset(this, stream);
1062  }
1063
1064  @J2ktIncompatible
1065  @GwtIncompatible // java.io.ObjectInputStream
1066  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
1067    stream.defaultReadObject();
1068    @SuppressWarnings("unchecked")
1069    // reading data stored by writeObject
1070    Comparator<? super E> comparator = (Comparator<? super E>) requireNonNull(stream.readObject());
1071    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
1072    Serialization.getFieldSetter(TreeMultiset.class, "range")
1073        .set(this, GeneralRange.all(comparator));
1074    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
1075        .set(this, new Reference<AvlNode<E>>());
1076    AvlNode<E> header = new AvlNode<>();
1077    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
1078    successor(header, header);
1079    Serialization.populateMultiset(this, stream);
1080  }
1081
1082  @GwtIncompatible @J2ktIncompatible private static final long serialVersionUID = 1;
1083}