001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkNotNull;
021import static com.google.common.base.Preconditions.checkState;
022import static com.google.common.collect.CollectPreconditions.checkNonnegative;
023import static com.google.common.collect.NullnessCasts.uncheckedCastNullableTToT;
024import static java.lang.Math.max;
025import static java.util.Objects.requireNonNull;
026
027import com.google.common.annotations.GwtCompatible;
028import com.google.common.annotations.GwtIncompatible;
029import com.google.common.annotations.J2ktIncompatible;
030import com.google.common.base.MoreObjects;
031import com.google.common.primitives.Ints;
032import com.google.errorprone.annotations.CanIgnoreReturnValue;
033import java.io.IOException;
034import java.io.ObjectInputStream;
035import java.io.ObjectOutputStream;
036import java.io.Serializable;
037import java.util.Comparator;
038import java.util.ConcurrentModificationException;
039import java.util.Iterator;
040import java.util.NoSuchElementException;
041import java.util.function.ObjIntConsumer;
042import org.jspecify.annotations.Nullable;
043
044/**
045 * A multiset which maintains the ordering of its elements, according to either their natural order
046 * or an explicit {@link Comparator}. In all cases, this implementation uses {@link
047 * Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to determine
048 * equivalence of instances.
049 *
050 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
051 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the {@link
052 * java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
053 *
054 * <p>See the Guava User Guide article on <a href=
055 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset">{@code Multiset}</a>.
056 *
057 * @author Louis Wasserman
058 * @author Jared Levy
059 * @since 2.0
060 */
061@GwtCompatible(emulated = true)
062public final class TreeMultiset<E extends @Nullable Object> extends AbstractSortedMultiset<E>
063    implements Serializable {
064
065  /**
066   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
067   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
068   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
069   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
070   * user attempts to add an element to the multiset that violates this constraint (for example, the
071   * user attempts to add a string element to a set whose elements are integers), the {@code
072   * add(Object)} call will throw a {@code ClassCastException}.
073   *
074   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
075   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
076   */
077  @SuppressWarnings("rawtypes") // https://github.com/google/guava/issues/989
078  public static <E extends Comparable> TreeMultiset<E> create() {
079    return new TreeMultiset<>(Ordering.natural());
080  }
081
082  /**
083   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
084   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
085   * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements
086   * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the
087   * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code
088   * ClassCastException}.
089   *
090   * @param comparator the comparator that will be used to sort this multiset. A null value
091   *     indicates that the elements' <i>natural ordering</i> should be used.
092   */
093  @SuppressWarnings("unchecked")
094  public static <E extends @Nullable Object> TreeMultiset<E> create(
095      @Nullable Comparator<? super E> comparator) {
096    return (comparator == null)
097        ? new TreeMultiset<E>((Comparator) Ordering.natural())
098        : new TreeMultiset<E>(comparator);
099  }
100
101  /**
102   * Creates an empty multiset containing the given initial elements, sorted according to the
103   * elements' natural order.
104   *
105   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
106   *
107   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
108   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
109   */
110  @SuppressWarnings("rawtypes") // https://github.com/google/guava/issues/989
111  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
112    TreeMultiset<E> multiset = create();
113    Iterables.addAll(multiset, elements);
114    return multiset;
115  }
116
117  private final transient Reference<AvlNode<E>> rootReference;
118  private final transient GeneralRange<E> range;
119  private final transient AvlNode<E> header;
120
121  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
122    super(range.comparator());
123    this.rootReference = rootReference;
124    this.range = range;
125    this.header = endLink;
126  }
127
128  TreeMultiset(Comparator<? super E> comparator) {
129    super(comparator);
130    this.range = GeneralRange.all(comparator);
131    this.header = new AvlNode<>();
132    successor(header, header);
133    this.rootReference = new Reference<>();
134  }
135
136  /** A function which can be summed across a subtree. */
137  private enum Aggregate {
138    SIZE {
139      @Override
140      int nodeAggregate(AvlNode<?> node) {
141        return node.elemCount;
142      }
143
144      @Override
145      long treeAggregate(@Nullable AvlNode<?> root) {
146        return (root == null) ? 0 : root.totalCount;
147      }
148    },
149    DISTINCT {
150      @Override
151      int nodeAggregate(AvlNode<?> node) {
152        return 1;
153      }
154
155      @Override
156      long treeAggregate(@Nullable AvlNode<?> root) {
157        return (root == null) ? 0 : root.distinctElements;
158      }
159    };
160
161    abstract int nodeAggregate(AvlNode<?> node);
162
163    abstract long treeAggregate(@Nullable AvlNode<?> root);
164  }
165
166  private long aggregateForEntries(Aggregate aggr) {
167    AvlNode<E> root = rootReference.get();
168    long total = aggr.treeAggregate(root);
169    if (range.hasLowerBound()) {
170      total -= aggregateBelowRange(aggr, root);
171    }
172    if (range.hasUpperBound()) {
173      total -= aggregateAboveRange(aggr, root);
174    }
175    return total;
176  }
177
178  private long aggregateBelowRange(Aggregate aggr, @Nullable AvlNode<E> node) {
179    if (node == null) {
180      return 0;
181    }
182    // The cast is safe because we call this method only if hasLowerBound().
183    int cmp =
184        comparator()
185            .compare(uncheckedCastNullableTToT(range.getLowerEndpoint()), node.getElement());
186    if (cmp < 0) {
187      return aggregateBelowRange(aggr, node.left);
188    } else if (cmp == 0) {
189      switch (range.getLowerBoundType()) {
190        case OPEN:
191          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
192        case CLOSED:
193          return aggr.treeAggregate(node.left);
194      }
195      throw new AssertionError();
196    } else {
197      return aggr.treeAggregate(node.left)
198          + aggr.nodeAggregate(node)
199          + aggregateBelowRange(aggr, node.right);
200    }
201  }
202
203  private long aggregateAboveRange(Aggregate aggr, @Nullable AvlNode<E> node) {
204    if (node == null) {
205      return 0;
206    }
207    // The cast is safe because we call this method only if hasUpperBound().
208    int cmp =
209        comparator()
210            .compare(uncheckedCastNullableTToT(range.getUpperEndpoint()), node.getElement());
211    if (cmp > 0) {
212      return aggregateAboveRange(aggr, node.right);
213    } else if (cmp == 0) {
214      switch (range.getUpperBoundType()) {
215        case OPEN:
216          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
217        case CLOSED:
218          return aggr.treeAggregate(node.right);
219      }
220      throw new AssertionError();
221    } else {
222      return aggr.treeAggregate(node.right)
223          + aggr.nodeAggregate(node)
224          + aggregateAboveRange(aggr, node.left);
225    }
226  }
227
228  @Override
229  public int size() {
230    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
231  }
232
233  @Override
234  int distinctElements() {
235    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
236  }
237
238  static int distinctElements(@Nullable AvlNode<?> node) {
239    return (node == null) ? 0 : node.distinctElements;
240  }
241
242  @Override
243  public int count(@Nullable Object element) {
244    try {
245      @SuppressWarnings("unchecked")
246      E e = (E) element;
247      AvlNode<E> root = rootReference.get();
248      if (!range.contains(e) || root == null) {
249        return 0;
250      }
251      return root.count(comparator(), e);
252    } catch (ClassCastException | NullPointerException e) {
253      return 0;
254    }
255  }
256
257  @CanIgnoreReturnValue
258  @Override
259  public int add(@ParametricNullness E element, int occurrences) {
260    checkNonnegative(occurrences, "occurrences");
261    if (occurrences == 0) {
262      return count(element);
263    }
264    checkArgument(range.contains(element));
265    AvlNode<E> root = rootReference.get();
266    if (root == null) {
267      int unused = comparator().compare(element, element);
268      AvlNode<E> newRoot = new AvlNode<>(element, occurrences);
269      successor(header, newRoot, header);
270      rootReference.checkAndSet(root, newRoot);
271      return 0;
272    }
273    int[] result = new int[1]; // used as a mutable int reference to hold result
274    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
275    rootReference.checkAndSet(root, newRoot);
276    return result[0];
277  }
278
279  @CanIgnoreReturnValue
280  @Override
281  public int remove(@Nullable Object element, int occurrences) {
282    checkNonnegative(occurrences, "occurrences");
283    if (occurrences == 0) {
284      return count(element);
285    }
286    AvlNode<E> root = rootReference.get();
287    int[] result = new int[1]; // used as a mutable int reference to hold result
288    AvlNode<E> newRoot;
289    try {
290      @SuppressWarnings("unchecked")
291      E e = (E) element;
292      if (!range.contains(e) || root == null) {
293        return 0;
294      }
295      newRoot = root.remove(comparator(), e, occurrences, result);
296    } catch (ClassCastException | NullPointerException e) {
297      return 0;
298    }
299    rootReference.checkAndSet(root, newRoot);
300    return result[0];
301  }
302
303  @CanIgnoreReturnValue
304  @Override
305  public int setCount(@ParametricNullness E element, int count) {
306    checkNonnegative(count, "count");
307    if (!range.contains(element)) {
308      checkArgument(count == 0);
309      return 0;
310    }
311
312    AvlNode<E> root = rootReference.get();
313    if (root == null) {
314      if (count > 0) {
315        add(element, count);
316      }
317      return 0;
318    }
319    int[] result = new int[1]; // used as a mutable int reference to hold result
320    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
321    rootReference.checkAndSet(root, newRoot);
322    return result[0];
323  }
324
325  @CanIgnoreReturnValue
326  @Override
327  public boolean setCount(@ParametricNullness E element, int oldCount, int newCount) {
328    checkNonnegative(newCount, "newCount");
329    checkNonnegative(oldCount, "oldCount");
330    checkArgument(range.contains(element));
331
332    AvlNode<E> root = rootReference.get();
333    if (root == null) {
334      if (oldCount == 0) {
335        if (newCount > 0) {
336          add(element, newCount);
337        }
338        return true;
339      } else {
340        return false;
341      }
342    }
343    int[] result = new int[1]; // used as a mutable int reference to hold result
344    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
345    rootReference.checkAndSet(root, newRoot);
346    return result[0] == oldCount;
347  }
348
349  @Override
350  public void clear() {
351    if (!range.hasLowerBound() && !range.hasUpperBound()) {
352      // We can do this in O(n) rather than removing one by one, which could force rebalancing.
353      for (AvlNode<E> current = header.succ(); current != header; ) {
354        AvlNode<E> next = current.succ();
355
356        current.elemCount = 0;
357        // Also clear these fields so that one deleted Entry doesn't retain all elements.
358        current.left = null;
359        current.right = null;
360        current.pred = null;
361        current.succ = null;
362
363        current = next;
364      }
365      successor(header, header);
366      rootReference.clear();
367    } else {
368      // TODO(cpovirk): Perhaps we can optimize in this case, too?
369      Iterators.clear(entryIterator());
370    }
371  }
372
373  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
374    return new Multisets.AbstractEntry<E>() {
375      @Override
376      @ParametricNullness
377      public E getElement() {
378        return baseEntry.getElement();
379      }
380
381      @Override
382      public int getCount() {
383        int result = baseEntry.getCount();
384        if (result == 0) {
385          return count(getElement());
386        } else {
387          return result;
388        }
389      }
390    };
391  }
392
393  /** Returns the first node in the tree that is in range. */
394  private @Nullable AvlNode<E> firstNode() {
395    AvlNode<E> root = rootReference.get();
396    if (root == null) {
397      return null;
398    }
399    AvlNode<E> node;
400    if (range.hasLowerBound()) {
401      // The cast is safe because of the hasLowerBound check.
402      E endpoint = uncheckedCastNullableTToT(range.getLowerEndpoint());
403      node = root.ceiling(comparator(), endpoint);
404      if (node == null) {
405        return null;
406      }
407      if (range.getLowerBoundType() == BoundType.OPEN
408          && comparator().compare(endpoint, node.getElement()) == 0) {
409        node = node.succ();
410      }
411    } else {
412      node = header.succ();
413    }
414    return (node == header || !range.contains(node.getElement())) ? null : node;
415  }
416
417  private @Nullable AvlNode<E> lastNode() {
418    AvlNode<E> root = rootReference.get();
419    if (root == null) {
420      return null;
421    }
422    AvlNode<E> node;
423    if (range.hasUpperBound()) {
424      // The cast is safe because of the hasUpperBound check.
425      E endpoint = uncheckedCastNullableTToT(range.getUpperEndpoint());
426      node = root.floor(comparator(), endpoint);
427      if (node == null) {
428        return null;
429      }
430      if (range.getUpperBoundType() == BoundType.OPEN
431          && comparator().compare(endpoint, node.getElement()) == 0) {
432        node = node.pred();
433      }
434    } else {
435      node = header.pred();
436    }
437    return (node == header || !range.contains(node.getElement())) ? null : node;
438  }
439
440  @Override
441  Iterator<E> elementIterator() {
442    return Multisets.elementIterator(entryIterator());
443  }
444
445  @Override
446  Iterator<Entry<E>> entryIterator() {
447    return new Iterator<Entry<E>>() {
448      @Nullable AvlNode<E> current = firstNode();
449      @Nullable Entry<E> prevEntry;
450
451      @Override
452      public boolean hasNext() {
453        if (current == null) {
454          return false;
455        } else if (range.tooHigh(current.getElement())) {
456          current = null;
457          return false;
458        } else {
459          return true;
460        }
461      }
462
463      @Override
464      public Entry<E> next() {
465        if (!hasNext()) {
466          throw new NoSuchElementException();
467        }
468        // requireNonNull is safe because current is only nulled out after iteration is complete.
469        Entry<E> result = wrapEntry(requireNonNull(current));
470        prevEntry = result;
471        if (current.succ() == header) {
472          current = null;
473        } else {
474          current = current.succ();
475        }
476        return result;
477      }
478
479      @Override
480      public void remove() {
481        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
482        setCount(prevEntry.getElement(), 0);
483        prevEntry = null;
484      }
485    };
486  }
487
488  @Override
489  Iterator<Entry<E>> descendingEntryIterator() {
490    return new Iterator<Entry<E>>() {
491      @Nullable AvlNode<E> current = lastNode();
492      @Nullable Entry<E> prevEntry = null;
493
494      @Override
495      public boolean hasNext() {
496        if (current == null) {
497          return false;
498        } else if (range.tooLow(current.getElement())) {
499          current = null;
500          return false;
501        } else {
502          return true;
503        }
504      }
505
506      @Override
507      public Entry<E> next() {
508        if (!hasNext()) {
509          throw new NoSuchElementException();
510        }
511        // requireNonNull is safe because current is only nulled out after iteration is complete.
512        requireNonNull(current);
513        Entry<E> result = wrapEntry(current);
514        prevEntry = result;
515        if (current.pred() == header) {
516          current = null;
517        } else {
518          current = current.pred();
519        }
520        return result;
521      }
522
523      @Override
524      public void remove() {
525        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
526        setCount(prevEntry.getElement(), 0);
527        prevEntry = null;
528      }
529    };
530  }
531
532  @Override
533  public void forEachEntry(ObjIntConsumer<? super E> action) {
534    checkNotNull(action);
535    for (AvlNode<E> node = firstNode();
536        node != header && node != null && !range.tooHigh(node.getElement());
537        node = node.succ()) {
538      action.accept(node.getElement(), node.getCount());
539    }
540  }
541
542  @Override
543  public Iterator<E> iterator() {
544    return Multisets.iteratorImpl(this);
545  }
546
547  @Override
548  public SortedMultiset<E> headMultiset(@ParametricNullness E upperBound, BoundType boundType) {
549    return new TreeMultiset<>(
550        rootReference,
551        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
552        header);
553  }
554
555  @Override
556  public SortedMultiset<E> tailMultiset(@ParametricNullness E lowerBound, BoundType boundType) {
557    return new TreeMultiset<>(
558        rootReference,
559        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
560        header);
561  }
562
563  private static final class Reference<T> {
564    private @Nullable T value;
565
566    public @Nullable T get() {
567      return value;
568    }
569
570    public void checkAndSet(@Nullable T expected, @Nullable T newValue) {
571      if (value != expected) {
572        throw new ConcurrentModificationException();
573      }
574      value = newValue;
575    }
576
577    void clear() {
578      value = null;
579    }
580  }
581
582  private static final class AvlNode<E extends @Nullable Object> {
583    /*
584     * For "normal" nodes, the type of this field is `E`, not `@Nullable E` (though note that E is a
585     * type that can include null, as in a TreeMultiset<@Nullable String>).
586     *
587     * For the header node, though, this field contains `null`, regardless of the type of the
588     * multiset.
589     *
590     * Most code that operates on an AvlNode never operates on the header node. Such code can access
591     * the elem field without a null check by calling getElement().
592     */
593    private final @Nullable E elem;
594
595    // elemCount is 0 iff this node has been deleted.
596    private int elemCount;
597
598    private int distinctElements;
599    private long totalCount;
600    private int height;
601    private @Nullable AvlNode<E> left;
602    private @Nullable AvlNode<E> right;
603    /*
604     * pred and succ are nullable after construction, but we always call successor() to initialize
605     * them immediately thereafter.
606     *
607     * They may be subsequently nulled out by TreeMultiset.clear(). I think that the only place that
608     * we can reference a node whose fields have been cleared is inside the iterator (and presumably
609     * only under concurrent modification).
610     *
611     * To access these fields when you know that they are not null, call the pred() and succ()
612     * methods, which perform null checks before returning the fields.
613     */
614    private @Nullable AvlNode<E> pred;
615    private @Nullable AvlNode<E> succ;
616
617    AvlNode(@ParametricNullness E elem, int elemCount) {
618      checkArgument(elemCount > 0);
619      this.elem = elem;
620      this.elemCount = elemCount;
621      this.totalCount = elemCount;
622      this.distinctElements = 1;
623      this.height = 1;
624      this.left = null;
625      this.right = null;
626    }
627
628    /** Constructor for the header node. */
629    AvlNode() {
630      this.elem = null;
631      this.elemCount = 1;
632    }
633
634    // For discussion of pred() and succ(), see the comment on the pred and succ fields.
635
636    private AvlNode<E> pred() {
637      return requireNonNull(pred);
638    }
639
640    private AvlNode<E> succ() {
641      return requireNonNull(succ);
642    }
643
644    int count(Comparator<? super E> comparator, @ParametricNullness E e) {
645      int cmp = comparator.compare(e, getElement());
646      if (cmp < 0) {
647        return (left == null) ? 0 : left.count(comparator, e);
648      } else if (cmp > 0) {
649        return (right == null) ? 0 : right.count(comparator, e);
650      } else {
651        return elemCount;
652      }
653    }
654
655    private AvlNode<E> addRightChild(@ParametricNullness E e, int count) {
656      right = new AvlNode<>(e, count);
657      successor(this, right, succ());
658      height = max(2, height);
659      distinctElements++;
660      totalCount += count;
661      return this;
662    }
663
664    private AvlNode<E> addLeftChild(@ParametricNullness E e, int count) {
665      left = new AvlNode<>(e, count);
666      successor(pred(), left, this);
667      height = max(2, height);
668      distinctElements++;
669      totalCount += count;
670      return this;
671    }
672
673    AvlNode<E> add(
674        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
675      /*
676       * It speeds things up considerably to unconditionally add count to totalCount here,
677       * but that destroys failure atomicity in the case of count overflow. =(
678       */
679      int cmp = comparator.compare(e, getElement());
680      if (cmp < 0) {
681        AvlNode<E> initLeft = left;
682        if (initLeft == null) {
683          result[0] = 0;
684          return addLeftChild(e, count);
685        }
686        int initHeight = initLeft.height;
687
688        left = initLeft.add(comparator, e, count, result);
689        if (result[0] == 0) {
690          distinctElements++;
691        }
692        this.totalCount += count;
693        return (left.height == initHeight) ? this : rebalance();
694      } else if (cmp > 0) {
695        AvlNode<E> initRight = right;
696        if (initRight == null) {
697          result[0] = 0;
698          return addRightChild(e, count);
699        }
700        int initHeight = initRight.height;
701
702        right = initRight.add(comparator, e, count, result);
703        if (result[0] == 0) {
704          distinctElements++;
705        }
706        this.totalCount += count;
707        return (right.height == initHeight) ? this : rebalance();
708      }
709
710      // adding count to me!  No rebalance possible.
711      result[0] = elemCount;
712      long resultCount = (long) elemCount + count;
713      checkArgument(resultCount <= Integer.MAX_VALUE);
714      this.elemCount += count;
715      this.totalCount += count;
716      return this;
717    }
718
719    @Nullable AvlNode<E> remove(
720        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
721      int cmp = comparator.compare(e, getElement());
722      if (cmp < 0) {
723        AvlNode<E> initLeft = left;
724        if (initLeft == null) {
725          result[0] = 0;
726          return this;
727        }
728
729        left = initLeft.remove(comparator, e, count, result);
730
731        if (result[0] > 0) {
732          if (count >= result[0]) {
733            this.distinctElements--;
734            this.totalCount -= result[0];
735          } else {
736            this.totalCount -= count;
737          }
738        }
739        return (result[0] == 0) ? this : rebalance();
740      } else if (cmp > 0) {
741        AvlNode<E> initRight = right;
742        if (initRight == null) {
743          result[0] = 0;
744          return this;
745        }
746
747        right = initRight.remove(comparator, e, count, result);
748
749        if (result[0] > 0) {
750          if (count >= result[0]) {
751            this.distinctElements--;
752            this.totalCount -= result[0];
753          } else {
754            this.totalCount -= count;
755          }
756        }
757        return rebalance();
758      }
759
760      // removing count from me!
761      result[0] = elemCount;
762      if (count >= elemCount) {
763        return deleteMe();
764      } else {
765        this.elemCount -= count;
766        this.totalCount -= count;
767        return this;
768      }
769    }
770
771    @Nullable AvlNode<E> setCount(
772        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
773      int cmp = comparator.compare(e, getElement());
774      if (cmp < 0) {
775        AvlNode<E> initLeft = left;
776        if (initLeft == null) {
777          result[0] = 0;
778          return (count > 0) ? addLeftChild(e, count) : this;
779        }
780
781        left = initLeft.setCount(comparator, e, count, result);
782
783        if (count == 0 && result[0] != 0) {
784          this.distinctElements--;
785        } else if (count > 0 && result[0] == 0) {
786          this.distinctElements++;
787        }
788
789        this.totalCount += count - result[0];
790        return rebalance();
791      } else if (cmp > 0) {
792        AvlNode<E> initRight = right;
793        if (initRight == null) {
794          result[0] = 0;
795          return (count > 0) ? addRightChild(e, count) : this;
796        }
797
798        right = initRight.setCount(comparator, e, count, result);
799
800        if (count == 0 && result[0] != 0) {
801          this.distinctElements--;
802        } else if (count > 0 && result[0] == 0) {
803          this.distinctElements++;
804        }
805
806        this.totalCount += count - result[0];
807        return rebalance();
808      }
809
810      // setting my count
811      result[0] = elemCount;
812      if (count == 0) {
813        return deleteMe();
814      }
815      this.totalCount += count - elemCount;
816      this.elemCount = count;
817      return this;
818    }
819
820    @Nullable AvlNode<E> setCount(
821        Comparator<? super E> comparator,
822        @ParametricNullness E e,
823        int expectedCount,
824        int newCount,
825        int[] result) {
826      int cmp = comparator.compare(e, getElement());
827      if (cmp < 0) {
828        AvlNode<E> initLeft = left;
829        if (initLeft == null) {
830          result[0] = 0;
831          if (expectedCount == 0 && newCount > 0) {
832            return addLeftChild(e, newCount);
833          }
834          return this;
835        }
836
837        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
838
839        if (result[0] == expectedCount) {
840          if (newCount == 0 && result[0] != 0) {
841            this.distinctElements--;
842          } else if (newCount > 0 && result[0] == 0) {
843            this.distinctElements++;
844          }
845          this.totalCount += newCount - result[0];
846        }
847        return rebalance();
848      } else if (cmp > 0) {
849        AvlNode<E> initRight = right;
850        if (initRight == null) {
851          result[0] = 0;
852          if (expectedCount == 0 && newCount > 0) {
853            return addRightChild(e, newCount);
854          }
855          return this;
856        }
857
858        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
859
860        if (result[0] == expectedCount) {
861          if (newCount == 0 && result[0] != 0) {
862            this.distinctElements--;
863          } else if (newCount > 0 && result[0] == 0) {
864            this.distinctElements++;
865          }
866          this.totalCount += newCount - result[0];
867        }
868        return rebalance();
869      }
870
871      // setting my count
872      result[0] = elemCount;
873      if (expectedCount == elemCount) {
874        if (newCount == 0) {
875          return deleteMe();
876        }
877        this.totalCount += newCount - elemCount;
878        this.elemCount = newCount;
879      }
880      return this;
881    }
882
883    private @Nullable AvlNode<E> deleteMe() {
884      int oldElemCount = this.elemCount;
885      this.elemCount = 0;
886      successor(pred(), succ());
887      if (left == null) {
888        return right;
889      } else if (right == null) {
890        return left;
891      } else if (left.height >= right.height) {
892        AvlNode<E> newTop = pred();
893        // newTop is the maximum node in my left subtree
894        newTop.left = left.removeMax(newTop);
895        newTop.right = right;
896        newTop.distinctElements = distinctElements - 1;
897        newTop.totalCount = totalCount - oldElemCount;
898        return newTop.rebalance();
899      } else {
900        AvlNode<E> newTop = succ();
901        newTop.right = right.removeMin(newTop);
902        newTop.left = left;
903        newTop.distinctElements = distinctElements - 1;
904        newTop.totalCount = totalCount - oldElemCount;
905        return newTop.rebalance();
906      }
907    }
908
909    // Removes the minimum node from this subtree to be reused elsewhere
910    private @Nullable AvlNode<E> removeMin(AvlNode<E> node) {
911      if (left == null) {
912        return right;
913      } else {
914        left = left.removeMin(node);
915        distinctElements--;
916        totalCount -= node.elemCount;
917        return rebalance();
918      }
919    }
920
921    // Removes the maximum node from this subtree to be reused elsewhere
922    private @Nullable AvlNode<E> removeMax(AvlNode<E> node) {
923      if (right == null) {
924        return left;
925      } else {
926        right = right.removeMax(node);
927        distinctElements--;
928        totalCount -= node.elemCount;
929        return rebalance();
930      }
931    }
932
933    private void recomputeMultiset() {
934      this.distinctElements =
935          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
936      this.totalCount = elemCount + totalCount(left) + totalCount(right);
937    }
938
939    private void recomputeHeight() {
940      this.height = 1 + max(height(left), height(right));
941    }
942
943    private void recompute() {
944      recomputeMultiset();
945      recomputeHeight();
946    }
947
948    private AvlNode<E> rebalance() {
949      switch (balanceFactor()) {
950        case -2:
951          // requireNonNull is safe because right must exist in order to get a negative factor.
952          requireNonNull(right);
953          if (right.balanceFactor() > 0) {
954            right = right.rotateRight();
955          }
956          return rotateLeft();
957        case 2:
958          // requireNonNull is safe because left must exist in order to get a positive factor.
959          requireNonNull(left);
960          if (left.balanceFactor() < 0) {
961            left = left.rotateLeft();
962          }
963          return rotateRight();
964        default:
965          recomputeHeight();
966          return this;
967      }
968    }
969
970    private int balanceFactor() {
971      return height(left) - height(right);
972    }
973
974    private AvlNode<E> rotateLeft() {
975      checkState(right != null);
976      AvlNode<E> newTop = right;
977      this.right = newTop.left;
978      newTop.left = this;
979      newTop.totalCount = this.totalCount;
980      newTop.distinctElements = this.distinctElements;
981      this.recompute();
982      newTop.recomputeHeight();
983      return newTop;
984    }
985
986    private AvlNode<E> rotateRight() {
987      checkState(left != null);
988      AvlNode<E> newTop = left;
989      this.left = newTop.right;
990      newTop.right = this;
991      newTop.totalCount = this.totalCount;
992      newTop.distinctElements = this.distinctElements;
993      this.recompute();
994      newTop.recomputeHeight();
995      return newTop;
996    }
997
998    private static long totalCount(@Nullable AvlNode<?> node) {
999      return (node == null) ? 0 : node.totalCount;
1000    }
1001
1002    private static int height(@Nullable AvlNode<?> node) {
1003      return (node == null) ? 0 : node.height;
1004    }
1005
1006    private @Nullable AvlNode<E> ceiling(
1007        Comparator<? super E> comparator, @ParametricNullness E e) {
1008      int cmp = comparator.compare(e, getElement());
1009      if (cmp < 0) {
1010        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
1011      } else if (cmp == 0) {
1012        return this;
1013      } else {
1014        return (right == null) ? null : right.ceiling(comparator, e);
1015      }
1016    }
1017
1018    private @Nullable AvlNode<E> floor(Comparator<? super E> comparator, @ParametricNullness E e) {
1019      int cmp = comparator.compare(e, getElement());
1020      if (cmp > 0) {
1021        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
1022      } else if (cmp == 0) {
1023        return this;
1024      } else {
1025        return (left == null) ? null : left.floor(comparator, e);
1026      }
1027    }
1028
1029    @ParametricNullness
1030    E getElement() {
1031      // For discussion of this cast, see the comment on the elem field.
1032      return uncheckedCastNullableTToT(elem);
1033    }
1034
1035    int getCount() {
1036      return elemCount;
1037    }
1038
1039    @Override
1040    public String toString() {
1041      return Multisets.immutableEntry(getElement(), getCount()).toString();
1042    }
1043  }
1044
1045  private static <T extends @Nullable Object> void successor(AvlNode<T> a, AvlNode<T> b) {
1046    a.succ = b;
1047    b.pred = a;
1048  }
1049
1050  private static <T extends @Nullable Object> void successor(
1051      AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
1052    successor(a, b);
1053    successor(b, c);
1054  }
1055
1056  /*
1057   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
1058   * calls the comparator to compare the two keys. If that change is made,
1059   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
1060   */
1061
1062  /**
1063   * @serialData the comparator, the number of distinct elements, the first element, its count, the
1064   *     second element, its count, and so on
1065   */
1066  @J2ktIncompatible
1067  @GwtIncompatible // java.io.ObjectOutputStream
1068  private void writeObject(ObjectOutputStream stream) throws IOException {
1069    stream.defaultWriteObject();
1070    stream.writeObject(elementSet().comparator());
1071    Serialization.writeMultiset(this, stream);
1072  }
1073
1074  @J2ktIncompatible
1075  @GwtIncompatible // java.io.ObjectInputStream
1076  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
1077    stream.defaultReadObject();
1078    @SuppressWarnings("unchecked")
1079    // reading data stored by writeObject
1080    Comparator<? super E> comparator = (Comparator<? super E>) requireNonNull(stream.readObject());
1081    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
1082    Serialization.getFieldSetter(TreeMultiset.class, "range")
1083        .set(this, GeneralRange.all(comparator));
1084    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
1085        .set(this, new Reference<AvlNode<E>>());
1086    AvlNode<E> header = new AvlNode<>();
1087    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
1088    successor(header, header);
1089    Serialization.populateMultiset(this, stream);
1090  }
1091
1092  @GwtIncompatible // not needed in emulated source
1093  @J2ktIncompatible
1094  private static final long serialVersionUID = 1;
1095}