001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkNotNull;
021import static com.google.common.base.Preconditions.checkState;
022import static com.google.common.collect.CollectPreconditions.checkNonnegative;
023import static com.google.common.collect.CollectPreconditions.checkRemove;
024
025import com.google.common.annotations.GwtCompatible;
026import com.google.common.annotations.GwtIncompatible;
027import com.google.common.base.MoreObjects;
028import com.google.common.primitives.Ints;
029import com.google.errorprone.annotations.CanIgnoreReturnValue;
030import java.io.IOException;
031import java.io.ObjectInputStream;
032import java.io.ObjectOutputStream;
033import java.io.Serializable;
034import java.util.Comparator;
035import java.util.ConcurrentModificationException;
036import java.util.Iterator;
037import java.util.NoSuchElementException;
038import java.util.function.ObjIntConsumer;
039import org.checkerframework.checker.nullness.compatqual.NullableDecl;
040
041/**
042 * A multiset which maintains the ordering of its elements, according to either their natural order
043 * or an explicit {@link Comparator}. In all cases, this implementation uses {@link
044 * Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to determine
045 * equivalence of instances.
046 *
047 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
048 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the {@link
049 * java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
050 *
051 * <p>See the Guava User Guide article on <a href=
052 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset"> {@code
053 * Multiset}</a>.
054 *
055 * @author Louis Wasserman
056 * @author Jared Levy
057 * @since 2.0
058 */
059@GwtCompatible(emulated = true)
060public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements Serializable {
061
062  /**
063   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
064   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
065   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
066   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
067   * user attempts to add an element to the multiset that violates this constraint (for example, the
068   * user attempts to add a string element to a set whose elements are integers), the {@code
069   * add(Object)} call will throw a {@code ClassCastException}.
070   *
071   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
072   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
073   */
074  public static <E extends Comparable> TreeMultiset<E> create() {
075    return new TreeMultiset<E>(Ordering.natural());
076  }
077
078  /**
079   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
080   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
081   * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements
082   * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the
083   * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code
084   * ClassCastException}.
085   *
086   * @param comparator the comparator that will be used to sort this multiset. A null value
087   *     indicates that the elements' <i>natural ordering</i> should be used.
088   */
089  @SuppressWarnings("unchecked")
090  public static <E> TreeMultiset<E> create(@NullableDecl Comparator<? super E> comparator) {
091    return (comparator == null)
092        ? new TreeMultiset<E>((Comparator) Ordering.natural())
093        : new TreeMultiset<E>(comparator);
094  }
095
096  /**
097   * Creates an empty multiset containing the given initial elements, sorted according to the
098   * elements' natural order.
099   *
100   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
101   *
102   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
103   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
104   */
105  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
106    TreeMultiset<E> multiset = create();
107    Iterables.addAll(multiset, elements);
108    return multiset;
109  }
110
111  private final transient Reference<AvlNode<E>> rootReference;
112  private final transient GeneralRange<E> range;
113  private final transient AvlNode<E> header;
114
115  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
116    super(range.comparator());
117    this.rootReference = rootReference;
118    this.range = range;
119    this.header = endLink;
120  }
121
122  TreeMultiset(Comparator<? super E> comparator) {
123    super(comparator);
124    this.range = GeneralRange.all(comparator);
125    this.header = new AvlNode<E>(null, 1);
126    successor(header, header);
127    this.rootReference = new Reference<>();
128  }
129
130  /** A function which can be summed across a subtree. */
131  private enum Aggregate {
132    SIZE {
133      @Override
134      int nodeAggregate(AvlNode<?> node) {
135        return node.elemCount;
136      }
137
138      @Override
139      long treeAggregate(@NullableDecl AvlNode<?> root) {
140        return (root == null) ? 0 : root.totalCount;
141      }
142    },
143    DISTINCT {
144      @Override
145      int nodeAggregate(AvlNode<?> node) {
146        return 1;
147      }
148
149      @Override
150      long treeAggregate(@NullableDecl AvlNode<?> root) {
151        return (root == null) ? 0 : root.distinctElements;
152      }
153    };
154
155    abstract int nodeAggregate(AvlNode<?> node);
156
157    abstract long treeAggregate(@NullableDecl AvlNode<?> root);
158  }
159
160  private long aggregateForEntries(Aggregate aggr) {
161    AvlNode<E> root = rootReference.get();
162    long total = aggr.treeAggregate(root);
163    if (range.hasLowerBound()) {
164      total -= aggregateBelowRange(aggr, root);
165    }
166    if (range.hasUpperBound()) {
167      total -= aggregateAboveRange(aggr, root);
168    }
169    return total;
170  }
171
172  private long aggregateBelowRange(Aggregate aggr, @NullableDecl AvlNode<E> node) {
173    if (node == null) {
174      return 0;
175    }
176    int cmp = comparator().compare(range.getLowerEndpoint(), node.elem);
177    if (cmp < 0) {
178      return aggregateBelowRange(aggr, node.left);
179    } else if (cmp == 0) {
180      switch (range.getLowerBoundType()) {
181        case OPEN:
182          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
183        case CLOSED:
184          return aggr.treeAggregate(node.left);
185        default:
186          throw new AssertionError();
187      }
188    } else {
189      return aggr.treeAggregate(node.left)
190          + aggr.nodeAggregate(node)
191          + aggregateBelowRange(aggr, node.right);
192    }
193  }
194
195  private long aggregateAboveRange(Aggregate aggr, @NullableDecl AvlNode<E> node) {
196    if (node == null) {
197      return 0;
198    }
199    int cmp = comparator().compare(range.getUpperEndpoint(), node.elem);
200    if (cmp > 0) {
201      return aggregateAboveRange(aggr, node.right);
202    } else if (cmp == 0) {
203      switch (range.getUpperBoundType()) {
204        case OPEN:
205          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
206        case CLOSED:
207          return aggr.treeAggregate(node.right);
208        default:
209          throw new AssertionError();
210      }
211    } else {
212      return aggr.treeAggregate(node.right)
213          + aggr.nodeAggregate(node)
214          + aggregateAboveRange(aggr, node.left);
215    }
216  }
217
218  @Override
219  public int size() {
220    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
221  }
222
223  @Override
224  int distinctElements() {
225    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
226  }
227
228  @Override
229  public int count(@NullableDecl Object element) {
230    try {
231      @SuppressWarnings("unchecked")
232      E e = (E) element;
233      AvlNode<E> root = rootReference.get();
234      if (!range.contains(e) || root == null) {
235        return 0;
236      }
237      return root.count(comparator(), e);
238    } catch (ClassCastException | NullPointerException e) {
239      return 0;
240    }
241  }
242
243  @CanIgnoreReturnValue
244  @Override
245  public int add(@NullableDecl E element, int occurrences) {
246    checkNonnegative(occurrences, "occurrences");
247    if (occurrences == 0) {
248      return count(element);
249    }
250    checkArgument(range.contains(element));
251    AvlNode<E> root = rootReference.get();
252    if (root == null) {
253      comparator().compare(element, element);
254      AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
255      successor(header, newRoot, header);
256      rootReference.checkAndSet(root, newRoot);
257      return 0;
258    }
259    int[] result = new int[1]; // used as a mutable int reference to hold result
260    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
261    rootReference.checkAndSet(root, newRoot);
262    return result[0];
263  }
264
265  @CanIgnoreReturnValue
266  @Override
267  public int remove(@NullableDecl Object element, int occurrences) {
268    checkNonnegative(occurrences, "occurrences");
269    if (occurrences == 0) {
270      return count(element);
271    }
272    AvlNode<E> root = rootReference.get();
273    int[] result = new int[1]; // used as a mutable int reference to hold result
274    AvlNode<E> newRoot;
275    try {
276      @SuppressWarnings("unchecked")
277      E e = (E) element;
278      if (!range.contains(e) || root == null) {
279        return 0;
280      }
281      newRoot = root.remove(comparator(), e, occurrences, result);
282    } catch (ClassCastException | NullPointerException e) {
283      return 0;
284    }
285    rootReference.checkAndSet(root, newRoot);
286    return result[0];
287  }
288
289  @CanIgnoreReturnValue
290  @Override
291  public int setCount(@NullableDecl E element, int count) {
292    checkNonnegative(count, "count");
293    if (!range.contains(element)) {
294      checkArgument(count == 0);
295      return 0;
296    }
297
298    AvlNode<E> root = rootReference.get();
299    if (root == null) {
300      if (count > 0) {
301        add(element, count);
302      }
303      return 0;
304    }
305    int[] result = new int[1]; // used as a mutable int reference to hold result
306    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
307    rootReference.checkAndSet(root, newRoot);
308    return result[0];
309  }
310
311  @CanIgnoreReturnValue
312  @Override
313  public boolean setCount(@NullableDecl E element, int oldCount, int newCount) {
314    checkNonnegative(newCount, "newCount");
315    checkNonnegative(oldCount, "oldCount");
316    checkArgument(range.contains(element));
317
318    AvlNode<E> root = rootReference.get();
319    if (root == null) {
320      if (oldCount == 0) {
321        if (newCount > 0) {
322          add(element, newCount);
323        }
324        return true;
325      } else {
326        return false;
327      }
328    }
329    int[] result = new int[1]; // used as a mutable int reference to hold result
330    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
331    rootReference.checkAndSet(root, newRoot);
332    return result[0] == oldCount;
333  }
334
335  @Override
336  public void clear() {
337    if (!range.hasLowerBound() && !range.hasUpperBound()) {
338      // We can do this in O(n) rather than removing one by one, which could force rebalancing.
339      for (AvlNode<E> current = header.succ; current != header; ) {
340        AvlNode<E> next = current.succ;
341
342        current.elemCount = 0;
343        // Also clear these fields so that one deleted Entry doesn't retain all elements.
344        current.left = null;
345        current.right = null;
346        current.pred = null;
347        current.succ = null;
348
349        current = next;
350      }
351      successor(header, header);
352      rootReference.clear();
353    } else {
354      // TODO(cpovirk): Perhaps we can optimize in this case, too?
355      Iterators.clear(entryIterator());
356    }
357  }
358
359  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
360    return new Multisets.AbstractEntry<E>() {
361      @Override
362      public E getElement() {
363        return baseEntry.getElement();
364      }
365
366      @Override
367      public int getCount() {
368        int result = baseEntry.getCount();
369        if (result == 0) {
370          return count(getElement());
371        } else {
372          return result;
373        }
374      }
375    };
376  }
377
378  /** Returns the first node in the tree that is in range. */
379  @NullableDecl
380  private AvlNode<E> firstNode() {
381    AvlNode<E> root = rootReference.get();
382    if (root == null) {
383      return null;
384    }
385    AvlNode<E> node;
386    if (range.hasLowerBound()) {
387      E endpoint = range.getLowerEndpoint();
388      node = rootReference.get().ceiling(comparator(), endpoint);
389      if (node == null) {
390        return null;
391      }
392      if (range.getLowerBoundType() == BoundType.OPEN
393          && comparator().compare(endpoint, node.getElement()) == 0) {
394        node = node.succ;
395      }
396    } else {
397      node = header.succ;
398    }
399    return (node == header || !range.contains(node.getElement())) ? null : node;
400  }
401
402  @NullableDecl
403  private AvlNode<E> lastNode() {
404    AvlNode<E> root = rootReference.get();
405    if (root == null) {
406      return null;
407    }
408    AvlNode<E> node;
409    if (range.hasUpperBound()) {
410      E endpoint = range.getUpperEndpoint();
411      node = rootReference.get().floor(comparator(), endpoint);
412      if (node == null) {
413        return null;
414      }
415      if (range.getUpperBoundType() == BoundType.OPEN
416          && comparator().compare(endpoint, node.getElement()) == 0) {
417        node = node.pred;
418      }
419    } else {
420      node = header.pred;
421    }
422    return (node == header || !range.contains(node.getElement())) ? null : node;
423  }
424
425  @Override
426  Iterator<E> elementIterator() {
427    return Multisets.elementIterator(entryIterator());
428  }
429
430  @Override
431  Iterator<Entry<E>> entryIterator() {
432    return new Iterator<Entry<E>>() {
433      AvlNode<E> current = firstNode();
434      @NullableDecl Entry<E> prevEntry;
435
436      @Override
437      public boolean hasNext() {
438        if (current == null) {
439          return false;
440        } else if (range.tooHigh(current.getElement())) {
441          current = null;
442          return false;
443        } else {
444          return true;
445        }
446      }
447
448      @Override
449      public Entry<E> next() {
450        if (!hasNext()) {
451          throw new NoSuchElementException();
452        }
453        Entry<E> result = wrapEntry(current);
454        prevEntry = result;
455        if (current.succ == header) {
456          current = null;
457        } else {
458          current = current.succ;
459        }
460        return result;
461      }
462
463      @Override
464      public void remove() {
465        checkRemove(prevEntry != null);
466        setCount(prevEntry.getElement(), 0);
467        prevEntry = null;
468      }
469    };
470  }
471
472  @Override
473  Iterator<Entry<E>> descendingEntryIterator() {
474    return new Iterator<Entry<E>>() {
475      AvlNode<E> current = lastNode();
476      Entry<E> prevEntry = null;
477
478      @Override
479      public boolean hasNext() {
480        if (current == null) {
481          return false;
482        } else if (range.tooLow(current.getElement())) {
483          current = null;
484          return false;
485        } else {
486          return true;
487        }
488      }
489
490      @Override
491      public Entry<E> next() {
492        if (!hasNext()) {
493          throw new NoSuchElementException();
494        }
495        Entry<E> result = wrapEntry(current);
496        prevEntry = result;
497        if (current.pred == header) {
498          current = null;
499        } else {
500          current = current.pred;
501        }
502        return result;
503      }
504
505      @Override
506      public void remove() {
507        checkRemove(prevEntry != null);
508        setCount(prevEntry.getElement(), 0);
509        prevEntry = null;
510      }
511    };
512  }
513
514  @Override
515  public void forEachEntry(ObjIntConsumer<? super E> action) {
516    checkNotNull(action);
517    for (AvlNode<E> node = firstNode();
518        node != header && node != null && !range.tooHigh(node.getElement());
519        node = node.succ) {
520      action.accept(node.getElement(), node.getCount());
521    }
522  }
523
524  @Override
525  public Iterator<E> iterator() {
526    return Multisets.iteratorImpl(this);
527  }
528
529  @Override
530  public SortedMultiset<E> headMultiset(@NullableDecl E upperBound, BoundType boundType) {
531    return new TreeMultiset<E>(
532        rootReference,
533        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
534        header);
535  }
536
537  @Override
538  public SortedMultiset<E> tailMultiset(@NullableDecl E lowerBound, BoundType boundType) {
539    return new TreeMultiset<E>(
540        rootReference,
541        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
542        header);
543  }
544
545  static int distinctElements(@NullableDecl AvlNode<?> node) {
546    return (node == null) ? 0 : node.distinctElements;
547  }
548
549  private static final class Reference<T> {
550    @NullableDecl private T value;
551
552    @NullableDecl
553    public T get() {
554      return value;
555    }
556
557    public void checkAndSet(@NullableDecl T expected, T newValue) {
558      if (value != expected) {
559        throw new ConcurrentModificationException();
560      }
561      value = newValue;
562    }
563
564    void clear() {
565      value = null;
566    }
567  }
568
569  private static final class AvlNode<E> {
570    @NullableDecl private final E elem;
571
572    // elemCount is 0 iff this node has been deleted.
573    private int elemCount;
574
575    private int distinctElements;
576    private long totalCount;
577    private int height;
578    @NullableDecl private AvlNode<E> left;
579    @NullableDecl private AvlNode<E> right;
580    @NullableDecl private AvlNode<E> pred;
581    @NullableDecl private AvlNode<E> succ;
582
583    AvlNode(@NullableDecl E elem, int elemCount) {
584      checkArgument(elemCount > 0);
585      this.elem = elem;
586      this.elemCount = elemCount;
587      this.totalCount = elemCount;
588      this.distinctElements = 1;
589      this.height = 1;
590      this.left = null;
591      this.right = null;
592    }
593
594    public int count(Comparator<? super E> comparator, E e) {
595      int cmp = comparator.compare(e, elem);
596      if (cmp < 0) {
597        return (left == null) ? 0 : left.count(comparator, e);
598      } else if (cmp > 0) {
599        return (right == null) ? 0 : right.count(comparator, e);
600      } else {
601        return elemCount;
602      }
603    }
604
605    private AvlNode<E> addRightChild(E e, int count) {
606      right = new AvlNode<E>(e, count);
607      successor(this, right, succ);
608      height = Math.max(2, height);
609      distinctElements++;
610      totalCount += count;
611      return this;
612    }
613
614    private AvlNode<E> addLeftChild(E e, int count) {
615      left = new AvlNode<E>(e, count);
616      successor(pred, left, this);
617      height = Math.max(2, height);
618      distinctElements++;
619      totalCount += count;
620      return this;
621    }
622
623    AvlNode<E> add(Comparator<? super E> comparator, @NullableDecl E e, int count, int[] result) {
624      /*
625       * It speeds things up considerably to unconditionally add count to totalCount here,
626       * but that destroys failure atomicity in the case of count overflow. =(
627       */
628      int cmp = comparator.compare(e, elem);
629      if (cmp < 0) {
630        AvlNode<E> initLeft = left;
631        if (initLeft == null) {
632          result[0] = 0;
633          return addLeftChild(e, count);
634        }
635        int initHeight = initLeft.height;
636
637        left = initLeft.add(comparator, e, count, result);
638        if (result[0] == 0) {
639          distinctElements++;
640        }
641        this.totalCount += count;
642        return (left.height == initHeight) ? this : rebalance();
643      } else if (cmp > 0) {
644        AvlNode<E> initRight = right;
645        if (initRight == null) {
646          result[0] = 0;
647          return addRightChild(e, count);
648        }
649        int initHeight = initRight.height;
650
651        right = initRight.add(comparator, e, count, result);
652        if (result[0] == 0) {
653          distinctElements++;
654        }
655        this.totalCount += count;
656        return (right.height == initHeight) ? this : rebalance();
657      }
658
659      // adding count to me!  No rebalance possible.
660      result[0] = elemCount;
661      long resultCount = (long) elemCount + count;
662      checkArgument(resultCount <= Integer.MAX_VALUE);
663      this.elemCount += count;
664      this.totalCount += count;
665      return this;
666    }
667
668    AvlNode<E> remove(
669        Comparator<? super E> comparator, @NullableDecl E e, int count, int[] result) {
670      int cmp = comparator.compare(e, elem);
671      if (cmp < 0) {
672        AvlNode<E> initLeft = left;
673        if (initLeft == null) {
674          result[0] = 0;
675          return this;
676        }
677
678        left = initLeft.remove(comparator, e, count, result);
679
680        if (result[0] > 0) {
681          if (count >= result[0]) {
682            this.distinctElements--;
683            this.totalCount -= result[0];
684          } else {
685            this.totalCount -= count;
686          }
687        }
688        return (result[0] == 0) ? this : rebalance();
689      } else if (cmp > 0) {
690        AvlNode<E> initRight = right;
691        if (initRight == null) {
692          result[0] = 0;
693          return this;
694        }
695
696        right = initRight.remove(comparator, e, count, result);
697
698        if (result[0] > 0) {
699          if (count >= result[0]) {
700            this.distinctElements--;
701            this.totalCount -= result[0];
702          } else {
703            this.totalCount -= count;
704          }
705        }
706        return rebalance();
707      }
708
709      // removing count from me!
710      result[0] = elemCount;
711      if (count >= elemCount) {
712        return deleteMe();
713      } else {
714        this.elemCount -= count;
715        this.totalCount -= count;
716        return this;
717      }
718    }
719
720    AvlNode<E> setCount(
721        Comparator<? super E> comparator, @NullableDecl E e, int count, int[] result) {
722      int cmp = comparator.compare(e, elem);
723      if (cmp < 0) {
724        AvlNode<E> initLeft = left;
725        if (initLeft == null) {
726          result[0] = 0;
727          return (count > 0) ? addLeftChild(e, count) : this;
728        }
729
730        left = initLeft.setCount(comparator, e, count, result);
731
732        if (count == 0 && result[0] != 0) {
733          this.distinctElements--;
734        } else if (count > 0 && result[0] == 0) {
735          this.distinctElements++;
736        }
737
738        this.totalCount += count - result[0];
739        return rebalance();
740      } else if (cmp > 0) {
741        AvlNode<E> initRight = right;
742        if (initRight == null) {
743          result[0] = 0;
744          return (count > 0) ? addRightChild(e, count) : this;
745        }
746
747        right = initRight.setCount(comparator, e, count, result);
748
749        if (count == 0 && result[0] != 0) {
750          this.distinctElements--;
751        } else if (count > 0 && result[0] == 0) {
752          this.distinctElements++;
753        }
754
755        this.totalCount += count - result[0];
756        return rebalance();
757      }
758
759      // setting my count
760      result[0] = elemCount;
761      if (count == 0) {
762        return deleteMe();
763      }
764      this.totalCount += count - elemCount;
765      this.elemCount = count;
766      return this;
767    }
768
769    AvlNode<E> setCount(
770        Comparator<? super E> comparator,
771        @NullableDecl E e,
772        int expectedCount,
773        int newCount,
774        int[] result) {
775      int cmp = comparator.compare(e, elem);
776      if (cmp < 0) {
777        AvlNode<E> initLeft = left;
778        if (initLeft == null) {
779          result[0] = 0;
780          if (expectedCount == 0 && newCount > 0) {
781            return addLeftChild(e, newCount);
782          }
783          return this;
784        }
785
786        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
787
788        if (result[0] == expectedCount) {
789          if (newCount == 0 && result[0] != 0) {
790            this.distinctElements--;
791          } else if (newCount > 0 && result[0] == 0) {
792            this.distinctElements++;
793          }
794          this.totalCount += newCount - result[0];
795        }
796        return rebalance();
797      } else if (cmp > 0) {
798        AvlNode<E> initRight = right;
799        if (initRight == null) {
800          result[0] = 0;
801          if (expectedCount == 0 && newCount > 0) {
802            return addRightChild(e, newCount);
803          }
804          return this;
805        }
806
807        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
808
809        if (result[0] == expectedCount) {
810          if (newCount == 0 && result[0] != 0) {
811            this.distinctElements--;
812          } else if (newCount > 0 && result[0] == 0) {
813            this.distinctElements++;
814          }
815          this.totalCount += newCount - result[0];
816        }
817        return rebalance();
818      }
819
820      // setting my count
821      result[0] = elemCount;
822      if (expectedCount == elemCount) {
823        if (newCount == 0) {
824          return deleteMe();
825        }
826        this.totalCount += newCount - elemCount;
827        this.elemCount = newCount;
828      }
829      return this;
830    }
831
832    private AvlNode<E> deleteMe() {
833      int oldElemCount = this.elemCount;
834      this.elemCount = 0;
835      successor(pred, succ);
836      if (left == null) {
837        return right;
838      } else if (right == null) {
839        return left;
840      } else if (left.height >= right.height) {
841        AvlNode<E> newTop = pred;
842        // newTop is the maximum node in my left subtree
843        newTop.left = left.removeMax(newTop);
844        newTop.right = right;
845        newTop.distinctElements = distinctElements - 1;
846        newTop.totalCount = totalCount - oldElemCount;
847        return newTop.rebalance();
848      } else {
849        AvlNode<E> newTop = succ;
850        newTop.right = right.removeMin(newTop);
851        newTop.left = left;
852        newTop.distinctElements = distinctElements - 1;
853        newTop.totalCount = totalCount - oldElemCount;
854        return newTop.rebalance();
855      }
856    }
857
858    // Removes the minimum node from this subtree to be reused elsewhere
859    private AvlNode<E> removeMin(AvlNode<E> node) {
860      if (left == null) {
861        return right;
862      } else {
863        left = left.removeMin(node);
864        distinctElements--;
865        totalCount -= node.elemCount;
866        return rebalance();
867      }
868    }
869
870    // Removes the maximum node from this subtree to be reused elsewhere
871    private AvlNode<E> removeMax(AvlNode<E> node) {
872      if (right == null) {
873        return left;
874      } else {
875        right = right.removeMax(node);
876        distinctElements--;
877        totalCount -= node.elemCount;
878        return rebalance();
879      }
880    }
881
882    private void recomputeMultiset() {
883      this.distinctElements =
884          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
885      this.totalCount = elemCount + totalCount(left) + totalCount(right);
886    }
887
888    private void recomputeHeight() {
889      this.height = 1 + Math.max(height(left), height(right));
890    }
891
892    private void recompute() {
893      recomputeMultiset();
894      recomputeHeight();
895    }
896
897    private AvlNode<E> rebalance() {
898      switch (balanceFactor()) {
899        case -2:
900          if (right.balanceFactor() > 0) {
901            right = right.rotateRight();
902          }
903          return rotateLeft();
904        case 2:
905          if (left.balanceFactor() < 0) {
906            left = left.rotateLeft();
907          }
908          return rotateRight();
909        default:
910          recomputeHeight();
911          return this;
912      }
913    }
914
915    private int balanceFactor() {
916      return height(left) - height(right);
917    }
918
919    private AvlNode<E> rotateLeft() {
920      checkState(right != null);
921      AvlNode<E> newTop = right;
922      this.right = newTop.left;
923      newTop.left = this;
924      newTop.totalCount = this.totalCount;
925      newTop.distinctElements = this.distinctElements;
926      this.recompute();
927      newTop.recomputeHeight();
928      return newTop;
929    }
930
931    private AvlNode<E> rotateRight() {
932      checkState(left != null);
933      AvlNode<E> newTop = left;
934      this.left = newTop.right;
935      newTop.right = this;
936      newTop.totalCount = this.totalCount;
937      newTop.distinctElements = this.distinctElements;
938      this.recompute();
939      newTop.recomputeHeight();
940      return newTop;
941    }
942
943    private static long totalCount(@NullableDecl AvlNode<?> node) {
944      return (node == null) ? 0 : node.totalCount;
945    }
946
947    private static int height(@NullableDecl AvlNode<?> node) {
948      return (node == null) ? 0 : node.height;
949    }
950
951    @NullableDecl
952    private AvlNode<E> ceiling(Comparator<? super E> comparator, E e) {
953      int cmp = comparator.compare(e, elem);
954      if (cmp < 0) {
955        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
956      } else if (cmp == 0) {
957        return this;
958      } else {
959        return (right == null) ? null : right.ceiling(comparator, e);
960      }
961    }
962
963    @NullableDecl
964    private AvlNode<E> floor(Comparator<? super E> comparator, E e) {
965      int cmp = comparator.compare(e, elem);
966      if (cmp > 0) {
967        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
968      } else if (cmp == 0) {
969        return this;
970      } else {
971        return (left == null) ? null : left.floor(comparator, e);
972      }
973    }
974
975    E getElement() {
976      return elem;
977    }
978
979    int getCount() {
980      return elemCount;
981    }
982
983    @Override
984    public String toString() {
985      return Multisets.immutableEntry(getElement(), getCount()).toString();
986    }
987  }
988
989  private static <T> void successor(AvlNode<T> a, AvlNode<T> b) {
990    a.succ = b;
991    b.pred = a;
992  }
993
994  private static <T> void successor(AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
995    successor(a, b);
996    successor(b, c);
997  }
998
999  /*
1000   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
1001   * calls the comparator to compare the two keys. If that change is made,
1002   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
1003   */
1004
1005  /**
1006   * @serialData the comparator, the number of distinct elements, the first element, its count, the
1007   *     second element, its count, and so on
1008   */
1009  @GwtIncompatible // java.io.ObjectOutputStream
1010  private void writeObject(ObjectOutputStream stream) throws IOException {
1011    stream.defaultWriteObject();
1012    stream.writeObject(elementSet().comparator());
1013    Serialization.writeMultiset(this, stream);
1014  }
1015
1016  @GwtIncompatible // java.io.ObjectInputStream
1017  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
1018    stream.defaultReadObject();
1019    @SuppressWarnings("unchecked")
1020    // reading data stored by writeObject
1021    Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
1022    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
1023    Serialization.getFieldSetter(TreeMultiset.class, "range")
1024        .set(this, GeneralRange.all(comparator));
1025    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
1026        .set(this, new Reference<AvlNode<E>>());
1027    AvlNode<E> header = new AvlNode<E>(null, 1);
1028    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
1029    successor(header, header);
1030    Serialization.populateMultiset(this, stream);
1031  }
1032
1033  @GwtIncompatible // not needed in emulated source
1034  private static final long serialVersionUID = 1;
1035}