001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkNotNull;
021import static com.google.common.base.Preconditions.checkState;
022import static com.google.common.collect.CollectPreconditions.checkNonnegative;
023import static com.google.common.collect.NullnessCasts.uncheckedCastNullableTToT;
024import static java.util.Objects.requireNonNull;
025
026import com.google.common.annotations.GwtCompatible;
027import com.google.common.annotations.GwtIncompatible;
028import com.google.common.annotations.J2ktIncompatible;
029import com.google.common.base.MoreObjects;
030import com.google.common.primitives.Ints;
031import com.google.errorprone.annotations.CanIgnoreReturnValue;
032import java.io.IOException;
033import java.io.ObjectInputStream;
034import java.io.ObjectOutputStream;
035import java.io.Serializable;
036import java.util.Comparator;
037import java.util.ConcurrentModificationException;
038import java.util.Iterator;
039import java.util.NoSuchElementException;
040import java.util.function.ObjIntConsumer;
041import javax.annotation.CheckForNull;
042import org.checkerframework.checker.nullness.qual.Nullable;
043
044/**
045 * A multiset which maintains the ordering of its elements, according to either their natural order
046 * or an explicit {@link Comparator}. In all cases, this implementation uses {@link
047 * Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to determine
048 * equivalence of instances.
049 *
050 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
051 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the {@link
052 * java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
053 *
054 * <p>See the Guava User Guide article on <a href=
055 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset">{@code Multiset}</a>.
056 *
057 * @author Louis Wasserman
058 * @author Jared Levy
059 * @since 2.0
060 */
061@GwtCompatible(emulated = true)
062@ElementTypesAreNonnullByDefault
063public final class TreeMultiset<E extends @Nullable Object> extends AbstractSortedMultiset<E>
064    implements Serializable {
065
066  /**
067   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
068   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
069   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
070   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
071   * user attempts to add an element to the multiset that violates this constraint (for example, the
072   * user attempts to add a string element to a set whose elements are integers), the {@code
073   * add(Object)} call will throw a {@code ClassCastException}.
074   *
075   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
076   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
077   */
078  @SuppressWarnings("rawtypes") // https://github.com/google/guava/issues/989
079  public static <E extends Comparable> TreeMultiset<E> create() {
080    return new TreeMultiset<>(Ordering.natural());
081  }
082
083  /**
084   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
085   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
086   * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements
087   * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the
088   * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code
089   * ClassCastException}.
090   *
091   * @param comparator the comparator that will be used to sort this multiset. A null value
092   *     indicates that the elements' <i>natural ordering</i> should be used.
093   */
094  @SuppressWarnings("unchecked")
095  public static <E extends @Nullable Object> TreeMultiset<E> create(
096      @CheckForNull Comparator<? super E> comparator) {
097    return (comparator == null)
098        ? new TreeMultiset<E>((Comparator) Ordering.natural())
099        : new TreeMultiset<E>(comparator);
100  }
101
102  /**
103   * Creates an empty multiset containing the given initial elements, sorted according to the
104   * elements' natural order.
105   *
106   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
107   *
108   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
109   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
110   */
111  @SuppressWarnings("rawtypes") // https://github.com/google/guava/issues/989
112  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
113    TreeMultiset<E> multiset = create();
114    Iterables.addAll(multiset, elements);
115    return multiset;
116  }
117
118  private final transient Reference<AvlNode<E>> rootReference;
119  private final transient GeneralRange<E> range;
120  private final transient AvlNode<E> header;
121
122  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
123    super(range.comparator());
124    this.rootReference = rootReference;
125    this.range = range;
126    this.header = endLink;
127  }
128
129  TreeMultiset(Comparator<? super E> comparator) {
130    super(comparator);
131    this.range = GeneralRange.all(comparator);
132    this.header = new AvlNode<>();
133    successor(header, header);
134    this.rootReference = new Reference<>();
135  }
136
137  /** A function which can be summed across a subtree. */
138  private enum Aggregate {
139    SIZE {
140      @Override
141      int nodeAggregate(AvlNode<?> node) {
142        return node.elemCount;
143      }
144
145      @Override
146      long treeAggregate(@CheckForNull AvlNode<?> root) {
147        return (root == null) ? 0 : root.totalCount;
148      }
149    },
150    DISTINCT {
151      @Override
152      int nodeAggregate(AvlNode<?> node) {
153        return 1;
154      }
155
156      @Override
157      long treeAggregate(@CheckForNull AvlNode<?> root) {
158        return (root == null) ? 0 : root.distinctElements;
159      }
160    };
161
162    abstract int nodeAggregate(AvlNode<?> node);
163
164    abstract long treeAggregate(@CheckForNull AvlNode<?> root);
165  }
166
167  private long aggregateForEntries(Aggregate aggr) {
168    AvlNode<E> root = rootReference.get();
169    long total = aggr.treeAggregate(root);
170    if (range.hasLowerBound()) {
171      total -= aggregateBelowRange(aggr, root);
172    }
173    if (range.hasUpperBound()) {
174      total -= aggregateAboveRange(aggr, root);
175    }
176    return total;
177  }
178
179  private long aggregateBelowRange(Aggregate aggr, @CheckForNull AvlNode<E> node) {
180    if (node == null) {
181      return 0;
182    }
183    // The cast is safe because we call this method only if hasLowerBound().
184    int cmp =
185        comparator()
186            .compare(uncheckedCastNullableTToT(range.getLowerEndpoint()), node.getElement());
187    if (cmp < 0) {
188      return aggregateBelowRange(aggr, node.left);
189    } else if (cmp == 0) {
190      switch (range.getLowerBoundType()) {
191        case OPEN:
192          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
193        case CLOSED:
194          return aggr.treeAggregate(node.left);
195        default:
196          throw new AssertionError();
197      }
198    } else {
199      return aggr.treeAggregate(node.left)
200          + aggr.nodeAggregate(node)
201          + aggregateBelowRange(aggr, node.right);
202    }
203  }
204
205  private long aggregateAboveRange(Aggregate aggr, @CheckForNull AvlNode<E> node) {
206    if (node == null) {
207      return 0;
208    }
209    // The cast is safe because we call this method only if hasUpperBound().
210    int cmp =
211        comparator()
212            .compare(uncheckedCastNullableTToT(range.getUpperEndpoint()), node.getElement());
213    if (cmp > 0) {
214      return aggregateAboveRange(aggr, node.right);
215    } else if (cmp == 0) {
216      switch (range.getUpperBoundType()) {
217        case OPEN:
218          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
219        case CLOSED:
220          return aggr.treeAggregate(node.right);
221        default:
222          throw new AssertionError();
223      }
224    } else {
225      return aggr.treeAggregate(node.right)
226          + aggr.nodeAggregate(node)
227          + aggregateAboveRange(aggr, node.left);
228    }
229  }
230
231  @Override
232  public int size() {
233    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
234  }
235
236  @Override
237  int distinctElements() {
238    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
239  }
240
241  static int distinctElements(@CheckForNull AvlNode<?> node) {
242    return (node == null) ? 0 : node.distinctElements;
243  }
244
245  @Override
246  public int count(@CheckForNull Object element) {
247    try {
248      @SuppressWarnings("unchecked")
249      E e = (E) element;
250      AvlNode<E> root = rootReference.get();
251      if (!range.contains(e) || root == null) {
252        return 0;
253      }
254      return root.count(comparator(), e);
255    } catch (ClassCastException | NullPointerException e) {
256      return 0;
257    }
258  }
259
260  @CanIgnoreReturnValue
261  @Override
262  public int add(@ParametricNullness E element, int occurrences) {
263    checkNonnegative(occurrences, "occurrences");
264    if (occurrences == 0) {
265      return count(element);
266    }
267    checkArgument(range.contains(element));
268    AvlNode<E> root = rootReference.get();
269    if (root == null) {
270      int unused = comparator().compare(element, element);
271      AvlNode<E> newRoot = new AvlNode<>(element, occurrences);
272      successor(header, newRoot, header);
273      rootReference.checkAndSet(root, newRoot);
274      return 0;
275    }
276    int[] result = new int[1]; // used as a mutable int reference to hold result
277    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
278    rootReference.checkAndSet(root, newRoot);
279    return result[0];
280  }
281
282  @CanIgnoreReturnValue
283  @Override
284  public int remove(@CheckForNull Object element, int occurrences) {
285    checkNonnegative(occurrences, "occurrences");
286    if (occurrences == 0) {
287      return count(element);
288    }
289    AvlNode<E> root = rootReference.get();
290    int[] result = new int[1]; // used as a mutable int reference to hold result
291    AvlNode<E> newRoot;
292    try {
293      @SuppressWarnings("unchecked")
294      E e = (E) element;
295      if (!range.contains(e) || root == null) {
296        return 0;
297      }
298      newRoot = root.remove(comparator(), e, occurrences, result);
299    } catch (ClassCastException | NullPointerException e) {
300      return 0;
301    }
302    rootReference.checkAndSet(root, newRoot);
303    return result[0];
304  }
305
306  @CanIgnoreReturnValue
307  @Override
308  public int setCount(@ParametricNullness E element, int count) {
309    checkNonnegative(count, "count");
310    if (!range.contains(element)) {
311      checkArgument(count == 0);
312      return 0;
313    }
314
315    AvlNode<E> root = rootReference.get();
316    if (root == null) {
317      if (count > 0) {
318        add(element, count);
319      }
320      return 0;
321    }
322    int[] result = new int[1]; // used as a mutable int reference to hold result
323    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
324    rootReference.checkAndSet(root, newRoot);
325    return result[0];
326  }
327
328  @CanIgnoreReturnValue
329  @Override
330  public boolean setCount(@ParametricNullness E element, int oldCount, int newCount) {
331    checkNonnegative(newCount, "newCount");
332    checkNonnegative(oldCount, "oldCount");
333    checkArgument(range.contains(element));
334
335    AvlNode<E> root = rootReference.get();
336    if (root == null) {
337      if (oldCount == 0) {
338        if (newCount > 0) {
339          add(element, newCount);
340        }
341        return true;
342      } else {
343        return false;
344      }
345    }
346    int[] result = new int[1]; // used as a mutable int reference to hold result
347    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
348    rootReference.checkAndSet(root, newRoot);
349    return result[0] == oldCount;
350  }
351
352  @Override
353  public void clear() {
354    if (!range.hasLowerBound() && !range.hasUpperBound()) {
355      // We can do this in O(n) rather than removing one by one, which could force rebalancing.
356      for (AvlNode<E> current = header.succ(); current != header; ) {
357        AvlNode<E> next = current.succ();
358
359        current.elemCount = 0;
360        // Also clear these fields so that one deleted Entry doesn't retain all elements.
361        current.left = null;
362        current.right = null;
363        current.pred = null;
364        current.succ = null;
365
366        current = next;
367      }
368      successor(header, header);
369      rootReference.clear();
370    } else {
371      // TODO(cpovirk): Perhaps we can optimize in this case, too?
372      Iterators.clear(entryIterator());
373    }
374  }
375
376  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
377    return new Multisets.AbstractEntry<E>() {
378      @Override
379      @ParametricNullness
380      public E getElement() {
381        return baseEntry.getElement();
382      }
383
384      @Override
385      public int getCount() {
386        int result = baseEntry.getCount();
387        if (result == 0) {
388          return count(getElement());
389        } else {
390          return result;
391        }
392      }
393    };
394  }
395
396  /** Returns the first node in the tree that is in range. */
397  @CheckForNull
398  private AvlNode<E> firstNode() {
399    AvlNode<E> root = rootReference.get();
400    if (root == null) {
401      return null;
402    }
403    AvlNode<E> node;
404    if (range.hasLowerBound()) {
405      // The cast is safe because of the hasLowerBound check.
406      E endpoint = uncheckedCastNullableTToT(range.getLowerEndpoint());
407      node = root.ceiling(comparator(), endpoint);
408      if (node == null) {
409        return null;
410      }
411      if (range.getLowerBoundType() == BoundType.OPEN
412          && comparator().compare(endpoint, node.getElement()) == 0) {
413        node = node.succ();
414      }
415    } else {
416      node = header.succ();
417    }
418    return (node == header || !range.contains(node.getElement())) ? null : node;
419  }
420
421  @CheckForNull
422  private AvlNode<E> lastNode() {
423    AvlNode<E> root = rootReference.get();
424    if (root == null) {
425      return null;
426    }
427    AvlNode<E> node;
428    if (range.hasUpperBound()) {
429      // The cast is safe because of the hasUpperBound check.
430      E endpoint = uncheckedCastNullableTToT(range.getUpperEndpoint());
431      node = root.floor(comparator(), endpoint);
432      if (node == null) {
433        return null;
434      }
435      if (range.getUpperBoundType() == BoundType.OPEN
436          && comparator().compare(endpoint, node.getElement()) == 0) {
437        node = node.pred();
438      }
439    } else {
440      node = header.pred();
441    }
442    return (node == header || !range.contains(node.getElement())) ? null : node;
443  }
444
445  @Override
446  Iterator<E> elementIterator() {
447    return Multisets.elementIterator(entryIterator());
448  }
449
450  @Override
451  Iterator<Entry<E>> entryIterator() {
452    return new Iterator<Entry<E>>() {
453      @CheckForNull AvlNode<E> current = firstNode();
454      @CheckForNull Entry<E> prevEntry;
455
456      @Override
457      public boolean hasNext() {
458        if (current == null) {
459          return false;
460        } else if (range.tooHigh(current.getElement())) {
461          current = null;
462          return false;
463        } else {
464          return true;
465        }
466      }
467
468      @Override
469      public Entry<E> next() {
470        if (!hasNext()) {
471          throw new NoSuchElementException();
472        }
473        // requireNonNull is safe because current is only nulled out after iteration is complete.
474        Entry<E> result = wrapEntry(requireNonNull(current));
475        prevEntry = result;
476        if (current.succ() == header) {
477          current = null;
478        } else {
479          current = current.succ();
480        }
481        return result;
482      }
483
484      @Override
485      public void remove() {
486        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
487        setCount(prevEntry.getElement(), 0);
488        prevEntry = null;
489      }
490    };
491  }
492
493  @Override
494  Iterator<Entry<E>> descendingEntryIterator() {
495    return new Iterator<Entry<E>>() {
496      @CheckForNull AvlNode<E> current = lastNode();
497      @CheckForNull Entry<E> prevEntry = null;
498
499      @Override
500      public boolean hasNext() {
501        if (current == null) {
502          return false;
503        } else if (range.tooLow(current.getElement())) {
504          current = null;
505          return false;
506        } else {
507          return true;
508        }
509      }
510
511      @Override
512      public Entry<E> next() {
513        if (!hasNext()) {
514          throw new NoSuchElementException();
515        }
516        // requireNonNull is safe because current is only nulled out after iteration is complete.
517        requireNonNull(current);
518        Entry<E> result = wrapEntry(current);
519        prevEntry = result;
520        if (current.pred() == header) {
521          current = null;
522        } else {
523          current = current.pred();
524        }
525        return result;
526      }
527
528      @Override
529      public void remove() {
530        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
531        setCount(prevEntry.getElement(), 0);
532        prevEntry = null;
533      }
534    };
535  }
536
537  @Override
538  public void forEachEntry(ObjIntConsumer<? super E> action) {
539    checkNotNull(action);
540    for (AvlNode<E> node = firstNode();
541        node != header && node != null && !range.tooHigh(node.getElement());
542        node = node.succ()) {
543      action.accept(node.getElement(), node.getCount());
544    }
545  }
546
547  @Override
548  public Iterator<E> iterator() {
549    return Multisets.iteratorImpl(this);
550  }
551
552  @Override
553  public SortedMultiset<E> headMultiset(@ParametricNullness E upperBound, BoundType boundType) {
554    return new TreeMultiset<>(
555        rootReference,
556        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
557        header);
558  }
559
560  @Override
561  public SortedMultiset<E> tailMultiset(@ParametricNullness E lowerBound, BoundType boundType) {
562    return new TreeMultiset<>(
563        rootReference,
564        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
565        header);
566  }
567
568  private static final class Reference<T> {
569    @CheckForNull private T value;
570
571    @CheckForNull
572    public T get() {
573      return value;
574    }
575
576    public void checkAndSet(@CheckForNull T expected, @CheckForNull T newValue) {
577      if (value != expected) {
578        throw new ConcurrentModificationException();
579      }
580      value = newValue;
581    }
582
583    void clear() {
584      value = null;
585    }
586  }
587
588  private static final class AvlNode<E extends @Nullable Object> {
589    /*
590     * For "normal" nodes, the type of this field is `E`, not `@Nullable E` (though note that E is a
591     * type that can include null, as in a TreeMultiset<@Nullable String>).
592     *
593     * For the header node, though, this field contains `null`, regardless of the type of the
594     * multiset.
595     *
596     * Most code that operates on an AvlNode never operates on the header node. Such code can access
597     * the elem field without a null check by calling getElement().
598     */
599    @CheckForNull private final E elem;
600
601    // elemCount is 0 iff this node has been deleted.
602    private int elemCount;
603
604    private int distinctElements;
605    private long totalCount;
606    private int height;
607    @CheckForNull private AvlNode<E> left;
608    @CheckForNull private AvlNode<E> right;
609    /*
610     * pred and succ are nullable after construction, but we always call successor() to initialize
611     * them immediately thereafter.
612     *
613     * They may be subsequently nulled out by TreeMultiset.clear(). I think that the only place that
614     * we can reference a node whose fields have been cleared is inside the iterator (and presumably
615     * only under concurrent modification).
616     *
617     * To access these fields when you know that they are not null, call the pred() and succ()
618     * methods, which perform null checks before returning the fields.
619     */
620    @CheckForNull private AvlNode<E> pred;
621    @CheckForNull private AvlNode<E> succ;
622
623    AvlNode(@ParametricNullness E elem, int elemCount) {
624      checkArgument(elemCount > 0);
625      this.elem = elem;
626      this.elemCount = elemCount;
627      this.totalCount = elemCount;
628      this.distinctElements = 1;
629      this.height = 1;
630      this.left = null;
631      this.right = null;
632    }
633
634    /** Constructor for the header node. */
635    AvlNode() {
636      this.elem = null;
637      this.elemCount = 1;
638    }
639
640    // For discussion of pred() and succ(), see the comment on the pred and succ fields.
641
642    private AvlNode<E> pred() {
643      return requireNonNull(pred);
644    }
645
646    private AvlNode<E> succ() {
647      return requireNonNull(succ);
648    }
649
650    int count(Comparator<? super E> comparator, @ParametricNullness E e) {
651      int cmp = comparator.compare(e, getElement());
652      if (cmp < 0) {
653        return (left == null) ? 0 : left.count(comparator, e);
654      } else if (cmp > 0) {
655        return (right == null) ? 0 : right.count(comparator, e);
656      } else {
657        return elemCount;
658      }
659    }
660
661    private AvlNode<E> addRightChild(@ParametricNullness E e, int count) {
662      right = new AvlNode<>(e, count);
663      successor(this, right, succ());
664      height = Math.max(2, height);
665      distinctElements++;
666      totalCount += count;
667      return this;
668    }
669
670    private AvlNode<E> addLeftChild(@ParametricNullness E e, int count) {
671      left = new AvlNode<>(e, count);
672      successor(pred(), left, this);
673      height = Math.max(2, height);
674      distinctElements++;
675      totalCount += count;
676      return this;
677    }
678
679    AvlNode<E> add(
680        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
681      /*
682       * It speeds things up considerably to unconditionally add count to totalCount here,
683       * but that destroys failure atomicity in the case of count overflow. =(
684       */
685      int cmp = comparator.compare(e, getElement());
686      if (cmp < 0) {
687        AvlNode<E> initLeft = left;
688        if (initLeft == null) {
689          result[0] = 0;
690          return addLeftChild(e, count);
691        }
692        int initHeight = initLeft.height;
693
694        left = initLeft.add(comparator, e, count, result);
695        if (result[0] == 0) {
696          distinctElements++;
697        }
698        this.totalCount += count;
699        return (left.height == initHeight) ? this : rebalance();
700      } else if (cmp > 0) {
701        AvlNode<E> initRight = right;
702        if (initRight == null) {
703          result[0] = 0;
704          return addRightChild(e, count);
705        }
706        int initHeight = initRight.height;
707
708        right = initRight.add(comparator, e, count, result);
709        if (result[0] == 0) {
710          distinctElements++;
711        }
712        this.totalCount += count;
713        return (right.height == initHeight) ? this : rebalance();
714      }
715
716      // adding count to me!  No rebalance possible.
717      result[0] = elemCount;
718      long resultCount = (long) elemCount + count;
719      checkArgument(resultCount <= Integer.MAX_VALUE);
720      this.elemCount += count;
721      this.totalCount += count;
722      return this;
723    }
724
725    @CheckForNull
726    AvlNode<E> remove(
727        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
728      int cmp = comparator.compare(e, getElement());
729      if (cmp < 0) {
730        AvlNode<E> initLeft = left;
731        if (initLeft == null) {
732          result[0] = 0;
733          return this;
734        }
735
736        left = initLeft.remove(comparator, e, count, result);
737
738        if (result[0] > 0) {
739          if (count >= result[0]) {
740            this.distinctElements--;
741            this.totalCount -= result[0];
742          } else {
743            this.totalCount -= count;
744          }
745        }
746        return (result[0] == 0) ? this : rebalance();
747      } else if (cmp > 0) {
748        AvlNode<E> initRight = right;
749        if (initRight == null) {
750          result[0] = 0;
751          return this;
752        }
753
754        right = initRight.remove(comparator, e, count, result);
755
756        if (result[0] > 0) {
757          if (count >= result[0]) {
758            this.distinctElements--;
759            this.totalCount -= result[0];
760          } else {
761            this.totalCount -= count;
762          }
763        }
764        return rebalance();
765      }
766
767      // removing count from me!
768      result[0] = elemCount;
769      if (count >= elemCount) {
770        return deleteMe();
771      } else {
772        this.elemCount -= count;
773        this.totalCount -= count;
774        return this;
775      }
776    }
777
778    @CheckForNull
779    AvlNode<E> setCount(
780        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
781      int cmp = comparator.compare(e, getElement());
782      if (cmp < 0) {
783        AvlNode<E> initLeft = left;
784        if (initLeft == null) {
785          result[0] = 0;
786          return (count > 0) ? addLeftChild(e, count) : this;
787        }
788
789        left = initLeft.setCount(comparator, e, count, result);
790
791        if (count == 0 && result[0] != 0) {
792          this.distinctElements--;
793        } else if (count > 0 && result[0] == 0) {
794          this.distinctElements++;
795        }
796
797        this.totalCount += count - result[0];
798        return rebalance();
799      } else if (cmp > 0) {
800        AvlNode<E> initRight = right;
801        if (initRight == null) {
802          result[0] = 0;
803          return (count > 0) ? addRightChild(e, count) : this;
804        }
805
806        right = initRight.setCount(comparator, e, count, result);
807
808        if (count == 0 && result[0] != 0) {
809          this.distinctElements--;
810        } else if (count > 0 && result[0] == 0) {
811          this.distinctElements++;
812        }
813
814        this.totalCount += count - result[0];
815        return rebalance();
816      }
817
818      // setting my count
819      result[0] = elemCount;
820      if (count == 0) {
821        return deleteMe();
822      }
823      this.totalCount += count - elemCount;
824      this.elemCount = count;
825      return this;
826    }
827
828    @CheckForNull
829    AvlNode<E> setCount(
830        Comparator<? super E> comparator,
831        @ParametricNullness E e,
832        int expectedCount,
833        int newCount,
834        int[] result) {
835      int cmp = comparator.compare(e, getElement());
836      if (cmp < 0) {
837        AvlNode<E> initLeft = left;
838        if (initLeft == null) {
839          result[0] = 0;
840          if (expectedCount == 0 && newCount > 0) {
841            return addLeftChild(e, newCount);
842          }
843          return this;
844        }
845
846        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
847
848        if (result[0] == expectedCount) {
849          if (newCount == 0 && result[0] != 0) {
850            this.distinctElements--;
851          } else if (newCount > 0 && result[0] == 0) {
852            this.distinctElements++;
853          }
854          this.totalCount += newCount - result[0];
855        }
856        return rebalance();
857      } else if (cmp > 0) {
858        AvlNode<E> initRight = right;
859        if (initRight == null) {
860          result[0] = 0;
861          if (expectedCount == 0 && newCount > 0) {
862            return addRightChild(e, newCount);
863          }
864          return this;
865        }
866
867        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
868
869        if (result[0] == expectedCount) {
870          if (newCount == 0 && result[0] != 0) {
871            this.distinctElements--;
872          } else if (newCount > 0 && result[0] == 0) {
873            this.distinctElements++;
874          }
875          this.totalCount += newCount - result[0];
876        }
877        return rebalance();
878      }
879
880      // setting my count
881      result[0] = elemCount;
882      if (expectedCount == elemCount) {
883        if (newCount == 0) {
884          return deleteMe();
885        }
886        this.totalCount += newCount - elemCount;
887        this.elemCount = newCount;
888      }
889      return this;
890    }
891
892    @CheckForNull
893    private AvlNode<E> deleteMe() {
894      int oldElemCount = this.elemCount;
895      this.elemCount = 0;
896      successor(pred(), succ());
897      if (left == null) {
898        return right;
899      } else if (right == null) {
900        return left;
901      } else if (left.height >= right.height) {
902        AvlNode<E> newTop = pred();
903        // newTop is the maximum node in my left subtree
904        newTop.left = left.removeMax(newTop);
905        newTop.right = right;
906        newTop.distinctElements = distinctElements - 1;
907        newTop.totalCount = totalCount - oldElemCount;
908        return newTop.rebalance();
909      } else {
910        AvlNode<E> newTop = succ();
911        newTop.right = right.removeMin(newTop);
912        newTop.left = left;
913        newTop.distinctElements = distinctElements - 1;
914        newTop.totalCount = totalCount - oldElemCount;
915        return newTop.rebalance();
916      }
917    }
918
919    // Removes the minimum node from this subtree to be reused elsewhere
920    @CheckForNull
921    private AvlNode<E> removeMin(AvlNode<E> node) {
922      if (left == null) {
923        return right;
924      } else {
925        left = left.removeMin(node);
926        distinctElements--;
927        totalCount -= node.elemCount;
928        return rebalance();
929      }
930    }
931
932    // Removes the maximum node from this subtree to be reused elsewhere
933    @CheckForNull
934    private AvlNode<E> removeMax(AvlNode<E> node) {
935      if (right == null) {
936        return left;
937      } else {
938        right = right.removeMax(node);
939        distinctElements--;
940        totalCount -= node.elemCount;
941        return rebalance();
942      }
943    }
944
945    private void recomputeMultiset() {
946      this.distinctElements =
947          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
948      this.totalCount = elemCount + totalCount(left) + totalCount(right);
949    }
950
951    private void recomputeHeight() {
952      this.height = 1 + Math.max(height(left), height(right));
953    }
954
955    private void recompute() {
956      recomputeMultiset();
957      recomputeHeight();
958    }
959
960    private AvlNode<E> rebalance() {
961      switch (balanceFactor()) {
962        case -2:
963          // requireNonNull is safe because right must exist in order to get a negative factor.
964          requireNonNull(right);
965          if (right.balanceFactor() > 0) {
966            right = right.rotateRight();
967          }
968          return rotateLeft();
969        case 2:
970          // requireNonNull is safe because left must exist in order to get a positive factor.
971          requireNonNull(left);
972          if (left.balanceFactor() < 0) {
973            left = left.rotateLeft();
974          }
975          return rotateRight();
976        default:
977          recomputeHeight();
978          return this;
979      }
980    }
981
982    private int balanceFactor() {
983      return height(left) - height(right);
984    }
985
986    private AvlNode<E> rotateLeft() {
987      checkState(right != null);
988      AvlNode<E> newTop = right;
989      this.right = newTop.left;
990      newTop.left = this;
991      newTop.totalCount = this.totalCount;
992      newTop.distinctElements = this.distinctElements;
993      this.recompute();
994      newTop.recomputeHeight();
995      return newTop;
996    }
997
998    private AvlNode<E> rotateRight() {
999      checkState(left != null);
1000      AvlNode<E> newTop = left;
1001      this.left = newTop.right;
1002      newTop.right = this;
1003      newTop.totalCount = this.totalCount;
1004      newTop.distinctElements = this.distinctElements;
1005      this.recompute();
1006      newTop.recomputeHeight();
1007      return newTop;
1008    }
1009
1010    private static long totalCount(@CheckForNull AvlNode<?> node) {
1011      return (node == null) ? 0 : node.totalCount;
1012    }
1013
1014    private static int height(@CheckForNull AvlNode<?> node) {
1015      return (node == null) ? 0 : node.height;
1016    }
1017
1018    @CheckForNull
1019    private AvlNode<E> ceiling(Comparator<? super E> comparator, @ParametricNullness E e) {
1020      int cmp = comparator.compare(e, getElement());
1021      if (cmp < 0) {
1022        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
1023      } else if (cmp == 0) {
1024        return this;
1025      } else {
1026        return (right == null) ? null : right.ceiling(comparator, e);
1027      }
1028    }
1029
1030    @CheckForNull
1031    private AvlNode<E> floor(Comparator<? super E> comparator, @ParametricNullness E e) {
1032      int cmp = comparator.compare(e, getElement());
1033      if (cmp > 0) {
1034        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
1035      } else if (cmp == 0) {
1036        return this;
1037      } else {
1038        return (left == null) ? null : left.floor(comparator, e);
1039      }
1040    }
1041
1042    @ParametricNullness
1043    E getElement() {
1044      // For discussion of this cast, see the comment on the elem field.
1045      return uncheckedCastNullableTToT(elem);
1046    }
1047
1048    int getCount() {
1049      return elemCount;
1050    }
1051
1052    @Override
1053    public String toString() {
1054      return Multisets.immutableEntry(getElement(), getCount()).toString();
1055    }
1056  }
1057
1058  private static <T extends @Nullable Object> void successor(AvlNode<T> a, AvlNode<T> b) {
1059    a.succ = b;
1060    b.pred = a;
1061  }
1062
1063  private static <T extends @Nullable Object> void successor(
1064      AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
1065    successor(a, b);
1066    successor(b, c);
1067  }
1068
1069  /*
1070   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
1071   * calls the comparator to compare the two keys. If that change is made,
1072   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
1073   */
1074
1075  /**
1076   * @serialData the comparator, the number of distinct elements, the first element, its count, the
1077   *     second element, its count, and so on
1078   */
1079  @J2ktIncompatible
1080  @GwtIncompatible // java.io.ObjectOutputStream
1081  private void writeObject(ObjectOutputStream stream) throws IOException {
1082    stream.defaultWriteObject();
1083    stream.writeObject(elementSet().comparator());
1084    Serialization.writeMultiset(this, stream);
1085  }
1086
1087  @J2ktIncompatible
1088  @GwtIncompatible // java.io.ObjectInputStream
1089  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
1090    stream.defaultReadObject();
1091    @SuppressWarnings("unchecked")
1092    // reading data stored by writeObject
1093    Comparator<? super E> comparator = (Comparator<? super E>) requireNonNull(stream.readObject());
1094    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
1095    Serialization.getFieldSetter(TreeMultiset.class, "range")
1096        .set(this, GeneralRange.all(comparator));
1097    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
1098        .set(this, new Reference<AvlNode<E>>());
1099    AvlNode<E> header = new AvlNode<>();
1100    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
1101    successor(header, header);
1102    Serialization.populateMultiset(this, stream);
1103  }
1104
1105  @GwtIncompatible // not needed in emulated source
1106  @J2ktIncompatible
1107  private static final long serialVersionUID = 1;
1108}