001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkNotNull;
021import static com.google.common.base.Preconditions.checkState;
022import static com.google.common.collect.CollectPreconditions.checkNonnegative;
023import static com.google.common.collect.NullnessCasts.uncheckedCastNullableTToT;
024import static java.util.Objects.requireNonNull;
025
026import com.google.common.annotations.GwtCompatible;
027import com.google.common.annotations.GwtIncompatible;
028import com.google.common.base.MoreObjects;
029import com.google.common.primitives.Ints;
030import com.google.errorprone.annotations.CanIgnoreReturnValue;
031import java.io.IOException;
032import java.io.ObjectInputStream;
033import java.io.ObjectOutputStream;
034import java.io.Serializable;
035import java.util.Comparator;
036import java.util.ConcurrentModificationException;
037import java.util.Iterator;
038import java.util.NoSuchElementException;
039import java.util.function.ObjIntConsumer;
040import javax.annotation.CheckForNull;
041import org.checkerframework.checker.nullness.qual.Nullable;
042
043/**
044 * A multiset which maintains the ordering of its elements, according to either their natural order
045 * or an explicit {@link Comparator}. In all cases, this implementation uses {@link
046 * Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to determine
047 * equivalence of instances.
048 *
049 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
050 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the {@link
051 * java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
052 *
053 * <p>See the Guava User Guide article on <a href=
054 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset">{@code Multiset}</a>.
055 *
056 * @author Louis Wasserman
057 * @author Jared Levy
058 * @since 2.0
059 */
060@GwtCompatible(emulated = true)
061@ElementTypesAreNonnullByDefault
062public final class TreeMultiset<E extends @Nullable Object> extends AbstractSortedMultiset<E>
063    implements Serializable {
064
065  /**
066   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
067   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
068   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
069   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
070   * user attempts to add an element to the multiset that violates this constraint (for example, the
071   * user attempts to add a string element to a set whose elements are integers), the {@code
072   * add(Object)} call will throw a {@code ClassCastException}.
073   *
074   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
075   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
076   */
077  public static <E extends Comparable> TreeMultiset<E> create() {
078    return new TreeMultiset<E>(Ordering.natural());
079  }
080
081  /**
082   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
083   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
084   * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements
085   * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the
086   * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code
087   * ClassCastException}.
088   *
089   * @param comparator the comparator that will be used to sort this multiset. A null value
090   *     indicates that the elements' <i>natural ordering</i> should be used.
091   */
092  @SuppressWarnings("unchecked")
093  public static <E extends @Nullable Object> TreeMultiset<E> create(
094      @CheckForNull Comparator<? super E> comparator) {
095    return (comparator == null)
096        ? new TreeMultiset<E>((Comparator) Ordering.natural())
097        : new TreeMultiset<E>(comparator);
098  }
099
100  /**
101   * Creates an empty multiset containing the given initial elements, sorted according to the
102   * elements' natural order.
103   *
104   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
105   *
106   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
107   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
108   */
109  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
110    TreeMultiset<E> multiset = create();
111    Iterables.addAll(multiset, elements);
112    return multiset;
113  }
114
115  private final transient Reference<AvlNode<E>> rootReference;
116  private final transient GeneralRange<E> range;
117  private final transient AvlNode<E> header;
118
119  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
120    super(range.comparator());
121    this.rootReference = rootReference;
122    this.range = range;
123    this.header = endLink;
124  }
125
126  TreeMultiset(Comparator<? super E> comparator) {
127    super(comparator);
128    this.range = GeneralRange.all(comparator);
129    this.header = new AvlNode<>();
130    successor(header, header);
131    this.rootReference = new Reference<>();
132  }
133
134  /** A function which can be summed across a subtree. */
135  private enum Aggregate {
136    SIZE {
137      @Override
138      int nodeAggregate(AvlNode<?> node) {
139        return node.elemCount;
140      }
141
142      @Override
143      long treeAggregate(@CheckForNull AvlNode<?> root) {
144        return (root == null) ? 0 : root.totalCount;
145      }
146    },
147    DISTINCT {
148      @Override
149      int nodeAggregate(AvlNode<?> node) {
150        return 1;
151      }
152
153      @Override
154      long treeAggregate(@CheckForNull AvlNode<?> root) {
155        return (root == null) ? 0 : root.distinctElements;
156      }
157    };
158
159    abstract int nodeAggregate(AvlNode<?> node);
160
161    abstract long treeAggregate(@CheckForNull AvlNode<?> root);
162  }
163
164  private long aggregateForEntries(Aggregate aggr) {
165    AvlNode<E> root = rootReference.get();
166    long total = aggr.treeAggregate(root);
167    if (range.hasLowerBound()) {
168      total -= aggregateBelowRange(aggr, root);
169    }
170    if (range.hasUpperBound()) {
171      total -= aggregateAboveRange(aggr, root);
172    }
173    return total;
174  }
175
176  private long aggregateBelowRange(Aggregate aggr, @CheckForNull AvlNode<E> node) {
177    if (node == null) {
178      return 0;
179    }
180    // The cast is safe because we call this method only if hasLowerBound().
181    int cmp =
182        comparator()
183            .compare(uncheckedCastNullableTToT(range.getLowerEndpoint()), node.getElement());
184    if (cmp < 0) {
185      return aggregateBelowRange(aggr, node.left);
186    } else if (cmp == 0) {
187      switch (range.getLowerBoundType()) {
188        case OPEN:
189          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
190        case CLOSED:
191          return aggr.treeAggregate(node.left);
192        default:
193          throw new AssertionError();
194      }
195    } else {
196      return aggr.treeAggregate(node.left)
197          + aggr.nodeAggregate(node)
198          + aggregateBelowRange(aggr, node.right);
199    }
200  }
201
202  private long aggregateAboveRange(Aggregate aggr, @CheckForNull AvlNode<E> node) {
203    if (node == null) {
204      return 0;
205    }
206    // The cast is safe because we call this method only if hasUpperBound().
207    int cmp =
208        comparator()
209            .compare(uncheckedCastNullableTToT(range.getUpperEndpoint()), node.getElement());
210    if (cmp > 0) {
211      return aggregateAboveRange(aggr, node.right);
212    } else if (cmp == 0) {
213      switch (range.getUpperBoundType()) {
214        case OPEN:
215          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
216        case CLOSED:
217          return aggr.treeAggregate(node.right);
218        default:
219          throw new AssertionError();
220      }
221    } else {
222      return aggr.treeAggregate(node.right)
223          + aggr.nodeAggregate(node)
224          + aggregateAboveRange(aggr, node.left);
225    }
226  }
227
228  @Override
229  public int size() {
230    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
231  }
232
233  @Override
234  int distinctElements() {
235    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
236  }
237
238  static int distinctElements(@CheckForNull AvlNode<?> node) {
239    return (node == null) ? 0 : node.distinctElements;
240  }
241
242  @Override
243  public int count(@CheckForNull Object element) {
244    try {
245      @SuppressWarnings("unchecked")
246      E e = (E) element;
247      AvlNode<E> root = rootReference.get();
248      if (!range.contains(e) || root == null) {
249        return 0;
250      }
251      return root.count(comparator(), e);
252    } catch (ClassCastException | NullPointerException e) {
253      return 0;
254    }
255  }
256
257  @CanIgnoreReturnValue
258  @Override
259  public int add(@ParametricNullness E element, int occurrences) {
260    checkNonnegative(occurrences, "occurrences");
261    if (occurrences == 0) {
262      return count(element);
263    }
264    checkArgument(range.contains(element));
265    AvlNode<E> root = rootReference.get();
266    if (root == null) {
267      comparator().compare(element, element);
268      AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
269      successor(header, newRoot, header);
270      rootReference.checkAndSet(root, newRoot);
271      return 0;
272    }
273    int[] result = new int[1]; // used as a mutable int reference to hold result
274    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
275    rootReference.checkAndSet(root, newRoot);
276    return result[0];
277  }
278
279  @CanIgnoreReturnValue
280  @Override
281  public int remove(@CheckForNull Object element, int occurrences) {
282    checkNonnegative(occurrences, "occurrences");
283    if (occurrences == 0) {
284      return count(element);
285    }
286    AvlNode<E> root = rootReference.get();
287    int[] result = new int[1]; // used as a mutable int reference to hold result
288    AvlNode<E> newRoot;
289    try {
290      @SuppressWarnings("unchecked")
291      E e = (E) element;
292      if (!range.contains(e) || root == null) {
293        return 0;
294      }
295      newRoot = root.remove(comparator(), e, occurrences, result);
296    } catch (ClassCastException | NullPointerException e) {
297      return 0;
298    }
299    rootReference.checkAndSet(root, newRoot);
300    return result[0];
301  }
302
303  @CanIgnoreReturnValue
304  @Override
305  public int setCount(@ParametricNullness E element, int count) {
306    checkNonnegative(count, "count");
307    if (!range.contains(element)) {
308      checkArgument(count == 0);
309      return 0;
310    }
311
312    AvlNode<E> root = rootReference.get();
313    if (root == null) {
314      if (count > 0) {
315        add(element, count);
316      }
317      return 0;
318    }
319    int[] result = new int[1]; // used as a mutable int reference to hold result
320    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
321    rootReference.checkAndSet(root, newRoot);
322    return result[0];
323  }
324
325  @CanIgnoreReturnValue
326  @Override
327  public boolean setCount(@ParametricNullness E element, int oldCount, int newCount) {
328    checkNonnegative(newCount, "newCount");
329    checkNonnegative(oldCount, "oldCount");
330    checkArgument(range.contains(element));
331
332    AvlNode<E> root = rootReference.get();
333    if (root == null) {
334      if (oldCount == 0) {
335        if (newCount > 0) {
336          add(element, newCount);
337        }
338        return true;
339      } else {
340        return false;
341      }
342    }
343    int[] result = new int[1]; // used as a mutable int reference to hold result
344    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
345    rootReference.checkAndSet(root, newRoot);
346    return result[0] == oldCount;
347  }
348
349  @Override
350  public void clear() {
351    if (!range.hasLowerBound() && !range.hasUpperBound()) {
352      // We can do this in O(n) rather than removing one by one, which could force rebalancing.
353      for (AvlNode<E> current = header.succ(); current != header; ) {
354        AvlNode<E> next = current.succ();
355
356        current.elemCount = 0;
357        // Also clear these fields so that one deleted Entry doesn't retain all elements.
358        current.left = null;
359        current.right = null;
360        current.pred = null;
361        current.succ = null;
362
363        current = next;
364      }
365      successor(header, header);
366      rootReference.clear();
367    } else {
368      // TODO(cpovirk): Perhaps we can optimize in this case, too?
369      Iterators.clear(entryIterator());
370    }
371  }
372
373  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
374    return new Multisets.AbstractEntry<E>() {
375      @Override
376      @ParametricNullness
377      public E getElement() {
378        return baseEntry.getElement();
379      }
380
381      @Override
382      public int getCount() {
383        int result = baseEntry.getCount();
384        if (result == 0) {
385          return count(getElement());
386        } else {
387          return result;
388        }
389      }
390    };
391  }
392
393  /** Returns the first node in the tree that is in range. */
394  @CheckForNull
395  private AvlNode<E> firstNode() {
396    AvlNode<E> root = rootReference.get();
397    if (root == null) {
398      return null;
399    }
400    AvlNode<E> node;
401    if (range.hasLowerBound()) {
402      // The cast is safe because of the hasLowerBound check.
403      E endpoint = uncheckedCastNullableTToT(range.getLowerEndpoint());
404      node = root.ceiling(comparator(), endpoint);
405      if (node == null) {
406        return null;
407      }
408      if (range.getLowerBoundType() == BoundType.OPEN
409          && comparator().compare(endpoint, node.getElement()) == 0) {
410        node = node.succ();
411      }
412    } else {
413      node = header.succ();
414    }
415    return (node == header || !range.contains(node.getElement())) ? null : node;
416  }
417
418  @CheckForNull
419  private AvlNode<E> lastNode() {
420    AvlNode<E> root = rootReference.get();
421    if (root == null) {
422      return null;
423    }
424    AvlNode<E> node;
425    if (range.hasUpperBound()) {
426      // The cast is safe because of the hasUpperBound check.
427      E endpoint = uncheckedCastNullableTToT(range.getUpperEndpoint());
428      node = root.floor(comparator(), endpoint);
429      if (node == null) {
430        return null;
431      }
432      if (range.getUpperBoundType() == BoundType.OPEN
433          && comparator().compare(endpoint, node.getElement()) == 0) {
434        node = node.pred();
435      }
436    } else {
437      node = header.pred();
438    }
439    return (node == header || !range.contains(node.getElement())) ? null : node;
440  }
441
442  @Override
443  Iterator<E> elementIterator() {
444    return Multisets.elementIterator(entryIterator());
445  }
446
447  @Override
448  Iterator<Entry<E>> entryIterator() {
449    return new Iterator<Entry<E>>() {
450      @CheckForNull AvlNode<E> current = firstNode();
451      @CheckForNull Entry<E> prevEntry;
452
453      @Override
454      public boolean hasNext() {
455        if (current == null) {
456          return false;
457        } else if (range.tooHigh(current.getElement())) {
458          current = null;
459          return false;
460        } else {
461          return true;
462        }
463      }
464
465      @Override
466      public Entry<E> next() {
467        if (!hasNext()) {
468          throw new NoSuchElementException();
469        }
470        // requireNonNull is safe because current is only nulled out after iteration is complete.
471        Entry<E> result = wrapEntry(requireNonNull(current));
472        prevEntry = result;
473        if (current.succ() == header) {
474          current = null;
475        } else {
476          current = current.succ();
477        }
478        return result;
479      }
480
481      @Override
482      public void remove() {
483        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
484        setCount(prevEntry.getElement(), 0);
485        prevEntry = null;
486      }
487    };
488  }
489
490  @Override
491  Iterator<Entry<E>> descendingEntryIterator() {
492    return new Iterator<Entry<E>>() {
493      @CheckForNull AvlNode<E> current = lastNode();
494      @CheckForNull Entry<E> prevEntry = null;
495
496      @Override
497      public boolean hasNext() {
498        if (current == null) {
499          return false;
500        } else if (range.tooLow(current.getElement())) {
501          current = null;
502          return false;
503        } else {
504          return true;
505        }
506      }
507
508      @Override
509      public Entry<E> next() {
510        if (!hasNext()) {
511          throw new NoSuchElementException();
512        }
513        // requireNonNull is safe because current is only nulled out after iteration is complete.
514        requireNonNull(current);
515        Entry<E> result = wrapEntry(current);
516        prevEntry = result;
517        if (current.pred() == header) {
518          current = null;
519        } else {
520          current = current.pred();
521        }
522        return result;
523      }
524
525      @Override
526      public void remove() {
527        checkState(prevEntry != null, "no calls to next() since the last call to remove()");
528        setCount(prevEntry.getElement(), 0);
529        prevEntry = null;
530      }
531    };
532  }
533
534  @Override
535  public void forEachEntry(ObjIntConsumer<? super E> action) {
536    checkNotNull(action);
537    for (AvlNode<E> node = firstNode();
538        node != header && node != null && !range.tooHigh(node.getElement());
539        node = node.succ()) {
540      action.accept(node.getElement(), node.getCount());
541    }
542  }
543
544  @Override
545  public Iterator<E> iterator() {
546    return Multisets.iteratorImpl(this);
547  }
548
549  @Override
550  public SortedMultiset<E> headMultiset(@ParametricNullness E upperBound, BoundType boundType) {
551    return new TreeMultiset<E>(
552        rootReference,
553        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
554        header);
555  }
556
557  @Override
558  public SortedMultiset<E> tailMultiset(@ParametricNullness E lowerBound, BoundType boundType) {
559    return new TreeMultiset<E>(
560        rootReference,
561        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
562        header);
563  }
564
565  private static final class Reference<T> {
566    @CheckForNull private T value;
567
568    @CheckForNull
569    public T get() {
570      return value;
571    }
572
573    public void checkAndSet(@CheckForNull T expected, @CheckForNull T newValue) {
574      if (value != expected) {
575        throw new ConcurrentModificationException();
576      }
577      value = newValue;
578    }
579
580    void clear() {
581      value = null;
582    }
583  }
584
585  private static final class AvlNode<E extends @Nullable Object> {
586    /*
587     * For "normal" nodes, the type of this field is `E`, not `@Nullable E` (though note that E is a
588     * type that can include null, as in a TreeMultiset<@Nullable String>).
589     *
590     * For the header node, though, this field contains `null`, regardless of the type of the
591     * multiset.
592     *
593     * Most code that operates on an AvlNode never operates on the header node. Such code can access
594     * the elem field without a null check by calling getElement().
595     */
596    @CheckForNull private final E elem;
597
598    // elemCount is 0 iff this node has been deleted.
599    private int elemCount;
600
601    private int distinctElements;
602    private long totalCount;
603    private int height;
604    @CheckForNull private AvlNode<E> left;
605    @CheckForNull private AvlNode<E> right;
606    /*
607     * pred and succ are nullable after construction, but we always call successor() to initialize
608     * them immediately thereafter.
609     *
610     * They may be subsequently nulled out by TreeMultiset.clear(). I think that the only place that
611     * we can reference a node whose fields have been cleared is inside the iterator (and presumably
612     * only under concurrent modification).
613     *
614     * To access these fields when you know that they are not null, call the pred() and succ()
615     * methods, which perform null checks before returning the fields.
616     */
617    @CheckForNull private AvlNode<E> pred;
618    @CheckForNull private AvlNode<E> succ;
619
620    AvlNode(@ParametricNullness E elem, int elemCount) {
621      checkArgument(elemCount > 0);
622      this.elem = elem;
623      this.elemCount = elemCount;
624      this.totalCount = elemCount;
625      this.distinctElements = 1;
626      this.height = 1;
627      this.left = null;
628      this.right = null;
629    }
630
631    /** Constructor for the header node. */
632    AvlNode() {
633      this.elem = null;
634      this.elemCount = 1;
635    }
636
637    // For discussion of pred() and succ(), see the comment on the pred and succ fields.
638
639    private AvlNode<E> pred() {
640      return requireNonNull(pred);
641    }
642
643    private AvlNode<E> succ() {
644      return requireNonNull(succ);
645    }
646
647    int count(Comparator<? super E> comparator, @ParametricNullness E e) {
648      int cmp = comparator.compare(e, getElement());
649      if (cmp < 0) {
650        return (left == null) ? 0 : left.count(comparator, e);
651      } else if (cmp > 0) {
652        return (right == null) ? 0 : right.count(comparator, e);
653      } else {
654        return elemCount;
655      }
656    }
657
658    private AvlNode<E> addRightChild(@ParametricNullness E e, int count) {
659      right = new AvlNode<E>(e, count);
660      successor(this, right, succ());
661      height = Math.max(2, height);
662      distinctElements++;
663      totalCount += count;
664      return this;
665    }
666
667    private AvlNode<E> addLeftChild(@ParametricNullness E e, int count) {
668      left = new AvlNode<E>(e, count);
669      successor(pred(), left, this);
670      height = Math.max(2, height);
671      distinctElements++;
672      totalCount += count;
673      return this;
674    }
675
676    AvlNode<E> add(
677        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
678      /*
679       * It speeds things up considerably to unconditionally add count to totalCount here,
680       * but that destroys failure atomicity in the case of count overflow. =(
681       */
682      int cmp = comparator.compare(e, getElement());
683      if (cmp < 0) {
684        AvlNode<E> initLeft = left;
685        if (initLeft == null) {
686          result[0] = 0;
687          return addLeftChild(e, count);
688        }
689        int initHeight = initLeft.height;
690
691        left = initLeft.add(comparator, e, count, result);
692        if (result[0] == 0) {
693          distinctElements++;
694        }
695        this.totalCount += count;
696        return (left.height == initHeight) ? this : rebalance();
697      } else if (cmp > 0) {
698        AvlNode<E> initRight = right;
699        if (initRight == null) {
700          result[0] = 0;
701          return addRightChild(e, count);
702        }
703        int initHeight = initRight.height;
704
705        right = initRight.add(comparator, e, count, result);
706        if (result[0] == 0) {
707          distinctElements++;
708        }
709        this.totalCount += count;
710        return (right.height == initHeight) ? this : rebalance();
711      }
712
713      // adding count to me!  No rebalance possible.
714      result[0] = elemCount;
715      long resultCount = (long) elemCount + count;
716      checkArgument(resultCount <= Integer.MAX_VALUE);
717      this.elemCount += count;
718      this.totalCount += count;
719      return this;
720    }
721
722    @CheckForNull
723    AvlNode<E> remove(
724        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
725      int cmp = comparator.compare(e, getElement());
726      if (cmp < 0) {
727        AvlNode<E> initLeft = left;
728        if (initLeft == null) {
729          result[0] = 0;
730          return this;
731        }
732
733        left = initLeft.remove(comparator, e, count, result);
734
735        if (result[0] > 0) {
736          if (count >= result[0]) {
737            this.distinctElements--;
738            this.totalCount -= result[0];
739          } else {
740            this.totalCount -= count;
741          }
742        }
743        return (result[0] == 0) ? this : rebalance();
744      } else if (cmp > 0) {
745        AvlNode<E> initRight = right;
746        if (initRight == null) {
747          result[0] = 0;
748          return this;
749        }
750
751        right = initRight.remove(comparator, e, count, result);
752
753        if (result[0] > 0) {
754          if (count >= result[0]) {
755            this.distinctElements--;
756            this.totalCount -= result[0];
757          } else {
758            this.totalCount -= count;
759          }
760        }
761        return rebalance();
762      }
763
764      // removing count from me!
765      result[0] = elemCount;
766      if (count >= elemCount) {
767        return deleteMe();
768      } else {
769        this.elemCount -= count;
770        this.totalCount -= count;
771        return this;
772      }
773    }
774
775    @CheckForNull
776    AvlNode<E> setCount(
777        Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) {
778      int cmp = comparator.compare(e, getElement());
779      if (cmp < 0) {
780        AvlNode<E> initLeft = left;
781        if (initLeft == null) {
782          result[0] = 0;
783          return (count > 0) ? addLeftChild(e, count) : this;
784        }
785
786        left = initLeft.setCount(comparator, e, count, result);
787
788        if (count == 0 && result[0] != 0) {
789          this.distinctElements--;
790        } else if (count > 0 && result[0] == 0) {
791          this.distinctElements++;
792        }
793
794        this.totalCount += count - result[0];
795        return rebalance();
796      } else if (cmp > 0) {
797        AvlNode<E> initRight = right;
798        if (initRight == null) {
799          result[0] = 0;
800          return (count > 0) ? addRightChild(e, count) : this;
801        }
802
803        right = initRight.setCount(comparator, e, count, result);
804
805        if (count == 0 && result[0] != 0) {
806          this.distinctElements--;
807        } else if (count > 0 && result[0] == 0) {
808          this.distinctElements++;
809        }
810
811        this.totalCount += count - result[0];
812        return rebalance();
813      }
814
815      // setting my count
816      result[0] = elemCount;
817      if (count == 0) {
818        return deleteMe();
819      }
820      this.totalCount += count - elemCount;
821      this.elemCount = count;
822      return this;
823    }
824
825    @CheckForNull
826    AvlNode<E> setCount(
827        Comparator<? super E> comparator,
828        @ParametricNullness E e,
829        int expectedCount,
830        int newCount,
831        int[] result) {
832      int cmp = comparator.compare(e, getElement());
833      if (cmp < 0) {
834        AvlNode<E> initLeft = left;
835        if (initLeft == null) {
836          result[0] = 0;
837          if (expectedCount == 0 && newCount > 0) {
838            return addLeftChild(e, newCount);
839          }
840          return this;
841        }
842
843        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
844
845        if (result[0] == expectedCount) {
846          if (newCount == 0 && result[0] != 0) {
847            this.distinctElements--;
848          } else if (newCount > 0 && result[0] == 0) {
849            this.distinctElements++;
850          }
851          this.totalCount += newCount - result[0];
852        }
853        return rebalance();
854      } else if (cmp > 0) {
855        AvlNode<E> initRight = right;
856        if (initRight == null) {
857          result[0] = 0;
858          if (expectedCount == 0 && newCount > 0) {
859            return addRightChild(e, newCount);
860          }
861          return this;
862        }
863
864        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
865
866        if (result[0] == expectedCount) {
867          if (newCount == 0 && result[0] != 0) {
868            this.distinctElements--;
869          } else if (newCount > 0 && result[0] == 0) {
870            this.distinctElements++;
871          }
872          this.totalCount += newCount - result[0];
873        }
874        return rebalance();
875      }
876
877      // setting my count
878      result[0] = elemCount;
879      if (expectedCount == elemCount) {
880        if (newCount == 0) {
881          return deleteMe();
882        }
883        this.totalCount += newCount - elemCount;
884        this.elemCount = newCount;
885      }
886      return this;
887    }
888
889    @CheckForNull
890    private AvlNode<E> deleteMe() {
891      int oldElemCount = this.elemCount;
892      this.elemCount = 0;
893      successor(pred(), succ());
894      if (left == null) {
895        return right;
896      } else if (right == null) {
897        return left;
898      } else if (left.height >= right.height) {
899        AvlNode<E> newTop = pred();
900        // newTop is the maximum node in my left subtree
901        newTop.left = left.removeMax(newTop);
902        newTop.right = right;
903        newTop.distinctElements = distinctElements - 1;
904        newTop.totalCount = totalCount - oldElemCount;
905        return newTop.rebalance();
906      } else {
907        AvlNode<E> newTop = succ();
908        newTop.right = right.removeMin(newTop);
909        newTop.left = left;
910        newTop.distinctElements = distinctElements - 1;
911        newTop.totalCount = totalCount - oldElemCount;
912        return newTop.rebalance();
913      }
914    }
915
916    // Removes the minimum node from this subtree to be reused elsewhere
917    @CheckForNull
918    private AvlNode<E> removeMin(AvlNode<E> node) {
919      if (left == null) {
920        return right;
921      } else {
922        left = left.removeMin(node);
923        distinctElements--;
924        totalCount -= node.elemCount;
925        return rebalance();
926      }
927    }
928
929    // Removes the maximum node from this subtree to be reused elsewhere
930    @CheckForNull
931    private AvlNode<E> removeMax(AvlNode<E> node) {
932      if (right == null) {
933        return left;
934      } else {
935        right = right.removeMax(node);
936        distinctElements--;
937        totalCount -= node.elemCount;
938        return rebalance();
939      }
940    }
941
942    private void recomputeMultiset() {
943      this.distinctElements =
944          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
945      this.totalCount = elemCount + totalCount(left) + totalCount(right);
946    }
947
948    private void recomputeHeight() {
949      this.height = 1 + Math.max(height(left), height(right));
950    }
951
952    private void recompute() {
953      recomputeMultiset();
954      recomputeHeight();
955    }
956
957    private AvlNode<E> rebalance() {
958      switch (balanceFactor()) {
959        case -2:
960          // requireNonNull is safe because right must exist in order to get a negative factor.
961          requireNonNull(right);
962          if (right.balanceFactor() > 0) {
963            right = right.rotateRight();
964          }
965          return rotateLeft();
966        case 2:
967          // requireNonNull is safe because left must exist in order to get a positive factor.
968          requireNonNull(left);
969          if (left.balanceFactor() < 0) {
970            left = left.rotateLeft();
971          }
972          return rotateRight();
973        default:
974          recomputeHeight();
975          return this;
976      }
977    }
978
979    private int balanceFactor() {
980      return height(left) - height(right);
981    }
982
983    private AvlNode<E> rotateLeft() {
984      checkState(right != null);
985      AvlNode<E> newTop = right;
986      this.right = newTop.left;
987      newTop.left = this;
988      newTop.totalCount = this.totalCount;
989      newTop.distinctElements = this.distinctElements;
990      this.recompute();
991      newTop.recomputeHeight();
992      return newTop;
993    }
994
995    private AvlNode<E> rotateRight() {
996      checkState(left != null);
997      AvlNode<E> newTop = left;
998      this.left = newTop.right;
999      newTop.right = this;
1000      newTop.totalCount = this.totalCount;
1001      newTop.distinctElements = this.distinctElements;
1002      this.recompute();
1003      newTop.recomputeHeight();
1004      return newTop;
1005    }
1006
1007    private static long totalCount(@CheckForNull AvlNode<?> node) {
1008      return (node == null) ? 0 : node.totalCount;
1009    }
1010
1011    private static int height(@CheckForNull AvlNode<?> node) {
1012      return (node == null) ? 0 : node.height;
1013    }
1014
1015    @CheckForNull
1016    private AvlNode<E> ceiling(Comparator<? super E> comparator, @ParametricNullness E e) {
1017      int cmp = comparator.compare(e, getElement());
1018      if (cmp < 0) {
1019        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
1020      } else if (cmp == 0) {
1021        return this;
1022      } else {
1023        return (right == null) ? null : right.ceiling(comparator, e);
1024      }
1025    }
1026
1027    @CheckForNull
1028    private AvlNode<E> floor(Comparator<? super E> comparator, @ParametricNullness E e) {
1029      int cmp = comparator.compare(e, getElement());
1030      if (cmp > 0) {
1031        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
1032      } else if (cmp == 0) {
1033        return this;
1034      } else {
1035        return (left == null) ? null : left.floor(comparator, e);
1036      }
1037    }
1038
1039    @ParametricNullness
1040    E getElement() {
1041      // For discussion of this cast, see the comment on the elem field.
1042      return uncheckedCastNullableTToT(elem);
1043    }
1044
1045    int getCount() {
1046      return elemCount;
1047    }
1048
1049    @Override
1050    public String toString() {
1051      return Multisets.immutableEntry(getElement(), getCount()).toString();
1052    }
1053  }
1054
1055  private static <T extends @Nullable Object> void successor(AvlNode<T> a, AvlNode<T> b) {
1056    a.succ = b;
1057    b.pred = a;
1058  }
1059
1060  private static <T extends @Nullable Object> void successor(
1061      AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
1062    successor(a, b);
1063    successor(b, c);
1064  }
1065
1066  /*
1067   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
1068   * calls the comparator to compare the two keys. If that change is made,
1069   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
1070   */
1071
1072  /**
1073   * @serialData the comparator, the number of distinct elements, the first element, its count, the
1074   *     second element, its count, and so on
1075   */
1076  @GwtIncompatible // java.io.ObjectOutputStream
1077  private void writeObject(ObjectOutputStream stream) throws IOException {
1078    stream.defaultWriteObject();
1079    stream.writeObject(elementSet().comparator());
1080    Serialization.writeMultiset(this, stream);
1081  }
1082
1083  @GwtIncompatible // java.io.ObjectInputStream
1084  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
1085    stream.defaultReadObject();
1086    @SuppressWarnings("unchecked")
1087    // reading data stored by writeObject
1088    Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
1089    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
1090    Serialization.getFieldSetter(TreeMultiset.class, "range")
1091        .set(this, GeneralRange.all(comparator));
1092    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
1093        .set(this, new Reference<AvlNode<E>>());
1094    AvlNode<E> header = new AvlNode<>();
1095    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
1096    successor(header, header);
1097    Serialization.populateMultiset(this, stream);
1098  }
1099
1100  @GwtIncompatible // not needed in emulated source
1101  private static final long serialVersionUID = 1;
1102}