001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkState;
021import static com.google.common.collect.CollectPreconditions.checkNonnegative;
022import static com.google.common.collect.CollectPreconditions.checkRemove;
023
024import com.google.common.annotations.GwtCompatible;
025import com.google.common.annotations.GwtIncompatible;
026import com.google.common.base.MoreObjects;
027import com.google.common.primitives.Ints;
028import com.google.errorprone.annotations.CanIgnoreReturnValue;
029import java.io.IOException;
030import java.io.ObjectInputStream;
031import java.io.ObjectOutputStream;
032import java.io.Serializable;
033import java.util.Comparator;
034import java.util.ConcurrentModificationException;
035import java.util.Iterator;
036import java.util.NoSuchElementException;
037import javax.annotation.Nullable;
038
039/**
040 * A multiset which maintains the ordering of its elements, according to either their natural order
041 * or an explicit {@link Comparator}. In all cases, this implementation uses
042 * {@link Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to
043 * determine equivalence of instances.
044 *
045 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
046 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the
047 * {@link java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
048 *
049 * <p>See the Guava User Guide article on <a href=
050 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset">
051 * {@code Multiset}</a>.
052 *
053 * @author Louis Wasserman
054 * @author Jared Levy
055 * @since 2.0
056 */
057@GwtCompatible(emulated = true)
058public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements Serializable {
059
060  /**
061   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
062   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
063   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
064   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
065   * user attempts to add an element to the multiset that violates this constraint (for example,
066   * the user attempts to add a string element to a set whose elements are integers), the
067   * {@code add(Object)} call will throw a {@code ClassCastException}.
068   *
069   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
070   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
071   */
072  public static <E extends Comparable> TreeMultiset<E> create() {
073    return new TreeMultiset<E>(Ordering.natural());
074  }
075
076  /**
077   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
078   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
079   * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements
080   * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the
081   * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code
082   * ClassCastException}.
083   *
084   * @param comparator the comparator that will be used to sort this multiset. A null value
085   *     indicates that the elements' <i>natural ordering</i> should be used.
086   */
087  @SuppressWarnings("unchecked")
088  public static <E> TreeMultiset<E> create(@Nullable Comparator<? super E> comparator) {
089    return (comparator == null)
090        ? new TreeMultiset<E>((Comparator) Ordering.natural())
091        : new TreeMultiset<E>(comparator);
092  }
093
094  /**
095   * Creates an empty multiset containing the given initial elements, sorted according to the
096   * elements' natural order.
097   *
098   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
099   *
100   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
101   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
102   */
103  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
104    TreeMultiset<E> multiset = create();
105    Iterables.addAll(multiset, elements);
106    return multiset;
107  }
108
109  private final transient Reference<AvlNode<E>> rootReference;
110  private final transient GeneralRange<E> range;
111  private final transient AvlNode<E> header;
112
113  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
114    super(range.comparator());
115    this.rootReference = rootReference;
116    this.range = range;
117    this.header = endLink;
118  }
119
120  TreeMultiset(Comparator<? super E> comparator) {
121    super(comparator);
122    this.range = GeneralRange.all(comparator);
123    this.header = new AvlNode<E>(null, 1);
124    successor(header, header);
125    this.rootReference = new Reference<AvlNode<E>>();
126  }
127
128  /**
129   * A function which can be summed across a subtree.
130   */
131  private enum Aggregate {
132    SIZE {
133      @Override
134      int nodeAggregate(AvlNode<?> node) {
135        return node.elemCount;
136      }
137
138      @Override
139      long treeAggregate(@Nullable AvlNode<?> root) {
140        return (root == null) ? 0 : root.totalCount;
141      }
142    },
143    DISTINCT {
144      @Override
145      int nodeAggregate(AvlNode<?> node) {
146        return 1;
147      }
148
149      @Override
150      long treeAggregate(@Nullable AvlNode<?> root) {
151        return (root == null) ? 0 : root.distinctElements;
152      }
153    };
154
155    abstract int nodeAggregate(AvlNode<?> node);
156
157    abstract long treeAggregate(@Nullable AvlNode<?> root);
158  }
159
160  private long aggregateForEntries(Aggregate aggr) {
161    AvlNode<E> root = rootReference.get();
162    long total = aggr.treeAggregate(root);
163    if (range.hasLowerBound()) {
164      total -= aggregateBelowRange(aggr, root);
165    }
166    if (range.hasUpperBound()) {
167      total -= aggregateAboveRange(aggr, root);
168    }
169    return total;
170  }
171
172  private long aggregateBelowRange(Aggregate aggr, @Nullable AvlNode<E> node) {
173    if (node == null) {
174      return 0;
175    }
176    int cmp = comparator().compare(range.getLowerEndpoint(), node.elem);
177    if (cmp < 0) {
178      return aggregateBelowRange(aggr, node.left);
179    } else if (cmp == 0) {
180      switch (range.getLowerBoundType()) {
181        case OPEN:
182          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
183        case CLOSED:
184          return aggr.treeAggregate(node.left);
185        default:
186          throw new AssertionError();
187      }
188    } else {
189      return aggr.treeAggregate(node.left)
190          + aggr.nodeAggregate(node)
191          + aggregateBelowRange(aggr, node.right);
192    }
193  }
194
195  private long aggregateAboveRange(Aggregate aggr, @Nullable AvlNode<E> node) {
196    if (node == null) {
197      return 0;
198    }
199    int cmp = comparator().compare(range.getUpperEndpoint(), node.elem);
200    if (cmp > 0) {
201      return aggregateAboveRange(aggr, node.right);
202    } else if (cmp == 0) {
203      switch (range.getUpperBoundType()) {
204        case OPEN:
205          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
206        case CLOSED:
207          return aggr.treeAggregate(node.right);
208        default:
209          throw new AssertionError();
210      }
211    } else {
212      return aggr.treeAggregate(node.right)
213          + aggr.nodeAggregate(node)
214          + aggregateAboveRange(aggr, node.left);
215    }
216  }
217
218  @Override
219  public int size() {
220    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
221  }
222
223  @Override
224  int distinctElements() {
225    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
226  }
227
228  @Override
229  public int count(@Nullable Object element) {
230    try {
231      @SuppressWarnings("unchecked")
232      E e = (E) element;
233      AvlNode<E> root = rootReference.get();
234      if (!range.contains(e) || root == null) {
235        return 0;
236      }
237      return root.count(comparator(), e);
238    } catch (ClassCastException e) {
239      return 0;
240    } catch (NullPointerException e) {
241      return 0;
242    }
243  }
244
245  @CanIgnoreReturnValue
246  @Override
247  public int add(@Nullable E element, int occurrences) {
248    checkNonnegative(occurrences, "occurrences");
249    if (occurrences == 0) {
250      return count(element);
251    }
252    checkArgument(range.contains(element));
253    AvlNode<E> root = rootReference.get();
254    if (root == null) {
255      comparator().compare(element, element);
256      AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
257      successor(header, newRoot, header);
258      rootReference.checkAndSet(root, newRoot);
259      return 0;
260    }
261    int[] result = new int[1]; // used as a mutable int reference to hold result
262    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
263    rootReference.checkAndSet(root, newRoot);
264    return result[0];
265  }
266
267  @CanIgnoreReturnValue
268  @Override
269  public int remove(@Nullable Object element, int occurrences) {
270    checkNonnegative(occurrences, "occurrences");
271    if (occurrences == 0) {
272      return count(element);
273    }
274    AvlNode<E> root = rootReference.get();
275    int[] result = new int[1]; // used as a mutable int reference to hold result
276    AvlNode<E> newRoot;
277    try {
278      @SuppressWarnings("unchecked")
279      E e = (E) element;
280      if (!range.contains(e) || root == null) {
281        return 0;
282      }
283      newRoot = root.remove(comparator(), e, occurrences, result);
284    } catch (ClassCastException e) {
285      return 0;
286    } catch (NullPointerException e) {
287      return 0;
288    }
289    rootReference.checkAndSet(root, newRoot);
290    return result[0];
291  }
292
293  @CanIgnoreReturnValue
294  @Override
295  public int setCount(@Nullable E element, int count) {
296    checkNonnegative(count, "count");
297    if (!range.contains(element)) {
298      checkArgument(count == 0);
299      return 0;
300    }
301
302    AvlNode<E> root = rootReference.get();
303    if (root == null) {
304      if (count > 0) {
305        add(element, count);
306      }
307      return 0;
308    }
309    int[] result = new int[1]; // used as a mutable int reference to hold result
310    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
311    rootReference.checkAndSet(root, newRoot);
312    return result[0];
313  }
314
315  @CanIgnoreReturnValue
316  @Override
317  public boolean setCount(@Nullable E element, int oldCount, int newCount) {
318    checkNonnegative(newCount, "newCount");
319    checkNonnegative(oldCount, "oldCount");
320    checkArgument(range.contains(element));
321
322    AvlNode<E> root = rootReference.get();
323    if (root == null) {
324      if (oldCount == 0) {
325        if (newCount > 0) {
326          add(element, newCount);
327        }
328        return true;
329      } else {
330        return false;
331      }
332    }
333    int[] result = new int[1]; // used as a mutable int reference to hold result
334    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
335    rootReference.checkAndSet(root, newRoot);
336    return result[0] == oldCount;
337  }
338
339  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
340    return new Multisets.AbstractEntry<E>() {
341      @Override
342      public E getElement() {
343        return baseEntry.getElement();
344      }
345
346      @Override
347      public int getCount() {
348        int result = baseEntry.getCount();
349        if (result == 0) {
350          return count(getElement());
351        } else {
352          return result;
353        }
354      }
355    };
356  }
357
358  /**
359   * Returns the first node in the tree that is in range.
360   */
361  @Nullable
362  private AvlNode<E> firstNode() {
363    AvlNode<E> root = rootReference.get();
364    if (root == null) {
365      return null;
366    }
367    AvlNode<E> node;
368    if (range.hasLowerBound()) {
369      E endpoint = range.getLowerEndpoint();
370      node = rootReference.get().ceiling(comparator(), endpoint);
371      if (node == null) {
372        return null;
373      }
374      if (range.getLowerBoundType() == BoundType.OPEN
375          && comparator().compare(endpoint, node.getElement()) == 0) {
376        node = node.succ;
377      }
378    } else {
379      node = header.succ;
380    }
381    return (node == header || !range.contains(node.getElement())) ? null : node;
382  }
383
384  @Nullable
385  private AvlNode<E> lastNode() {
386    AvlNode<E> root = rootReference.get();
387    if (root == null) {
388      return null;
389    }
390    AvlNode<E> node;
391    if (range.hasUpperBound()) {
392      E endpoint = range.getUpperEndpoint();
393      node = rootReference.get().floor(comparator(), endpoint);
394      if (node == null) {
395        return null;
396      }
397      if (range.getUpperBoundType() == BoundType.OPEN
398          && comparator().compare(endpoint, node.getElement()) == 0) {
399        node = node.pred;
400      }
401    } else {
402      node = header.pred;
403    }
404    return (node == header || !range.contains(node.getElement())) ? null : node;
405  }
406
407  @Override
408  Iterator<Entry<E>> entryIterator() {
409    return new Iterator<Entry<E>>() {
410      AvlNode<E> current = firstNode();
411      Entry<E> prevEntry;
412
413      @Override
414      public boolean hasNext() {
415        if (current == null) {
416          return false;
417        } else if (range.tooHigh(current.getElement())) {
418          current = null;
419          return false;
420        } else {
421          return true;
422        }
423      }
424
425      @Override
426      public Entry<E> next() {
427        if (!hasNext()) {
428          throw new NoSuchElementException();
429        }
430        Entry<E> result = wrapEntry(current);
431        prevEntry = result;
432        if (current.succ == header) {
433          current = null;
434        } else {
435          current = current.succ;
436        }
437        return result;
438      }
439
440      @Override
441      public void remove() {
442        checkRemove(prevEntry != null);
443        setCount(prevEntry.getElement(), 0);
444        prevEntry = null;
445      }
446    };
447  }
448
449  @Override
450  Iterator<Entry<E>> descendingEntryIterator() {
451    return new Iterator<Entry<E>>() {
452      AvlNode<E> current = lastNode();
453      Entry<E> prevEntry = null;
454
455      @Override
456      public boolean hasNext() {
457        if (current == null) {
458          return false;
459        } else if (range.tooLow(current.getElement())) {
460          current = null;
461          return false;
462        } else {
463          return true;
464        }
465      }
466
467      @Override
468      public Entry<E> next() {
469        if (!hasNext()) {
470          throw new NoSuchElementException();
471        }
472        Entry<E> result = wrapEntry(current);
473        prevEntry = result;
474        if (current.pred == header) {
475          current = null;
476        } else {
477          current = current.pred;
478        }
479        return result;
480      }
481
482      @Override
483      public void remove() {
484        checkRemove(prevEntry != null);
485        setCount(prevEntry.getElement(), 0);
486        prevEntry = null;
487      }
488    };
489  }
490
491  @Override
492  public SortedMultiset<E> headMultiset(@Nullable E upperBound, BoundType boundType) {
493    return new TreeMultiset<E>(
494        rootReference,
495        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
496        header);
497  }
498
499  @Override
500  public SortedMultiset<E> tailMultiset(@Nullable E lowerBound, BoundType boundType) {
501    return new TreeMultiset<E>(
502        rootReference,
503        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
504        header);
505  }
506
507  static int distinctElements(@Nullable AvlNode<?> node) {
508    return (node == null) ? 0 : node.distinctElements;
509  }
510
511  private static final class Reference<T> {
512    @Nullable private T value;
513
514    @Nullable
515    public T get() {
516      return value;
517    }
518
519    public void checkAndSet(@Nullable T expected, T newValue) {
520      if (value != expected) {
521        throw new ConcurrentModificationException();
522      }
523      value = newValue;
524    }
525  }
526
527  private static final class AvlNode<E> extends Multisets.AbstractEntry<E> {
528    @Nullable private final E elem;
529
530    // elemCount is 0 iff this node has been deleted.
531    private int elemCount;
532
533    private int distinctElements;
534    private long totalCount;
535    private int height;
536    private AvlNode<E> left;
537    private AvlNode<E> right;
538    private AvlNode<E> pred;
539    private AvlNode<E> succ;
540
541    AvlNode(@Nullable E elem, int elemCount) {
542      checkArgument(elemCount > 0);
543      this.elem = elem;
544      this.elemCount = elemCount;
545      this.totalCount = elemCount;
546      this.distinctElements = 1;
547      this.height = 1;
548      this.left = null;
549      this.right = null;
550    }
551
552    public int count(Comparator<? super E> comparator, E e) {
553      int cmp = comparator.compare(e, elem);
554      if (cmp < 0) {
555        return (left == null) ? 0 : left.count(comparator, e);
556      } else if (cmp > 0) {
557        return (right == null) ? 0 : right.count(comparator, e);
558      } else {
559        return elemCount;
560      }
561    }
562
563    private AvlNode<E> addRightChild(E e, int count) {
564      right = new AvlNode<E>(e, count);
565      successor(this, right, succ);
566      height = Math.max(2, height);
567      distinctElements++;
568      totalCount += count;
569      return this;
570    }
571
572    private AvlNode<E> addLeftChild(E e, int count) {
573      left = new AvlNode<E>(e, count);
574      successor(pred, left, this);
575      height = Math.max(2, height);
576      distinctElements++;
577      totalCount += count;
578      return this;
579    }
580
581    AvlNode<E> add(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
582      /*
583       * It speeds things up considerably to unconditionally add count to totalCount here,
584       * but that destroys failure atomicity in the case of count overflow. =(
585       */
586      int cmp = comparator.compare(e, elem);
587      if (cmp < 0) {
588        AvlNode<E> initLeft = left;
589        if (initLeft == null) {
590          result[0] = 0;
591          return addLeftChild(e, count);
592        }
593        int initHeight = initLeft.height;
594
595        left = initLeft.add(comparator, e, count, result);
596        if (result[0] == 0) {
597          distinctElements++;
598        }
599        this.totalCount += count;
600        return (left.height == initHeight) ? this : rebalance();
601      } else if (cmp > 0) {
602        AvlNode<E> initRight = right;
603        if (initRight == null) {
604          result[0] = 0;
605          return addRightChild(e, count);
606        }
607        int initHeight = initRight.height;
608
609        right = initRight.add(comparator, e, count, result);
610        if (result[0] == 0) {
611          distinctElements++;
612        }
613        this.totalCount += count;
614        return (right.height == initHeight) ? this : rebalance();
615      }
616
617      // adding count to me!  No rebalance possible.
618      result[0] = elemCount;
619      long resultCount = (long) elemCount + count;
620      checkArgument(resultCount <= Integer.MAX_VALUE);
621      this.elemCount += count;
622      this.totalCount += count;
623      return this;
624    }
625
626    AvlNode<E> remove(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
627      int cmp = comparator.compare(e, elem);
628      if (cmp < 0) {
629        AvlNode<E> initLeft = left;
630        if (initLeft == null) {
631          result[0] = 0;
632          return this;
633        }
634
635        left = initLeft.remove(comparator, e, count, result);
636
637        if (result[0] > 0) {
638          if (count >= result[0]) {
639            this.distinctElements--;
640            this.totalCount -= result[0];
641          } else {
642            this.totalCount -= count;
643          }
644        }
645        return (result[0] == 0) ? this : rebalance();
646      } else if (cmp > 0) {
647        AvlNode<E> initRight = right;
648        if (initRight == null) {
649          result[0] = 0;
650          return this;
651        }
652
653        right = initRight.remove(comparator, e, count, result);
654
655        if (result[0] > 0) {
656          if (count >= result[0]) {
657            this.distinctElements--;
658            this.totalCount -= result[0];
659          } else {
660            this.totalCount -= count;
661          }
662        }
663        return rebalance();
664      }
665
666      // removing count from me!
667      result[0] = elemCount;
668      if (count >= elemCount) {
669        return deleteMe();
670      } else {
671        this.elemCount -= count;
672        this.totalCount -= count;
673        return this;
674      }
675    }
676
677    AvlNode<E> setCount(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
678      int cmp = comparator.compare(e, elem);
679      if (cmp < 0) {
680        AvlNode<E> initLeft = left;
681        if (initLeft == null) {
682          result[0] = 0;
683          return (count > 0) ? addLeftChild(e, count) : this;
684        }
685
686        left = initLeft.setCount(comparator, e, count, result);
687
688        if (count == 0 && result[0] != 0) {
689          this.distinctElements--;
690        } else if (count > 0 && result[0] == 0) {
691          this.distinctElements++;
692        }
693
694        this.totalCount += count - result[0];
695        return rebalance();
696      } else if (cmp > 0) {
697        AvlNode<E> initRight = right;
698        if (initRight == null) {
699          result[0] = 0;
700          return (count > 0) ? addRightChild(e, count) : this;
701        }
702
703        right = initRight.setCount(comparator, e, count, result);
704
705        if (count == 0 && result[0] != 0) {
706          this.distinctElements--;
707        } else if (count > 0 && result[0] == 0) {
708          this.distinctElements++;
709        }
710
711        this.totalCount += count - result[0];
712        return rebalance();
713      }
714
715      // setting my count
716      result[0] = elemCount;
717      if (count == 0) {
718        return deleteMe();
719      }
720      this.totalCount += count - elemCount;
721      this.elemCount = count;
722      return this;
723    }
724
725    AvlNode<E> setCount(
726        Comparator<? super E> comparator,
727        @Nullable E e,
728        int expectedCount,
729        int newCount,
730        int[] result) {
731      int cmp = comparator.compare(e, elem);
732      if (cmp < 0) {
733        AvlNode<E> initLeft = left;
734        if (initLeft == null) {
735          result[0] = 0;
736          if (expectedCount == 0 && newCount > 0) {
737            return addLeftChild(e, newCount);
738          }
739          return this;
740        }
741
742        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
743
744        if (result[0] == expectedCount) {
745          if (newCount == 0 && result[0] != 0) {
746            this.distinctElements--;
747          } else if (newCount > 0 && result[0] == 0) {
748            this.distinctElements++;
749          }
750          this.totalCount += newCount - result[0];
751        }
752        return rebalance();
753      } else if (cmp > 0) {
754        AvlNode<E> initRight = right;
755        if (initRight == null) {
756          result[0] = 0;
757          if (expectedCount == 0 && newCount > 0) {
758            return addRightChild(e, newCount);
759          }
760          return this;
761        }
762
763        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
764
765        if (result[0] == expectedCount) {
766          if (newCount == 0 && result[0] != 0) {
767            this.distinctElements--;
768          } else if (newCount > 0 && result[0] == 0) {
769            this.distinctElements++;
770          }
771          this.totalCount += newCount - result[0];
772        }
773        return rebalance();
774      }
775
776      // setting my count
777      result[0] = elemCount;
778      if (expectedCount == elemCount) {
779        if (newCount == 0) {
780          return deleteMe();
781        }
782        this.totalCount += newCount - elemCount;
783        this.elemCount = newCount;
784      }
785      return this;
786    }
787
788    private AvlNode<E> deleteMe() {
789      int oldElemCount = this.elemCount;
790      this.elemCount = 0;
791      successor(pred, succ);
792      if (left == null) {
793        return right;
794      } else if (right == null) {
795        return left;
796      } else if (left.height >= right.height) {
797        AvlNode<E> newTop = pred;
798        // newTop is the maximum node in my left subtree
799        newTop.left = left.removeMax(newTop);
800        newTop.right = right;
801        newTop.distinctElements = distinctElements - 1;
802        newTop.totalCount = totalCount - oldElemCount;
803        return newTop.rebalance();
804      } else {
805        AvlNode<E> newTop = succ;
806        newTop.right = right.removeMin(newTop);
807        newTop.left = left;
808        newTop.distinctElements = distinctElements - 1;
809        newTop.totalCount = totalCount - oldElemCount;
810        return newTop.rebalance();
811      }
812    }
813
814    // Removes the minimum node from this subtree to be reused elsewhere
815    private AvlNode<E> removeMin(AvlNode<E> node) {
816      if (left == null) {
817        return right;
818      } else {
819        left = left.removeMin(node);
820        distinctElements--;
821        totalCount -= node.elemCount;
822        return rebalance();
823      }
824    }
825
826    // Removes the maximum node from this subtree to be reused elsewhere
827    private AvlNode<E> removeMax(AvlNode<E> node) {
828      if (right == null) {
829        return left;
830      } else {
831        right = right.removeMax(node);
832        distinctElements--;
833        totalCount -= node.elemCount;
834        return rebalance();
835      }
836    }
837
838    private void recomputeMultiset() {
839      this.distinctElements =
840          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
841      this.totalCount = elemCount + totalCount(left) + totalCount(right);
842    }
843
844    private void recomputeHeight() {
845      this.height = 1 + Math.max(height(left), height(right));
846    }
847
848    private void recompute() {
849      recomputeMultiset();
850      recomputeHeight();
851    }
852
853    private AvlNode<E> rebalance() {
854      switch (balanceFactor()) {
855        case -2:
856          if (right.balanceFactor() > 0) {
857            right = right.rotateRight();
858          }
859          return rotateLeft();
860        case 2:
861          if (left.balanceFactor() < 0) {
862            left = left.rotateLeft();
863          }
864          return rotateRight();
865        default:
866          recomputeHeight();
867          return this;
868      }
869    }
870
871    private int balanceFactor() {
872      return height(left) - height(right);
873    }
874
875    private AvlNode<E> rotateLeft() {
876      checkState(right != null);
877      AvlNode<E> newTop = right;
878      this.right = newTop.left;
879      newTop.left = this;
880      newTop.totalCount = this.totalCount;
881      newTop.distinctElements = this.distinctElements;
882      this.recompute();
883      newTop.recomputeHeight();
884      return newTop;
885    }
886
887    private AvlNode<E> rotateRight() {
888      checkState(left != null);
889      AvlNode<E> newTop = left;
890      this.left = newTop.right;
891      newTop.right = this;
892      newTop.totalCount = this.totalCount;
893      newTop.distinctElements = this.distinctElements;
894      this.recompute();
895      newTop.recomputeHeight();
896      return newTop;
897    }
898
899    private static long totalCount(@Nullable AvlNode<?> node) {
900      return (node == null) ? 0 : node.totalCount;
901    }
902
903    private static int height(@Nullable AvlNode<?> node) {
904      return (node == null) ? 0 : node.height;
905    }
906
907    @Nullable
908    private AvlNode<E> ceiling(Comparator<? super E> comparator, E e) {
909      int cmp = comparator.compare(e, elem);
910      if (cmp < 0) {
911        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
912      } else if (cmp == 0) {
913        return this;
914      } else {
915        return (right == null) ? null : right.ceiling(comparator, e);
916      }
917    }
918
919    @Nullable
920    private AvlNode<E> floor(Comparator<? super E> comparator, E e) {
921      int cmp = comparator.compare(e, elem);
922      if (cmp > 0) {
923        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
924      } else if (cmp == 0) {
925        return this;
926      } else {
927        return (left == null) ? null : left.floor(comparator, e);
928      }
929    }
930
931    @Override
932    public E getElement() {
933      return elem;
934    }
935
936    @Override
937    public int getCount() {
938      return elemCount;
939    }
940
941    @Override
942    public String toString() {
943      return Multisets.immutableEntry(getElement(), getCount()).toString();
944    }
945  }
946
947  private static <T> void successor(AvlNode<T> a, AvlNode<T> b) {
948    a.succ = b;
949    b.pred = a;
950  }
951
952  private static <T> void successor(AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
953    successor(a, b);
954    successor(b, c);
955  }
956
957  /*
958   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
959   * calls the comparator to compare the two keys. If that change is made,
960   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
961   */
962
963  /**
964   * @serialData the comparator, the number of distinct elements, the first element, its count, the
965   *             second element, its count, and so on
966   */
967  @GwtIncompatible // java.io.ObjectOutputStream
968  private void writeObject(ObjectOutputStream stream) throws IOException {
969    stream.defaultWriteObject();
970    stream.writeObject(elementSet().comparator());
971    Serialization.writeMultiset(this, stream);
972  }
973
974  @GwtIncompatible // java.io.ObjectInputStream
975  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
976    stream.defaultReadObject();
977    @SuppressWarnings("unchecked")
978    // reading data stored by writeObject
979    Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
980    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
981    Serialization.getFieldSetter(TreeMultiset.class, "range")
982        .set(this, GeneralRange.all(comparator));
983    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
984        .set(this, new Reference<AvlNode<E>>());
985    AvlNode<E> header = new AvlNode<E>(null, 1);
986    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
987    successor(header, header);
988    Serialization.populateMultiset(this, stream);
989  }
990
991  @GwtIncompatible // not needed in emulated source
992  private static final long serialVersionUID = 1;
993}