001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkState;
021import static com.google.common.collect.CollectPreconditions.checkNonnegative;
022import static com.google.common.collect.CollectPreconditions.checkRemove;
023
024import com.google.common.annotations.GwtCompatible;
025import com.google.common.annotations.GwtIncompatible;
026import com.google.common.base.MoreObjects;
027import com.google.common.primitives.Ints;
028import com.google.errorprone.annotations.CanIgnoreReturnValue;
029import java.io.IOException;
030import java.io.ObjectInputStream;
031import java.io.ObjectOutputStream;
032import java.io.Serializable;
033import java.util.Comparator;
034import java.util.ConcurrentModificationException;
035import java.util.Iterator;
036import java.util.NoSuchElementException;
037import org.checkerframework.checker.nullness.compatqual.NullableDecl;
038
039/**
040 * A multiset which maintains the ordering of its elements, according to either their natural order
041 * or an explicit {@link Comparator}. In all cases, this implementation uses {@link
042 * Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to determine
043 * equivalence of instances.
044 *
045 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
046 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the {@link
047 * java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
048 *
049 * <p>See the Guava User Guide article on <a href=
050 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset"> {@code
051 * Multiset}</a>.
052 *
053 * @author Louis Wasserman
054 * @author Jared Levy
055 * @since 2.0
056 */
057@GwtCompatible(emulated = true)
058public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements Serializable {
059
060  /**
061   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
062   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
063   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
064   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
065   * user attempts to add an element to the multiset that violates this constraint (for example, the
066   * user attempts to add a string element to a set whose elements are integers), the {@code
067   * add(Object)} call will throw a {@code ClassCastException}.
068   *
069   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
070   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
071   */
072  public static <E extends Comparable> TreeMultiset<E> create() {
073    return new TreeMultiset<E>(Ordering.natural());
074  }
075
076  /**
077   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
078   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
079   * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements
080   * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the
081   * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code
082   * ClassCastException}.
083   *
084   * @param comparator the comparator that will be used to sort this multiset. A null value
085   *     indicates that the elements' <i>natural ordering</i> should be used.
086   */
087  @SuppressWarnings("unchecked")
088  public static <E> TreeMultiset<E> create(@NullableDecl Comparator<? super E> comparator) {
089    return (comparator == null)
090        ? new TreeMultiset<E>((Comparator) Ordering.natural())
091        : new TreeMultiset<E>(comparator);
092  }
093
094  /**
095   * Creates an empty multiset containing the given initial elements, sorted according to the
096   * elements' natural order.
097   *
098   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
099   *
100   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
101   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
102   */
103  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
104    TreeMultiset<E> multiset = create();
105    Iterables.addAll(multiset, elements);
106    return multiset;
107  }
108
109  private final transient Reference<AvlNode<E>> rootReference;
110  private final transient GeneralRange<E> range;
111  private final transient AvlNode<E> header;
112
113  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
114    super(range.comparator());
115    this.rootReference = rootReference;
116    this.range = range;
117    this.header = endLink;
118  }
119
120  TreeMultiset(Comparator<? super E> comparator) {
121    super(comparator);
122    this.range = GeneralRange.all(comparator);
123    this.header = new AvlNode<E>(null, 1);
124    successor(header, header);
125    this.rootReference = new Reference<>();
126  }
127
128  /** A function which can be summed across a subtree. */
129  private enum Aggregate {
130    SIZE {
131      @Override
132      int nodeAggregate(AvlNode<?> node) {
133        return node.elemCount;
134      }
135
136      @Override
137      long treeAggregate(@NullableDecl AvlNode<?> root) {
138        return (root == null) ? 0 : root.totalCount;
139      }
140    },
141    DISTINCT {
142      @Override
143      int nodeAggregate(AvlNode<?> node) {
144        return 1;
145      }
146
147      @Override
148      long treeAggregate(@NullableDecl AvlNode<?> root) {
149        return (root == null) ? 0 : root.distinctElements;
150      }
151    };
152
153    abstract int nodeAggregate(AvlNode<?> node);
154
155    abstract long treeAggregate(@NullableDecl AvlNode<?> root);
156  }
157
158  private long aggregateForEntries(Aggregate aggr) {
159    AvlNode<E> root = rootReference.get();
160    long total = aggr.treeAggregate(root);
161    if (range.hasLowerBound()) {
162      total -= aggregateBelowRange(aggr, root);
163    }
164    if (range.hasUpperBound()) {
165      total -= aggregateAboveRange(aggr, root);
166    }
167    return total;
168  }
169
170  private long aggregateBelowRange(Aggregate aggr, @NullableDecl AvlNode<E> node) {
171    if (node == null) {
172      return 0;
173    }
174    int cmp = comparator().compare(range.getLowerEndpoint(), node.elem);
175    if (cmp < 0) {
176      return aggregateBelowRange(aggr, node.left);
177    } else if (cmp == 0) {
178      switch (range.getLowerBoundType()) {
179        case OPEN:
180          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
181        case CLOSED:
182          return aggr.treeAggregate(node.left);
183        default:
184          throw new AssertionError();
185      }
186    } else {
187      return aggr.treeAggregate(node.left)
188          + aggr.nodeAggregate(node)
189          + aggregateBelowRange(aggr, node.right);
190    }
191  }
192
193  private long aggregateAboveRange(Aggregate aggr, @NullableDecl AvlNode<E> node) {
194    if (node == null) {
195      return 0;
196    }
197    int cmp = comparator().compare(range.getUpperEndpoint(), node.elem);
198    if (cmp > 0) {
199      return aggregateAboveRange(aggr, node.right);
200    } else if (cmp == 0) {
201      switch (range.getUpperBoundType()) {
202        case OPEN:
203          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
204        case CLOSED:
205          return aggr.treeAggregate(node.right);
206        default:
207          throw new AssertionError();
208      }
209    } else {
210      return aggr.treeAggregate(node.right)
211          + aggr.nodeAggregate(node)
212          + aggregateAboveRange(aggr, node.left);
213    }
214  }
215
216  @Override
217  public int size() {
218    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
219  }
220
221  @Override
222  int distinctElements() {
223    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
224  }
225
226  @Override
227  public int count(@NullableDecl Object element) {
228    try {
229      @SuppressWarnings("unchecked")
230      E e = (E) element;
231      AvlNode<E> root = rootReference.get();
232      if (!range.contains(e) || root == null) {
233        return 0;
234      }
235      return root.count(comparator(), e);
236    } catch (ClassCastException | NullPointerException e) {
237      return 0;
238    }
239  }
240
241  @CanIgnoreReturnValue
242  @Override
243  public int add(@NullableDecl E element, int occurrences) {
244    checkNonnegative(occurrences, "occurrences");
245    if (occurrences == 0) {
246      return count(element);
247    }
248    checkArgument(range.contains(element));
249    AvlNode<E> root = rootReference.get();
250    if (root == null) {
251      comparator().compare(element, element);
252      AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
253      successor(header, newRoot, header);
254      rootReference.checkAndSet(root, newRoot);
255      return 0;
256    }
257    int[] result = new int[1]; // used as a mutable int reference to hold result
258    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
259    rootReference.checkAndSet(root, newRoot);
260    return result[0];
261  }
262
263  @CanIgnoreReturnValue
264  @Override
265  public int remove(@NullableDecl Object element, int occurrences) {
266    checkNonnegative(occurrences, "occurrences");
267    if (occurrences == 0) {
268      return count(element);
269    }
270    AvlNode<E> root = rootReference.get();
271    int[] result = new int[1]; // used as a mutable int reference to hold result
272    AvlNode<E> newRoot;
273    try {
274      @SuppressWarnings("unchecked")
275      E e = (E) element;
276      if (!range.contains(e) || root == null) {
277        return 0;
278      }
279      newRoot = root.remove(comparator(), e, occurrences, result);
280    } catch (ClassCastException | NullPointerException e) {
281      return 0;
282    }
283    rootReference.checkAndSet(root, newRoot);
284    return result[0];
285  }
286
287  @CanIgnoreReturnValue
288  @Override
289  public int setCount(@NullableDecl E element, int count) {
290    checkNonnegative(count, "count");
291    if (!range.contains(element)) {
292      checkArgument(count == 0);
293      return 0;
294    }
295
296    AvlNode<E> root = rootReference.get();
297    if (root == null) {
298      if (count > 0) {
299        add(element, count);
300      }
301      return 0;
302    }
303    int[] result = new int[1]; // used as a mutable int reference to hold result
304    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
305    rootReference.checkAndSet(root, newRoot);
306    return result[0];
307  }
308
309  @CanIgnoreReturnValue
310  @Override
311  public boolean setCount(@NullableDecl E element, int oldCount, int newCount) {
312    checkNonnegative(newCount, "newCount");
313    checkNonnegative(oldCount, "oldCount");
314    checkArgument(range.contains(element));
315
316    AvlNode<E> root = rootReference.get();
317    if (root == null) {
318      if (oldCount == 0) {
319        if (newCount > 0) {
320          add(element, newCount);
321        }
322        return true;
323      } else {
324        return false;
325      }
326    }
327    int[] result = new int[1]; // used as a mutable int reference to hold result
328    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
329    rootReference.checkAndSet(root, newRoot);
330    return result[0] == oldCount;
331  }
332
333  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
334    return new Multisets.AbstractEntry<E>() {
335      @Override
336      public E getElement() {
337        return baseEntry.getElement();
338      }
339
340      @Override
341      public int getCount() {
342        int result = baseEntry.getCount();
343        if (result == 0) {
344          return count(getElement());
345        } else {
346          return result;
347        }
348      }
349    };
350  }
351
352  /** Returns the first node in the tree that is in range. */
353  @NullableDecl
354  private AvlNode<E> firstNode() {
355    AvlNode<E> root = rootReference.get();
356    if (root == null) {
357      return null;
358    }
359    AvlNode<E> node;
360    if (range.hasLowerBound()) {
361      E endpoint = range.getLowerEndpoint();
362      node = rootReference.get().ceiling(comparator(), endpoint);
363      if (node == null) {
364        return null;
365      }
366      if (range.getLowerBoundType() == BoundType.OPEN
367          && comparator().compare(endpoint, node.getElement()) == 0) {
368        node = node.succ;
369      }
370    } else {
371      node = header.succ;
372    }
373    return (node == header || !range.contains(node.getElement())) ? null : node;
374  }
375
376  @NullableDecl
377  private AvlNode<E> lastNode() {
378    AvlNode<E> root = rootReference.get();
379    if (root == null) {
380      return null;
381    }
382    AvlNode<E> node;
383    if (range.hasUpperBound()) {
384      E endpoint = range.getUpperEndpoint();
385      node = rootReference.get().floor(comparator(), endpoint);
386      if (node == null) {
387        return null;
388      }
389      if (range.getUpperBoundType() == BoundType.OPEN
390          && comparator().compare(endpoint, node.getElement()) == 0) {
391        node = node.pred;
392      }
393    } else {
394      node = header.pred;
395    }
396    return (node == header || !range.contains(node.getElement())) ? null : node;
397  }
398
399  @Override
400  Iterator<Entry<E>> entryIterator() {
401    return new Iterator<Entry<E>>() {
402      AvlNode<E> current = firstNode();
403      Entry<E> prevEntry;
404
405      @Override
406      public boolean hasNext() {
407        if (current == null) {
408          return false;
409        } else if (range.tooHigh(current.getElement())) {
410          current = null;
411          return false;
412        } else {
413          return true;
414        }
415      }
416
417      @Override
418      public Entry<E> next() {
419        if (!hasNext()) {
420          throw new NoSuchElementException();
421        }
422        Entry<E> result = wrapEntry(current);
423        prevEntry = result;
424        if (current.succ == header) {
425          current = null;
426        } else {
427          current = current.succ;
428        }
429        return result;
430      }
431
432      @Override
433      public void remove() {
434        checkRemove(prevEntry != null);
435        setCount(prevEntry.getElement(), 0);
436        prevEntry = null;
437      }
438    };
439  }
440
441  @Override
442  Iterator<Entry<E>> descendingEntryIterator() {
443    return new Iterator<Entry<E>>() {
444      AvlNode<E> current = lastNode();
445      Entry<E> prevEntry = null;
446
447      @Override
448      public boolean hasNext() {
449        if (current == null) {
450          return false;
451        } else if (range.tooLow(current.getElement())) {
452          current = null;
453          return false;
454        } else {
455          return true;
456        }
457      }
458
459      @Override
460      public Entry<E> next() {
461        if (!hasNext()) {
462          throw new NoSuchElementException();
463        }
464        Entry<E> result = wrapEntry(current);
465        prevEntry = result;
466        if (current.pred == header) {
467          current = null;
468        } else {
469          current = current.pred;
470        }
471        return result;
472      }
473
474      @Override
475      public void remove() {
476        checkRemove(prevEntry != null);
477        setCount(prevEntry.getElement(), 0);
478        prevEntry = null;
479      }
480    };
481  }
482
483  @Override
484  public SortedMultiset<E> headMultiset(@NullableDecl E upperBound, BoundType boundType) {
485    return new TreeMultiset<E>(
486        rootReference,
487        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
488        header);
489  }
490
491  @Override
492  public SortedMultiset<E> tailMultiset(@NullableDecl E lowerBound, BoundType boundType) {
493    return new TreeMultiset<E>(
494        rootReference,
495        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
496        header);
497  }
498
499  static int distinctElements(@NullableDecl AvlNode<?> node) {
500    return (node == null) ? 0 : node.distinctElements;
501  }
502
503  private static final class Reference<T> {
504    @NullableDecl private T value;
505
506    @NullableDecl
507    public T get() {
508      return value;
509    }
510
511    public void checkAndSet(@NullableDecl T expected, T newValue) {
512      if (value != expected) {
513        throw new ConcurrentModificationException();
514      }
515      value = newValue;
516    }
517  }
518
519  private static final class AvlNode<E> extends Multisets.AbstractEntry<E> {
520    @NullableDecl private final E elem;
521
522    // elemCount is 0 iff this node has been deleted.
523    private int elemCount;
524
525    private int distinctElements;
526    private long totalCount;
527    private int height;
528    private AvlNode<E> left;
529    private AvlNode<E> right;
530    private AvlNode<E> pred;
531    private AvlNode<E> succ;
532
533    AvlNode(@NullableDecl E elem, int elemCount) {
534      checkArgument(elemCount > 0);
535      this.elem = elem;
536      this.elemCount = elemCount;
537      this.totalCount = elemCount;
538      this.distinctElements = 1;
539      this.height = 1;
540      this.left = null;
541      this.right = null;
542    }
543
544    public int count(Comparator<? super E> comparator, E e) {
545      int cmp = comparator.compare(e, elem);
546      if (cmp < 0) {
547        return (left == null) ? 0 : left.count(comparator, e);
548      } else if (cmp > 0) {
549        return (right == null) ? 0 : right.count(comparator, e);
550      } else {
551        return elemCount;
552      }
553    }
554
555    private AvlNode<E> addRightChild(E e, int count) {
556      right = new AvlNode<E>(e, count);
557      successor(this, right, succ);
558      height = Math.max(2, height);
559      distinctElements++;
560      totalCount += count;
561      return this;
562    }
563
564    private AvlNode<E> addLeftChild(E e, int count) {
565      left = new AvlNode<E>(e, count);
566      successor(pred, left, this);
567      height = Math.max(2, height);
568      distinctElements++;
569      totalCount += count;
570      return this;
571    }
572
573    AvlNode<E> add(Comparator<? super E> comparator, @NullableDecl E e, int count, int[] result) {
574      /*
575       * It speeds things up considerably to unconditionally add count to totalCount here,
576       * but that destroys failure atomicity in the case of count overflow. =(
577       */
578      int cmp = comparator.compare(e, elem);
579      if (cmp < 0) {
580        AvlNode<E> initLeft = left;
581        if (initLeft == null) {
582          result[0] = 0;
583          return addLeftChild(e, count);
584        }
585        int initHeight = initLeft.height;
586
587        left = initLeft.add(comparator, e, count, result);
588        if (result[0] == 0) {
589          distinctElements++;
590        }
591        this.totalCount += count;
592        return (left.height == initHeight) ? this : rebalance();
593      } else if (cmp > 0) {
594        AvlNode<E> initRight = right;
595        if (initRight == null) {
596          result[0] = 0;
597          return addRightChild(e, count);
598        }
599        int initHeight = initRight.height;
600
601        right = initRight.add(comparator, e, count, result);
602        if (result[0] == 0) {
603          distinctElements++;
604        }
605        this.totalCount += count;
606        return (right.height == initHeight) ? this : rebalance();
607      }
608
609      // adding count to me!  No rebalance possible.
610      result[0] = elemCount;
611      long resultCount = (long) elemCount + count;
612      checkArgument(resultCount <= Integer.MAX_VALUE);
613      this.elemCount += count;
614      this.totalCount += count;
615      return this;
616    }
617
618    AvlNode<E> remove(
619        Comparator<? super E> comparator, @NullableDecl E e, int count, int[] result) {
620      int cmp = comparator.compare(e, elem);
621      if (cmp < 0) {
622        AvlNode<E> initLeft = left;
623        if (initLeft == null) {
624          result[0] = 0;
625          return this;
626        }
627
628        left = initLeft.remove(comparator, e, count, result);
629
630        if (result[0] > 0) {
631          if (count >= result[0]) {
632            this.distinctElements--;
633            this.totalCount -= result[0];
634          } else {
635            this.totalCount -= count;
636          }
637        }
638        return (result[0] == 0) ? this : rebalance();
639      } else if (cmp > 0) {
640        AvlNode<E> initRight = right;
641        if (initRight == null) {
642          result[0] = 0;
643          return this;
644        }
645
646        right = initRight.remove(comparator, e, count, result);
647
648        if (result[0] > 0) {
649          if (count >= result[0]) {
650            this.distinctElements--;
651            this.totalCount -= result[0];
652          } else {
653            this.totalCount -= count;
654          }
655        }
656        return rebalance();
657      }
658
659      // removing count from me!
660      result[0] = elemCount;
661      if (count >= elemCount) {
662        return deleteMe();
663      } else {
664        this.elemCount -= count;
665        this.totalCount -= count;
666        return this;
667      }
668    }
669
670    AvlNode<E> setCount(
671        Comparator<? super E> comparator, @NullableDecl E e, int count, int[] result) {
672      int cmp = comparator.compare(e, elem);
673      if (cmp < 0) {
674        AvlNode<E> initLeft = left;
675        if (initLeft == null) {
676          result[0] = 0;
677          return (count > 0) ? addLeftChild(e, count) : this;
678        }
679
680        left = initLeft.setCount(comparator, e, count, result);
681
682        if (count == 0 && result[0] != 0) {
683          this.distinctElements--;
684        } else if (count > 0 && result[0] == 0) {
685          this.distinctElements++;
686        }
687
688        this.totalCount += count - result[0];
689        return rebalance();
690      } else if (cmp > 0) {
691        AvlNode<E> initRight = right;
692        if (initRight == null) {
693          result[0] = 0;
694          return (count > 0) ? addRightChild(e, count) : this;
695        }
696
697        right = initRight.setCount(comparator, e, count, result);
698
699        if (count == 0 && result[0] != 0) {
700          this.distinctElements--;
701        } else if (count > 0 && result[0] == 0) {
702          this.distinctElements++;
703        }
704
705        this.totalCount += count - result[0];
706        return rebalance();
707      }
708
709      // setting my count
710      result[0] = elemCount;
711      if (count == 0) {
712        return deleteMe();
713      }
714      this.totalCount += count - elemCount;
715      this.elemCount = count;
716      return this;
717    }
718
719    AvlNode<E> setCount(
720        Comparator<? super E> comparator,
721        @NullableDecl E e,
722        int expectedCount,
723        int newCount,
724        int[] result) {
725      int cmp = comparator.compare(e, elem);
726      if (cmp < 0) {
727        AvlNode<E> initLeft = left;
728        if (initLeft == null) {
729          result[0] = 0;
730          if (expectedCount == 0 && newCount > 0) {
731            return addLeftChild(e, newCount);
732          }
733          return this;
734        }
735
736        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
737
738        if (result[0] == expectedCount) {
739          if (newCount == 0 && result[0] != 0) {
740            this.distinctElements--;
741          } else if (newCount > 0 && result[0] == 0) {
742            this.distinctElements++;
743          }
744          this.totalCount += newCount - result[0];
745        }
746        return rebalance();
747      } else if (cmp > 0) {
748        AvlNode<E> initRight = right;
749        if (initRight == null) {
750          result[0] = 0;
751          if (expectedCount == 0 && newCount > 0) {
752            return addRightChild(e, newCount);
753          }
754          return this;
755        }
756
757        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
758
759        if (result[0] == expectedCount) {
760          if (newCount == 0 && result[0] != 0) {
761            this.distinctElements--;
762          } else if (newCount > 0 && result[0] == 0) {
763            this.distinctElements++;
764          }
765          this.totalCount += newCount - result[0];
766        }
767        return rebalance();
768      }
769
770      // setting my count
771      result[0] = elemCount;
772      if (expectedCount == elemCount) {
773        if (newCount == 0) {
774          return deleteMe();
775        }
776        this.totalCount += newCount - elemCount;
777        this.elemCount = newCount;
778      }
779      return this;
780    }
781
782    private AvlNode<E> deleteMe() {
783      int oldElemCount = this.elemCount;
784      this.elemCount = 0;
785      successor(pred, succ);
786      if (left == null) {
787        return right;
788      } else if (right == null) {
789        return left;
790      } else if (left.height >= right.height) {
791        AvlNode<E> newTop = pred;
792        // newTop is the maximum node in my left subtree
793        newTop.left = left.removeMax(newTop);
794        newTop.right = right;
795        newTop.distinctElements = distinctElements - 1;
796        newTop.totalCount = totalCount - oldElemCount;
797        return newTop.rebalance();
798      } else {
799        AvlNode<E> newTop = succ;
800        newTop.right = right.removeMin(newTop);
801        newTop.left = left;
802        newTop.distinctElements = distinctElements - 1;
803        newTop.totalCount = totalCount - oldElemCount;
804        return newTop.rebalance();
805      }
806    }
807
808    // Removes the minimum node from this subtree to be reused elsewhere
809    private AvlNode<E> removeMin(AvlNode<E> node) {
810      if (left == null) {
811        return right;
812      } else {
813        left = left.removeMin(node);
814        distinctElements--;
815        totalCount -= node.elemCount;
816        return rebalance();
817      }
818    }
819
820    // Removes the maximum node from this subtree to be reused elsewhere
821    private AvlNode<E> removeMax(AvlNode<E> node) {
822      if (right == null) {
823        return left;
824      } else {
825        right = right.removeMax(node);
826        distinctElements--;
827        totalCount -= node.elemCount;
828        return rebalance();
829      }
830    }
831
832    private void recomputeMultiset() {
833      this.distinctElements =
834          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
835      this.totalCount = elemCount + totalCount(left) + totalCount(right);
836    }
837
838    private void recomputeHeight() {
839      this.height = 1 + Math.max(height(left), height(right));
840    }
841
842    private void recompute() {
843      recomputeMultiset();
844      recomputeHeight();
845    }
846
847    private AvlNode<E> rebalance() {
848      switch (balanceFactor()) {
849        case -2:
850          if (right.balanceFactor() > 0) {
851            right = right.rotateRight();
852          }
853          return rotateLeft();
854        case 2:
855          if (left.balanceFactor() < 0) {
856            left = left.rotateLeft();
857          }
858          return rotateRight();
859        default:
860          recomputeHeight();
861          return this;
862      }
863    }
864
865    private int balanceFactor() {
866      return height(left) - height(right);
867    }
868
869    private AvlNode<E> rotateLeft() {
870      checkState(right != null);
871      AvlNode<E> newTop = right;
872      this.right = newTop.left;
873      newTop.left = this;
874      newTop.totalCount = this.totalCount;
875      newTop.distinctElements = this.distinctElements;
876      this.recompute();
877      newTop.recomputeHeight();
878      return newTop;
879    }
880
881    private AvlNode<E> rotateRight() {
882      checkState(left != null);
883      AvlNode<E> newTop = left;
884      this.left = newTop.right;
885      newTop.right = this;
886      newTop.totalCount = this.totalCount;
887      newTop.distinctElements = this.distinctElements;
888      this.recompute();
889      newTop.recomputeHeight();
890      return newTop;
891    }
892
893    private static long totalCount(@NullableDecl AvlNode<?> node) {
894      return (node == null) ? 0 : node.totalCount;
895    }
896
897    private static int height(@NullableDecl AvlNode<?> node) {
898      return (node == null) ? 0 : node.height;
899    }
900
901    @NullableDecl
902    private AvlNode<E> ceiling(Comparator<? super E> comparator, E e) {
903      int cmp = comparator.compare(e, elem);
904      if (cmp < 0) {
905        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
906      } else if (cmp == 0) {
907        return this;
908      } else {
909        return (right == null) ? null : right.ceiling(comparator, e);
910      }
911    }
912
913    @NullableDecl
914    private AvlNode<E> floor(Comparator<? super E> comparator, E e) {
915      int cmp = comparator.compare(e, elem);
916      if (cmp > 0) {
917        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
918      } else if (cmp == 0) {
919        return this;
920      } else {
921        return (left == null) ? null : left.floor(comparator, e);
922      }
923    }
924
925    @Override
926    public E getElement() {
927      return elem;
928    }
929
930    @Override
931    public int getCount() {
932      return elemCount;
933    }
934
935    @Override
936    public String toString() {
937      return Multisets.immutableEntry(getElement(), getCount()).toString();
938    }
939  }
940
941  private static <T> void successor(AvlNode<T> a, AvlNode<T> b) {
942    a.succ = b;
943    b.pred = a;
944  }
945
946  private static <T> void successor(AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
947    successor(a, b);
948    successor(b, c);
949  }
950
951  /*
952   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
953   * calls the comparator to compare the two keys. If that change is made,
954   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
955   */
956
957  /**
958   * @serialData the comparator, the number of distinct elements, the first element, its count, the
959   *     second element, its count, and so on
960   */
961  @GwtIncompatible // java.io.ObjectOutputStream
962  private void writeObject(ObjectOutputStream stream) throws IOException {
963    stream.defaultWriteObject();
964    stream.writeObject(elementSet().comparator());
965    Serialization.writeMultiset(this, stream);
966  }
967
968  @GwtIncompatible // java.io.ObjectInputStream
969  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
970    stream.defaultReadObject();
971    @SuppressWarnings("unchecked")
972    // reading data stored by writeObject
973    Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
974    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
975    Serialization.getFieldSetter(TreeMultiset.class, "range")
976        .set(this, GeneralRange.all(comparator));
977    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
978        .set(this, new Reference<AvlNode<E>>());
979    AvlNode<E> header = new AvlNode<E>(null, 1);
980    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
981    successor(header, header);
982    Serialization.populateMultiset(this, stream);
983  }
984
985  @GwtIncompatible // not needed in emulated source
986  private static final long serialVersionUID = 1;
987}