001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkNotNull;
021import static com.google.common.base.Preconditions.checkState;
022import static com.google.common.collect.CollectPreconditions.checkNonnegative;
023import static com.google.common.collect.CollectPreconditions.checkRemove;
024
025import com.google.common.annotations.GwtCompatible;
026import com.google.common.annotations.GwtIncompatible;
027import com.google.common.base.MoreObjects;
028import com.google.common.primitives.Ints;
029import com.google.errorprone.annotations.CanIgnoreReturnValue;
030import java.io.IOException;
031import java.io.ObjectInputStream;
032import java.io.ObjectOutputStream;
033import java.io.Serializable;
034import java.util.Comparator;
035import java.util.ConcurrentModificationException;
036import java.util.Iterator;
037import java.util.NoSuchElementException;
038import java.util.function.ObjIntConsumer;
039import org.checkerframework.checker.nullness.qual.Nullable;
040
041/**
042 * A multiset which maintains the ordering of its elements, according to either their natural order
043 * or an explicit {@link Comparator}. In all cases, this implementation uses {@link
044 * Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to determine
045 * equivalence of instances.
046 *
047 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
048 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the {@link
049 * java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
050 *
051 * <p>See the Guava User Guide article on <a href=
052 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset"> {@code
053 * Multiset}</a>.
054 *
055 * @author Louis Wasserman
056 * @author Jared Levy
057 * @since 2.0
058 */
059@GwtCompatible(emulated = true)
060public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements Serializable {
061
062  /**
063   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
064   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
065   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
066   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
067   * user attempts to add an element to the multiset that violates this constraint (for example, the
068   * user attempts to add a string element to a set whose elements are integers), the {@code
069   * add(Object)} call will throw a {@code ClassCastException}.
070   *
071   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
072   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
073   */
074  public static <E extends Comparable> TreeMultiset<E> create() {
075    return new TreeMultiset<E>(Ordering.natural());
076  }
077
078  /**
079   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
080   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
081   * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements
082   * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the
083   * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code
084   * ClassCastException}.
085   *
086   * @param comparator the comparator that will be used to sort this multiset. A null value
087   *     indicates that the elements' <i>natural ordering</i> should be used.
088   */
089  @SuppressWarnings("unchecked")
090  public static <E> TreeMultiset<E> create(@Nullable Comparator<? super E> comparator) {
091    return (comparator == null)
092        ? new TreeMultiset<E>((Comparator) Ordering.natural())
093        : new TreeMultiset<E>(comparator);
094  }
095
096  /**
097   * Creates an empty multiset containing the given initial elements, sorted according to the
098   * elements' natural order.
099   *
100   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
101   *
102   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
103   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
104   */
105  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
106    TreeMultiset<E> multiset = create();
107    Iterables.addAll(multiset, elements);
108    return multiset;
109  }
110
111  private final transient Reference<AvlNode<E>> rootReference;
112  private final transient GeneralRange<E> range;
113  private final transient AvlNode<E> header;
114
115  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
116    super(range.comparator());
117    this.rootReference = rootReference;
118    this.range = range;
119    this.header = endLink;
120  }
121
122  TreeMultiset(Comparator<? super E> comparator) {
123    super(comparator);
124    this.range = GeneralRange.all(comparator);
125    this.header = new AvlNode<E>(null, 1);
126    successor(header, header);
127    this.rootReference = new Reference<>();
128  }
129
130  /** A function which can be summed across a subtree. */
131  private enum Aggregate {
132    SIZE {
133      @Override
134      int nodeAggregate(AvlNode<?> node) {
135        return node.elemCount;
136      }
137
138      @Override
139      long treeAggregate(@Nullable AvlNode<?> root) {
140        return (root == null) ? 0 : root.totalCount;
141      }
142    },
143    DISTINCT {
144      @Override
145      int nodeAggregate(AvlNode<?> node) {
146        return 1;
147      }
148
149      @Override
150      long treeAggregate(@Nullable AvlNode<?> root) {
151        return (root == null) ? 0 : root.distinctElements;
152      }
153    };
154
155    abstract int nodeAggregate(AvlNode<?> node);
156
157    abstract long treeAggregate(@Nullable AvlNode<?> root);
158  }
159
160  private long aggregateForEntries(Aggregate aggr) {
161    AvlNode<E> root = rootReference.get();
162    long total = aggr.treeAggregate(root);
163    if (range.hasLowerBound()) {
164      total -= aggregateBelowRange(aggr, root);
165    }
166    if (range.hasUpperBound()) {
167      total -= aggregateAboveRange(aggr, root);
168    }
169    return total;
170  }
171
172  private long aggregateBelowRange(Aggregate aggr, @Nullable AvlNode<E> node) {
173    if (node == null) {
174      return 0;
175    }
176    int cmp = comparator().compare(range.getLowerEndpoint(), node.elem);
177    if (cmp < 0) {
178      return aggregateBelowRange(aggr, node.left);
179    } else if (cmp == 0) {
180      switch (range.getLowerBoundType()) {
181        case OPEN:
182          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
183        case CLOSED:
184          return aggr.treeAggregate(node.left);
185        default:
186          throw new AssertionError();
187      }
188    } else {
189      return aggr.treeAggregate(node.left)
190          + aggr.nodeAggregate(node)
191          + aggregateBelowRange(aggr, node.right);
192    }
193  }
194
195  private long aggregateAboveRange(Aggregate aggr, @Nullable AvlNode<E> node) {
196    if (node == null) {
197      return 0;
198    }
199    int cmp = comparator().compare(range.getUpperEndpoint(), node.elem);
200    if (cmp > 0) {
201      return aggregateAboveRange(aggr, node.right);
202    } else if (cmp == 0) {
203      switch (range.getUpperBoundType()) {
204        case OPEN:
205          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
206        case CLOSED:
207          return aggr.treeAggregate(node.right);
208        default:
209          throw new AssertionError();
210      }
211    } else {
212      return aggr.treeAggregate(node.right)
213          + aggr.nodeAggregate(node)
214          + aggregateAboveRange(aggr, node.left);
215    }
216  }
217
218  @Override
219  public int size() {
220    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
221  }
222
223  @Override
224  int distinctElements() {
225    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
226  }
227
228  static int distinctElements(@Nullable AvlNode<?> node) {
229    return (node == null) ? 0 : node.distinctElements;
230  }
231
232  @Override
233  public int count(@Nullable Object element) {
234    try {
235      @SuppressWarnings("unchecked")
236      E e = (E) element;
237      AvlNode<E> root = rootReference.get();
238      if (!range.contains(e) || root == null) {
239        return 0;
240      }
241      return root.count(comparator(), e);
242    } catch (ClassCastException | NullPointerException e) {
243      return 0;
244    }
245  }
246
247  @CanIgnoreReturnValue
248  @Override
249  public int add(@Nullable E element, int occurrences) {
250    checkNonnegative(occurrences, "occurrences");
251    if (occurrences == 0) {
252      return count(element);
253    }
254    checkArgument(range.contains(element));
255    AvlNode<E> root = rootReference.get();
256    if (root == null) {
257      comparator().compare(element, element);
258      AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
259      successor(header, newRoot, header);
260      rootReference.checkAndSet(root, newRoot);
261      return 0;
262    }
263    int[] result = new int[1]; // used as a mutable int reference to hold result
264    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
265    rootReference.checkAndSet(root, newRoot);
266    return result[0];
267  }
268
269  @CanIgnoreReturnValue
270  @Override
271  public int remove(@Nullable Object element, int occurrences) {
272    checkNonnegative(occurrences, "occurrences");
273    if (occurrences == 0) {
274      return count(element);
275    }
276    AvlNode<E> root = rootReference.get();
277    int[] result = new int[1]; // used as a mutable int reference to hold result
278    AvlNode<E> newRoot;
279    try {
280      @SuppressWarnings("unchecked")
281      E e = (E) element;
282      if (!range.contains(e) || root == null) {
283        return 0;
284      }
285      newRoot = root.remove(comparator(), e, occurrences, result);
286    } catch (ClassCastException | NullPointerException e) {
287      return 0;
288    }
289    rootReference.checkAndSet(root, newRoot);
290    return result[0];
291  }
292
293  @CanIgnoreReturnValue
294  @Override
295  public int setCount(@Nullable E element, int count) {
296    checkNonnegative(count, "count");
297    if (!range.contains(element)) {
298      checkArgument(count == 0);
299      return 0;
300    }
301
302    AvlNode<E> root = rootReference.get();
303    if (root == null) {
304      if (count > 0) {
305        add(element, count);
306      }
307      return 0;
308    }
309    int[] result = new int[1]; // used as a mutable int reference to hold result
310    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
311    rootReference.checkAndSet(root, newRoot);
312    return result[0];
313  }
314
315  @CanIgnoreReturnValue
316  @Override
317  public boolean setCount(@Nullable E element, int oldCount, int newCount) {
318    checkNonnegative(newCount, "newCount");
319    checkNonnegative(oldCount, "oldCount");
320    checkArgument(range.contains(element));
321
322    AvlNode<E> root = rootReference.get();
323    if (root == null) {
324      if (oldCount == 0) {
325        if (newCount > 0) {
326          add(element, newCount);
327        }
328        return true;
329      } else {
330        return false;
331      }
332    }
333    int[] result = new int[1]; // used as a mutable int reference to hold result
334    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
335    rootReference.checkAndSet(root, newRoot);
336    return result[0] == oldCount;
337  }
338
339  @Override
340  public void clear() {
341    if (!range.hasLowerBound() && !range.hasUpperBound()) {
342      // We can do this in O(n) rather than removing one by one, which could force rebalancing.
343      for (AvlNode<E> current = header.succ; current != header; ) {
344        AvlNode<E> next = current.succ;
345
346        current.elemCount = 0;
347        // Also clear these fields so that one deleted Entry doesn't retain all elements.
348        current.left = null;
349        current.right = null;
350        current.pred = null;
351        current.succ = null;
352
353        current = next;
354      }
355      successor(header, header);
356      rootReference.clear();
357    } else {
358      // TODO(cpovirk): Perhaps we can optimize in this case, too?
359      Iterators.clear(entryIterator());
360    }
361  }
362
363  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
364    return new Multisets.AbstractEntry<E>() {
365      @Override
366      public E getElement() {
367        return baseEntry.getElement();
368      }
369
370      @Override
371      public int getCount() {
372        int result = baseEntry.getCount();
373        if (result == 0) {
374          return count(getElement());
375        } else {
376          return result;
377        }
378      }
379    };
380  }
381
382  /** Returns the first node in the tree that is in range. */
383  private @Nullable AvlNode<E> firstNode() {
384    AvlNode<E> root = rootReference.get();
385    if (root == null) {
386      return null;
387    }
388    AvlNode<E> node;
389    if (range.hasLowerBound()) {
390      E endpoint = range.getLowerEndpoint();
391      node = rootReference.get().ceiling(comparator(), endpoint);
392      if (node == null) {
393        return null;
394      }
395      if (range.getLowerBoundType() == BoundType.OPEN
396          && comparator().compare(endpoint, node.getElement()) == 0) {
397        node = node.succ;
398      }
399    } else {
400      node = header.succ;
401    }
402    return (node == header || !range.contains(node.getElement())) ? null : node;
403  }
404
405  private @Nullable AvlNode<E> lastNode() {
406    AvlNode<E> root = rootReference.get();
407    if (root == null) {
408      return null;
409    }
410    AvlNode<E> node;
411    if (range.hasUpperBound()) {
412      E endpoint = range.getUpperEndpoint();
413      node = rootReference.get().floor(comparator(), endpoint);
414      if (node == null) {
415        return null;
416      }
417      if (range.getUpperBoundType() == BoundType.OPEN
418          && comparator().compare(endpoint, node.getElement()) == 0) {
419        node = node.pred;
420      }
421    } else {
422      node = header.pred;
423    }
424    return (node == header || !range.contains(node.getElement())) ? null : node;
425  }
426
427  @Override
428  Iterator<E> elementIterator() {
429    return Multisets.elementIterator(entryIterator());
430  }
431
432  @Override
433  Iterator<Entry<E>> entryIterator() {
434    return new Iterator<Entry<E>>() {
435      AvlNode<E> current = firstNode();
436      @Nullable Entry<E> prevEntry;
437
438      @Override
439      public boolean hasNext() {
440        if (current == null) {
441          return false;
442        } else if (range.tooHigh(current.getElement())) {
443          current = null;
444          return false;
445        } else {
446          return true;
447        }
448      }
449
450      @Override
451      public Entry<E> next() {
452        if (!hasNext()) {
453          throw new NoSuchElementException();
454        }
455        Entry<E> result = wrapEntry(current);
456        prevEntry = result;
457        if (current.succ == header) {
458          current = null;
459        } else {
460          current = current.succ;
461        }
462        return result;
463      }
464
465      @Override
466      public void remove() {
467        checkRemove(prevEntry != null);
468        setCount(prevEntry.getElement(), 0);
469        prevEntry = null;
470      }
471    };
472  }
473
474  @Override
475  Iterator<Entry<E>> descendingEntryIterator() {
476    return new Iterator<Entry<E>>() {
477      AvlNode<E> current = lastNode();
478      Entry<E> prevEntry = null;
479
480      @Override
481      public boolean hasNext() {
482        if (current == null) {
483          return false;
484        } else if (range.tooLow(current.getElement())) {
485          current = null;
486          return false;
487        } else {
488          return true;
489        }
490      }
491
492      @Override
493      public Entry<E> next() {
494        if (!hasNext()) {
495          throw new NoSuchElementException();
496        }
497        Entry<E> result = wrapEntry(current);
498        prevEntry = result;
499        if (current.pred == header) {
500          current = null;
501        } else {
502          current = current.pred;
503        }
504        return result;
505      }
506
507      @Override
508      public void remove() {
509        checkRemove(prevEntry != null);
510        setCount(prevEntry.getElement(), 0);
511        prevEntry = null;
512      }
513    };
514  }
515
516  @Override
517  public void forEachEntry(ObjIntConsumer<? super E> action) {
518    checkNotNull(action);
519    for (AvlNode<E> node = firstNode();
520        node != header && node != null && !range.tooHigh(node.getElement());
521        node = node.succ) {
522      action.accept(node.getElement(), node.getCount());
523    }
524  }
525
526  @Override
527  public Iterator<E> iterator() {
528    return Multisets.iteratorImpl(this);
529  }
530
531  @Override
532  public SortedMultiset<E> headMultiset(@Nullable E upperBound, BoundType boundType) {
533    return new TreeMultiset<E>(
534        rootReference,
535        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
536        header);
537  }
538
539  @Override
540  public SortedMultiset<E> tailMultiset(@Nullable E lowerBound, BoundType boundType) {
541    return new TreeMultiset<E>(
542        rootReference,
543        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
544        header);
545  }
546
547  private static final class Reference<T> {
548    private @Nullable T value;
549
550    public @Nullable T get() {
551      return value;
552    }
553
554    public void checkAndSet(@Nullable T expected, T newValue) {
555      if (value != expected) {
556        throw new ConcurrentModificationException();
557      }
558      value = newValue;
559    }
560
561    void clear() {
562      value = null;
563    }
564  }
565
566  private static final class AvlNode<E> {
567    private final @Nullable E elem;
568
569    // elemCount is 0 iff this node has been deleted.
570    private int elemCount;
571
572    private int distinctElements;
573    private long totalCount;
574    private int height;
575    private @Nullable AvlNode<E> left;
576    private @Nullable AvlNode<E> right;
577    private @Nullable AvlNode<E> pred;
578    private @Nullable AvlNode<E> succ;
579
580    AvlNode(@Nullable E elem, int elemCount) {
581      checkArgument(elemCount > 0);
582      this.elem = elem;
583      this.elemCount = elemCount;
584      this.totalCount = elemCount;
585      this.distinctElements = 1;
586      this.height = 1;
587      this.left = null;
588      this.right = null;
589    }
590
591    public int count(Comparator<? super E> comparator, E e) {
592      int cmp = comparator.compare(e, elem);
593      if (cmp < 0) {
594        return (left == null) ? 0 : left.count(comparator, e);
595      } else if (cmp > 0) {
596        return (right == null) ? 0 : right.count(comparator, e);
597      } else {
598        return elemCount;
599      }
600    }
601
602    private AvlNode<E> addRightChild(E e, int count) {
603      right = new AvlNode<E>(e, count);
604      successor(this, right, succ);
605      height = Math.max(2, height);
606      distinctElements++;
607      totalCount += count;
608      return this;
609    }
610
611    private AvlNode<E> addLeftChild(E e, int count) {
612      left = new AvlNode<E>(e, count);
613      successor(pred, left, this);
614      height = Math.max(2, height);
615      distinctElements++;
616      totalCount += count;
617      return this;
618    }
619
620    AvlNode<E> add(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
621      /*
622       * It speeds things up considerably to unconditionally add count to totalCount here,
623       * but that destroys failure atomicity in the case of count overflow. =(
624       */
625      int cmp = comparator.compare(e, elem);
626      if (cmp < 0) {
627        AvlNode<E> initLeft = left;
628        if (initLeft == null) {
629          result[0] = 0;
630          return addLeftChild(e, count);
631        }
632        int initHeight = initLeft.height;
633
634        left = initLeft.add(comparator, e, count, result);
635        if (result[0] == 0) {
636          distinctElements++;
637        }
638        this.totalCount += count;
639        return (left.height == initHeight) ? this : rebalance();
640      } else if (cmp > 0) {
641        AvlNode<E> initRight = right;
642        if (initRight == null) {
643          result[0] = 0;
644          return addRightChild(e, count);
645        }
646        int initHeight = initRight.height;
647
648        right = initRight.add(comparator, e, count, result);
649        if (result[0] == 0) {
650          distinctElements++;
651        }
652        this.totalCount += count;
653        return (right.height == initHeight) ? this : rebalance();
654      }
655
656      // adding count to me!  No rebalance possible.
657      result[0] = elemCount;
658      long resultCount = (long) elemCount + count;
659      checkArgument(resultCount <= Integer.MAX_VALUE);
660      this.elemCount += count;
661      this.totalCount += count;
662      return this;
663    }
664
665    AvlNode<E> remove(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
666      int cmp = comparator.compare(e, elem);
667      if (cmp < 0) {
668        AvlNode<E> initLeft = left;
669        if (initLeft == null) {
670          result[0] = 0;
671          return this;
672        }
673
674        left = initLeft.remove(comparator, e, count, result);
675
676        if (result[0] > 0) {
677          if (count >= result[0]) {
678            this.distinctElements--;
679            this.totalCount -= result[0];
680          } else {
681            this.totalCount -= count;
682          }
683        }
684        return (result[0] == 0) ? this : rebalance();
685      } else if (cmp > 0) {
686        AvlNode<E> initRight = right;
687        if (initRight == null) {
688          result[0] = 0;
689          return this;
690        }
691
692        right = initRight.remove(comparator, e, count, result);
693
694        if (result[0] > 0) {
695          if (count >= result[0]) {
696            this.distinctElements--;
697            this.totalCount -= result[0];
698          } else {
699            this.totalCount -= count;
700          }
701        }
702        return rebalance();
703      }
704
705      // removing count from me!
706      result[0] = elemCount;
707      if (count >= elemCount) {
708        return deleteMe();
709      } else {
710        this.elemCount -= count;
711        this.totalCount -= count;
712        return this;
713      }
714    }
715
716    AvlNode<E> setCount(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
717      int cmp = comparator.compare(e, elem);
718      if (cmp < 0) {
719        AvlNode<E> initLeft = left;
720        if (initLeft == null) {
721          result[0] = 0;
722          return (count > 0) ? addLeftChild(e, count) : this;
723        }
724
725        left = initLeft.setCount(comparator, e, count, result);
726
727        if (count == 0 && result[0] != 0) {
728          this.distinctElements--;
729        } else if (count > 0 && result[0] == 0) {
730          this.distinctElements++;
731        }
732
733        this.totalCount += count - result[0];
734        return rebalance();
735      } else if (cmp > 0) {
736        AvlNode<E> initRight = right;
737        if (initRight == null) {
738          result[0] = 0;
739          return (count > 0) ? addRightChild(e, count) : this;
740        }
741
742        right = initRight.setCount(comparator, e, count, result);
743
744        if (count == 0 && result[0] != 0) {
745          this.distinctElements--;
746        } else if (count > 0 && result[0] == 0) {
747          this.distinctElements++;
748        }
749
750        this.totalCount += count - result[0];
751        return rebalance();
752      }
753
754      // setting my count
755      result[0] = elemCount;
756      if (count == 0) {
757        return deleteMe();
758      }
759      this.totalCount += count - elemCount;
760      this.elemCount = count;
761      return this;
762    }
763
764    AvlNode<E> setCount(
765        Comparator<? super E> comparator,
766        @Nullable E e,
767        int expectedCount,
768        int newCount,
769        int[] result) {
770      int cmp = comparator.compare(e, elem);
771      if (cmp < 0) {
772        AvlNode<E> initLeft = left;
773        if (initLeft == null) {
774          result[0] = 0;
775          if (expectedCount == 0 && newCount > 0) {
776            return addLeftChild(e, newCount);
777          }
778          return this;
779        }
780
781        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
782
783        if (result[0] == expectedCount) {
784          if (newCount == 0 && result[0] != 0) {
785            this.distinctElements--;
786          } else if (newCount > 0 && result[0] == 0) {
787            this.distinctElements++;
788          }
789          this.totalCount += newCount - result[0];
790        }
791        return rebalance();
792      } else if (cmp > 0) {
793        AvlNode<E> initRight = right;
794        if (initRight == null) {
795          result[0] = 0;
796          if (expectedCount == 0 && newCount > 0) {
797            return addRightChild(e, newCount);
798          }
799          return this;
800        }
801
802        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
803
804        if (result[0] == expectedCount) {
805          if (newCount == 0 && result[0] != 0) {
806            this.distinctElements--;
807          } else if (newCount > 0 && result[0] == 0) {
808            this.distinctElements++;
809          }
810          this.totalCount += newCount - result[0];
811        }
812        return rebalance();
813      }
814
815      // setting my count
816      result[0] = elemCount;
817      if (expectedCount == elemCount) {
818        if (newCount == 0) {
819          return deleteMe();
820        }
821        this.totalCount += newCount - elemCount;
822        this.elemCount = newCount;
823      }
824      return this;
825    }
826
827    private AvlNode<E> deleteMe() {
828      int oldElemCount = this.elemCount;
829      this.elemCount = 0;
830      successor(pred, succ);
831      if (left == null) {
832        return right;
833      } else if (right == null) {
834        return left;
835      } else if (left.height >= right.height) {
836        AvlNode<E> newTop = pred;
837        // newTop is the maximum node in my left subtree
838        newTop.left = left.removeMax(newTop);
839        newTop.right = right;
840        newTop.distinctElements = distinctElements - 1;
841        newTop.totalCount = totalCount - oldElemCount;
842        return newTop.rebalance();
843      } else {
844        AvlNode<E> newTop = succ;
845        newTop.right = right.removeMin(newTop);
846        newTop.left = left;
847        newTop.distinctElements = distinctElements - 1;
848        newTop.totalCount = totalCount - oldElemCount;
849        return newTop.rebalance();
850      }
851    }
852
853    // Removes the minimum node from this subtree to be reused elsewhere
854    private AvlNode<E> removeMin(AvlNode<E> node) {
855      if (left == null) {
856        return right;
857      } else {
858        left = left.removeMin(node);
859        distinctElements--;
860        totalCount -= node.elemCount;
861        return rebalance();
862      }
863    }
864
865    // Removes the maximum node from this subtree to be reused elsewhere
866    private AvlNode<E> removeMax(AvlNode<E> node) {
867      if (right == null) {
868        return left;
869      } else {
870        right = right.removeMax(node);
871        distinctElements--;
872        totalCount -= node.elemCount;
873        return rebalance();
874      }
875    }
876
877    private void recomputeMultiset() {
878      this.distinctElements =
879          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
880      this.totalCount = elemCount + totalCount(left) + totalCount(right);
881    }
882
883    private void recomputeHeight() {
884      this.height = 1 + Math.max(height(left), height(right));
885    }
886
887    private void recompute() {
888      recomputeMultiset();
889      recomputeHeight();
890    }
891
892    private AvlNode<E> rebalance() {
893      switch (balanceFactor()) {
894        case -2:
895          if (right.balanceFactor() > 0) {
896            right = right.rotateRight();
897          }
898          return rotateLeft();
899        case 2:
900          if (left.balanceFactor() < 0) {
901            left = left.rotateLeft();
902          }
903          return rotateRight();
904        default:
905          recomputeHeight();
906          return this;
907      }
908    }
909
910    private int balanceFactor() {
911      return height(left) - height(right);
912    }
913
914    private AvlNode<E> rotateLeft() {
915      checkState(right != null);
916      AvlNode<E> newTop = right;
917      this.right = newTop.left;
918      newTop.left = this;
919      newTop.totalCount = this.totalCount;
920      newTop.distinctElements = this.distinctElements;
921      this.recompute();
922      newTop.recomputeHeight();
923      return newTop;
924    }
925
926    private AvlNode<E> rotateRight() {
927      checkState(left != null);
928      AvlNode<E> newTop = left;
929      this.left = newTop.right;
930      newTop.right = this;
931      newTop.totalCount = this.totalCount;
932      newTop.distinctElements = this.distinctElements;
933      this.recompute();
934      newTop.recomputeHeight();
935      return newTop;
936    }
937
938    private static long totalCount(@Nullable AvlNode<?> node) {
939      return (node == null) ? 0 : node.totalCount;
940    }
941
942    private static int height(@Nullable AvlNode<?> node) {
943      return (node == null) ? 0 : node.height;
944    }
945
946    private @Nullable AvlNode<E> ceiling(Comparator<? super E> comparator, E e) {
947      int cmp = comparator.compare(e, elem);
948      if (cmp < 0) {
949        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
950      } else if (cmp == 0) {
951        return this;
952      } else {
953        return (right == null) ? null : right.ceiling(comparator, e);
954      }
955    }
956
957    private @Nullable AvlNode<E> floor(Comparator<? super E> comparator, E e) {
958      int cmp = comparator.compare(e, elem);
959      if (cmp > 0) {
960        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
961      } else if (cmp == 0) {
962        return this;
963      } else {
964        return (left == null) ? null : left.floor(comparator, e);
965      }
966    }
967
968    E getElement() {
969      return elem;
970    }
971
972    int getCount() {
973      return elemCount;
974    }
975
976    @Override
977    public String toString() {
978      return Multisets.immutableEntry(getElement(), getCount()).toString();
979    }
980  }
981
982  private static <T> void successor(AvlNode<T> a, AvlNode<T> b) {
983    a.succ = b;
984    b.pred = a;
985  }
986
987  private static <T> void successor(AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
988    successor(a, b);
989    successor(b, c);
990  }
991
992  /*
993   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
994   * calls the comparator to compare the two keys. If that change is made,
995   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
996   */
997
998  /**
999   * @serialData the comparator, the number of distinct elements, the first element, its count, the
1000   *     second element, its count, and so on
1001   */
1002  @GwtIncompatible // java.io.ObjectOutputStream
1003  private void writeObject(ObjectOutputStream stream) throws IOException {
1004    stream.defaultWriteObject();
1005    stream.writeObject(elementSet().comparator());
1006    Serialization.writeMultiset(this, stream);
1007  }
1008
1009  @GwtIncompatible // java.io.ObjectInputStream
1010  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
1011    stream.defaultReadObject();
1012    @SuppressWarnings("unchecked")
1013    // reading data stored by writeObject
1014    Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
1015    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
1016    Serialization.getFieldSetter(TreeMultiset.class, "range")
1017        .set(this, GeneralRange.all(comparator));
1018    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
1019        .set(this, new Reference<AvlNode<E>>());
1020    AvlNode<E> header = new AvlNode<E>(null, 1);
1021    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
1022    successor(header, header);
1023    Serialization.populateMultiset(this, stream);
1024  }
1025
1026  @GwtIncompatible // not needed in emulated source
1027  private static final long serialVersionUID = 1;
1028}