001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkState;
021import static com.google.common.collect.CollectPreconditions.checkNonnegative;
022import static com.google.common.collect.CollectPreconditions.checkRemove;
023
024import com.google.common.annotations.GwtCompatible;
025import com.google.common.annotations.GwtIncompatible;
026import com.google.common.base.MoreObjects;
027import com.google.common.primitives.Ints;
028
029import java.io.IOException;
030import java.io.ObjectInputStream;
031import java.io.ObjectOutputStream;
032import java.io.Serializable;
033import java.util.Comparator;
034import java.util.ConcurrentModificationException;
035import java.util.Iterator;
036import java.util.NoSuchElementException;
037
038import javax.annotation.Nullable;
039
040/**
041 * A multiset which maintains the ordering of its elements, according to either their natural order
042 * or an explicit {@link Comparator}. In all cases, this implementation uses
043 * {@link Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to
044 * determine equivalence of instances.
045 *
046 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
047 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the
048 * {@link java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
049 *
050 * <p>See the Guava User Guide article on <a href=
051 * "http://code.google.com/p/guava-libraries/wiki/NewCollectionTypesExplained#Multiset">
052 * {@code Multiset}</a>.
053 *
054 * @author Louis Wasserman
055 * @author Jared Levy
056 * @since 2.0 (imported from Google Collections Library)
057 */
058@GwtCompatible(emulated = true)
059public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements Serializable {
060
061  /**
062   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
063   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
064   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
065   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
066   * user attempts to add an element to the multiset that violates this constraint (for example,
067   * the user attempts to add a string element to a set whose elements are integers), the
068   * {@code add(Object)} call will throw a {@code ClassCastException}.
069   *
070   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
071   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
072   */
073  public static <E extends Comparable> TreeMultiset<E> create() {
074    return new TreeMultiset<E>(Ordering.natural());
075  }
076
077  /**
078   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
079   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
080   * {@code comparator.compare(e1,
081   * e2)} must not throw a {@code ClassCastException} for any elements {@code e1} and {@code e2} in
082   * the multiset. If the user attempts to add an element to the multiset that violates this
083   * constraint, the {@code add(Object)} call will throw a {@code ClassCastException}.
084   *
085   * @param comparator
086   *          the comparator that will be used to sort this multiset. A null value indicates that
087   *          the elements' <i>natural ordering</i> should be used.
088   */
089  @SuppressWarnings("unchecked")
090  public static <E> TreeMultiset<E> create(@Nullable Comparator<? super E> comparator) {
091    return (comparator == null)
092        ? new TreeMultiset<E>((Comparator) Ordering.natural())
093        : new TreeMultiset<E>(comparator);
094  }
095
096  /**
097   * Creates an empty multiset containing the given initial elements, sorted according to the
098   * elements' natural order.
099   *
100   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
101   *
102   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
103   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
104   */
105  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
106    TreeMultiset<E> multiset = create();
107    Iterables.addAll(multiset, elements);
108    return multiset;
109  }
110
111  private final transient Reference<AvlNode<E>> rootReference;
112  private final transient GeneralRange<E> range;
113  private final transient AvlNode<E> header;
114
115  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
116    super(range.comparator());
117    this.rootReference = rootReference;
118    this.range = range;
119    this.header = endLink;
120  }
121
122  TreeMultiset(Comparator<? super E> comparator) {
123    super(comparator);
124    this.range = GeneralRange.all(comparator);
125    this.header = new AvlNode<E>(null, 1);
126    successor(header, header);
127    this.rootReference = new Reference<AvlNode<E>>();
128  }
129
130  /**
131   * A function which can be summed across a subtree.
132   */
133  private enum Aggregate {
134    SIZE {
135      @Override
136      int nodeAggregate(AvlNode<?> node) {
137        return node.elemCount;
138      }
139
140      @Override
141      long treeAggregate(@Nullable AvlNode<?> root) {
142        return (root == null) ? 0 : root.totalCount;
143      }
144    },
145    DISTINCT {
146      @Override
147      int nodeAggregate(AvlNode<?> node) {
148        return 1;
149      }
150
151      @Override
152      long treeAggregate(@Nullable AvlNode<?> root) {
153        return (root == null) ? 0 : root.distinctElements;
154      }
155    };
156    abstract int nodeAggregate(AvlNode<?> node);
157
158    abstract long treeAggregate(@Nullable AvlNode<?> root);
159  }
160
161  private long aggregateForEntries(Aggregate aggr) {
162    AvlNode<E> root = rootReference.get();
163    long total = aggr.treeAggregate(root);
164    if (range.hasLowerBound()) {
165      total -= aggregateBelowRange(aggr, root);
166    }
167    if (range.hasUpperBound()) {
168      total -= aggregateAboveRange(aggr, root);
169    }
170    return total;
171  }
172
173  private long aggregateBelowRange(Aggregate aggr, @Nullable AvlNode<E> node) {
174    if (node == null) {
175      return 0;
176    }
177    int cmp = comparator().compare(range.getLowerEndpoint(), node.elem);
178    if (cmp < 0) {
179      return aggregateBelowRange(aggr, node.left);
180    } else if (cmp == 0) {
181      switch (range.getLowerBoundType()) {
182        case OPEN:
183          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
184        case CLOSED:
185          return aggr.treeAggregate(node.left);
186        default:
187          throw new AssertionError();
188      }
189    } else {
190      return aggr.treeAggregate(node.left) + aggr.nodeAggregate(node)
191          + aggregateBelowRange(aggr, node.right);
192    }
193  }
194
195  private long aggregateAboveRange(Aggregate aggr, @Nullable AvlNode<E> node) {
196    if (node == null) {
197      return 0;
198    }
199    int cmp = comparator().compare(range.getUpperEndpoint(), node.elem);
200    if (cmp > 0) {
201      return aggregateAboveRange(aggr, node.right);
202    } else if (cmp == 0) {
203      switch (range.getUpperBoundType()) {
204        case OPEN:
205          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
206        case CLOSED:
207          return aggr.treeAggregate(node.right);
208        default:
209          throw new AssertionError();
210      }
211    } else {
212      return aggr.treeAggregate(node.right) + aggr.nodeAggregate(node)
213          + aggregateAboveRange(aggr, node.left);
214    }
215  }
216
217  @Override
218  public int size() {
219    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
220  }
221
222  @Override
223  int distinctElements() {
224    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
225  }
226
227  @Override
228  public int count(@Nullable Object element) {
229    try {
230      @SuppressWarnings("unchecked")
231      E e = (E) element;
232      AvlNode<E> root = rootReference.get();
233      if (!range.contains(e) || root == null) {
234        return 0;
235      }
236      return root.count(comparator(), e);
237    } catch (ClassCastException e) {
238      return 0;
239    } catch (NullPointerException e) {
240      return 0;
241    }
242  }
243
244  @Override
245  public int add(@Nullable E element, int occurrences) {
246    checkNonnegative(occurrences, "occurrences");
247    if (occurrences == 0) {
248      return count(element);
249    }
250    checkArgument(range.contains(element));
251    AvlNode<E> root = rootReference.get();
252    if (root == null) {
253      comparator().compare(element, element);
254      AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
255      successor(header, newRoot, header);
256      rootReference.checkAndSet(root, newRoot);
257      return 0;
258    }
259    int[] result = new int[1]; // used as a mutable int reference to hold result
260    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
261    rootReference.checkAndSet(root, newRoot);
262    return result[0];
263  }
264
265  @Override
266  public int remove(@Nullable Object element, int occurrences) {
267    checkNonnegative(occurrences, "occurrences");
268    if (occurrences == 0) {
269      return count(element);
270    }
271    AvlNode<E> root = rootReference.get();
272    int[] result = new int[1]; // used as a mutable int reference to hold result
273    AvlNode<E> newRoot;
274    try {
275      @SuppressWarnings("unchecked")
276      E e = (E) element;
277      if (!range.contains(e) || root == null) {
278        return 0;
279      }
280      newRoot = root.remove(comparator(), e, occurrences, result);
281    } catch (ClassCastException e) {
282      return 0;
283    } catch (NullPointerException e) {
284      return 0;
285    }
286    rootReference.checkAndSet(root, newRoot);
287    return result[0];
288  }
289
290  @Override
291  public int setCount(@Nullable E element, int count) {
292    checkNonnegative(count, "count");
293    if (!range.contains(element)) {
294      checkArgument(count == 0);
295      return 0;
296    }
297
298    AvlNode<E> root = rootReference.get();
299    if (root == null) {
300      if (count > 0) {
301        add(element, count);
302      }
303      return 0;
304    }
305    int[] result = new int[1]; // used as a mutable int reference to hold result
306    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
307    rootReference.checkAndSet(root, newRoot);
308    return result[0];
309  }
310
311  @Override
312  public boolean setCount(@Nullable E element, int oldCount, int newCount) {
313    checkNonnegative(newCount, "newCount");
314    checkNonnegative(oldCount, "oldCount");
315    checkArgument(range.contains(element));
316
317    AvlNode<E> root = rootReference.get();
318    if (root == null) {
319      if (oldCount == 0) {
320        if (newCount > 0) {
321          add(element, newCount);
322        }
323        return true;
324      } else {
325        return false;
326      }
327    }
328    int[] result = new int[1]; // used as a mutable int reference to hold result
329    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
330    rootReference.checkAndSet(root, newRoot);
331    return result[0] == oldCount;
332  }
333
334  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
335    return new Multisets.AbstractEntry<E>() {
336      @Override
337      public E getElement() {
338        return baseEntry.getElement();
339      }
340
341      @Override
342      public int getCount() {
343        int result = baseEntry.getCount();
344        if (result == 0) {
345          return count(getElement());
346        } else {
347          return result;
348        }
349      }
350    };
351  }
352
353  /**
354   * Returns the first node in the tree that is in range.
355   */
356  @Nullable private AvlNode<E> firstNode() {
357    AvlNode<E> root = rootReference.get();
358    if (root == null) {
359      return null;
360    }
361    AvlNode<E> node;
362    if (range.hasLowerBound()) {
363      E endpoint = range.getLowerEndpoint();
364      node = rootReference.get().ceiling(comparator(), endpoint);
365      if (node == null) {
366        return null;
367      }
368      if (range.getLowerBoundType() == BoundType.OPEN
369          && comparator().compare(endpoint, node.getElement()) == 0) {
370        node = node.succ;
371      }
372    } else {
373      node = header.succ;
374    }
375    return (node == header || !range.contains(node.getElement())) ? null : node;
376  }
377
378  @Nullable private AvlNode<E> lastNode() {
379    AvlNode<E> root = rootReference.get();
380    if (root == null) {
381      return null;
382    }
383    AvlNode<E> node;
384    if (range.hasUpperBound()) {
385      E endpoint = range.getUpperEndpoint();
386      node = rootReference.get().floor(comparator(), endpoint);
387      if (node == null) {
388        return null;
389      }
390      if (range.getUpperBoundType() == BoundType.OPEN
391          && comparator().compare(endpoint, node.getElement()) == 0) {
392        node = node.pred;
393      }
394    } else {
395      node = header.pred;
396    }
397    return (node == header || !range.contains(node.getElement())) ? null : node;
398  }
399
400  @Override
401  Iterator<Entry<E>> entryIterator() {
402    return new Iterator<Entry<E>>() {
403      AvlNode<E> current = firstNode();
404      Entry<E> prevEntry;
405
406      @Override
407      public boolean hasNext() {
408        if (current == null) {
409          return false;
410        } else if (range.tooHigh(current.getElement())) {
411          current = null;
412          return false;
413        } else {
414          return true;
415        }
416      }
417
418      @Override
419      public Entry<E> next() {
420        if (!hasNext()) {
421          throw new NoSuchElementException();
422        }
423        Entry<E> result = wrapEntry(current);
424        prevEntry = result;
425        if (current.succ == header) {
426          current = null;
427        } else {
428          current = current.succ;
429        }
430        return result;
431      }
432
433      @Override
434      public void remove() {
435        checkRemove(prevEntry != null);
436        setCount(prevEntry.getElement(), 0);
437        prevEntry = null;
438      }
439    };
440  }
441
442  @Override
443  Iterator<Entry<E>> descendingEntryIterator() {
444    return new Iterator<Entry<E>>() {
445      AvlNode<E> current = lastNode();
446      Entry<E> prevEntry = null;
447
448      @Override
449      public boolean hasNext() {
450        if (current == null) {
451          return false;
452        } else if (range.tooLow(current.getElement())) {
453          current = null;
454          return false;
455        } else {
456          return true;
457        }
458      }
459
460      @Override
461      public Entry<E> next() {
462        if (!hasNext()) {
463          throw new NoSuchElementException();
464        }
465        Entry<E> result = wrapEntry(current);
466        prevEntry = result;
467        if (current.pred == header) {
468          current = null;
469        } else {
470          current = current.pred;
471        }
472        return result;
473      }
474
475      @Override
476      public void remove() {
477        checkRemove(prevEntry != null);
478        setCount(prevEntry.getElement(), 0);
479        prevEntry = null;
480      }
481    };
482  }
483
484  @Override
485  public SortedMultiset<E> headMultiset(@Nullable E upperBound, BoundType boundType) {
486    return new TreeMultiset<E>(rootReference, range.intersect(GeneralRange.upTo(
487        comparator(),
488        upperBound,
489        boundType)), header);
490  }
491
492  @Override
493  public SortedMultiset<E> tailMultiset(@Nullable E lowerBound, BoundType boundType) {
494    return new TreeMultiset<E>(rootReference, range.intersect(GeneralRange.downTo(
495        comparator(),
496        lowerBound,
497        boundType)), header);
498  }
499
500  static int distinctElements(@Nullable AvlNode<?> node) {
501    return (node == null) ? 0 : node.distinctElements;
502  }
503
504  private static final class Reference<T> {
505    @Nullable private T value;
506
507    @Nullable public T get() {
508      return value;
509    }
510
511    public void checkAndSet(@Nullable T expected, T newValue) {
512      if (value != expected) {
513        throw new ConcurrentModificationException();
514      }
515      value = newValue;
516    }
517  }
518
519  private static final class AvlNode<E> extends Multisets.AbstractEntry<E> {
520    @Nullable private final E elem;
521
522    // elemCount is 0 iff this node has been deleted.
523    private int elemCount;
524
525    private int distinctElements;
526    private long totalCount;
527    private int height;
528    private AvlNode<E> left;
529    private AvlNode<E> right;
530    private AvlNode<E> pred;
531    private AvlNode<E> succ;
532
533    AvlNode(@Nullable E elem, int elemCount) {
534      checkArgument(elemCount > 0);
535      this.elem = elem;
536      this.elemCount = elemCount;
537      this.totalCount = elemCount;
538      this.distinctElements = 1;
539      this.height = 1;
540      this.left = null;
541      this.right = null;
542    }
543
544    public int count(Comparator<? super E> comparator, E e) {
545      int cmp = comparator.compare(e, elem);
546      if (cmp < 0) {
547        return (left == null) ? 0 : left.count(comparator, e);
548      } else if (cmp > 0) {
549        return (right == null) ? 0 : right.count(comparator, e);
550      } else {
551        return elemCount;
552      }
553    }
554
555    private AvlNode<E> addRightChild(E e, int count) {
556      right = new AvlNode<E>(e, count);
557      successor(this, right, succ);
558      height = Math.max(2, height);
559      distinctElements++;
560      totalCount += count;
561      return this;
562    }
563
564    private AvlNode<E> addLeftChild(E e, int count) {
565      left = new AvlNode<E>(e, count);
566      successor(pred, left, this);
567      height = Math.max(2, height);
568      distinctElements++;
569      totalCount += count;
570      return this;
571    }
572
573    AvlNode<E> add(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
574      /*
575       * It speeds things up considerably to unconditionally add count to totalCount here,
576       * but that destroys failure atomicity in the case of count overflow. =(
577       */
578      int cmp = comparator.compare(e, elem);
579      if (cmp < 0) {
580        AvlNode<E> initLeft = left;
581        if (initLeft == null) {
582          result[0] = 0;
583          return addLeftChild(e, count);
584        }
585        int initHeight = initLeft.height;
586
587        left = initLeft.add(comparator, e, count, result);
588        if (result[0] == 0) {
589          distinctElements++;
590        }
591        this.totalCount += count;
592        return (left.height == initHeight) ? this : rebalance();
593      } else if (cmp > 0) {
594        AvlNode<E> initRight = right;
595        if (initRight == null) {
596          result[0] = 0;
597          return addRightChild(e, count);
598        }
599        int initHeight = initRight.height;
600
601        right = initRight.add(comparator, e, count, result);
602        if (result[0] == 0) {
603          distinctElements++;
604        }
605        this.totalCount += count;
606        return (right.height == initHeight) ? this : rebalance();
607      }
608
609      // adding count to me!  No rebalance possible.
610      result[0] = elemCount;
611      long resultCount = (long) elemCount + count;
612      checkArgument(resultCount <= Integer.MAX_VALUE);
613      this.elemCount += count;
614      this.totalCount += count;
615      return this;
616    }
617
618    AvlNode<E> remove(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
619      int cmp = comparator.compare(e, elem);
620      if (cmp < 0) {
621        AvlNode<E> initLeft = left;
622        if (initLeft == null) {
623          result[0] = 0;
624          return this;
625        }
626
627        left = initLeft.remove(comparator, e, count, result);
628
629        if (result[0] > 0) {
630          if (count >= result[0]) {
631            this.distinctElements--;
632            this.totalCount -= result[0];
633          } else {
634            this.totalCount -= count;
635          }
636        }
637        return (result[0] == 0) ? this : rebalance();
638      } else if (cmp > 0) {
639        AvlNode<E> initRight = right;
640        if (initRight == null) {
641          result[0] = 0;
642          return this;
643        }
644
645        right = initRight.remove(comparator, e, count, result);
646
647        if (result[0] > 0) {
648          if (count >= result[0]) {
649            this.distinctElements--;
650            this.totalCount -= result[0];
651          } else {
652            this.totalCount -= count;
653          }
654        }
655        return rebalance();
656      }
657
658      // removing count from me!
659      result[0] = elemCount;
660      if (count >= elemCount) {
661        return deleteMe();
662      } else {
663        this.elemCount -= count;
664        this.totalCount -= count;
665        return this;
666      }
667    }
668
669    AvlNode<E> setCount(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
670      int cmp = comparator.compare(e, elem);
671      if (cmp < 0) {
672        AvlNode<E> initLeft = left;
673        if (initLeft == null) {
674          result[0] = 0;
675          return (count > 0) ? addLeftChild(e, count) : this;
676        }
677
678        left = initLeft.setCount(comparator, e, count, result);
679
680        if (count == 0 && result[0] != 0) {
681          this.distinctElements--;
682        } else if (count > 0 && result[0] == 0) {
683          this.distinctElements++;
684        }
685
686        this.totalCount += count - result[0];
687        return rebalance();
688      } else if (cmp > 0) {
689        AvlNode<E> initRight = right;
690        if (initRight == null) {
691          result[0] = 0;
692          return (count > 0) ? addRightChild(e, count) : this;
693        }
694
695        right = initRight.setCount(comparator, e, count, result);
696
697        if (count == 0 && result[0] != 0) {
698          this.distinctElements--;
699        } else if (count > 0 && result[0] == 0) {
700          this.distinctElements++;
701        }
702
703        this.totalCount += count - result[0];
704        return rebalance();
705      }
706
707      // setting my count
708      result[0] = elemCount;
709      if (count == 0) {
710        return deleteMe();
711      }
712      this.totalCount += count - elemCount;
713      this.elemCount = count;
714      return this;
715    }
716
717    AvlNode<E> setCount(
718        Comparator<? super E> comparator,
719        @Nullable E e,
720        int expectedCount,
721        int newCount,
722        int[] result) {
723      int cmp = comparator.compare(e, elem);
724      if (cmp < 0) {
725        AvlNode<E> initLeft = left;
726        if (initLeft == null) {
727          result[0] = 0;
728          if (expectedCount == 0 && newCount > 0) {
729            return addLeftChild(e, newCount);
730          }
731          return this;
732        }
733
734        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
735
736        if (result[0] == expectedCount) {
737          if (newCount == 0 && result[0] != 0) {
738            this.distinctElements--;
739          } else if (newCount > 0 && result[0] == 0) {
740            this.distinctElements++;
741          }
742          this.totalCount += newCount - result[0];
743        }
744        return rebalance();
745      } else if (cmp > 0) {
746        AvlNode<E> initRight = right;
747        if (initRight == null) {
748          result[0] = 0;
749          if (expectedCount == 0 && newCount > 0) {
750            return addRightChild(e, newCount);
751          }
752          return this;
753        }
754
755        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
756
757        if (result[0] == expectedCount) {
758          if (newCount == 0 && result[0] != 0) {
759            this.distinctElements--;
760          } else if (newCount > 0 && result[0] == 0) {
761            this.distinctElements++;
762          }
763          this.totalCount += newCount - result[0];
764        }
765        return rebalance();
766      }
767
768      // setting my count
769      result[0] = elemCount;
770      if (expectedCount == elemCount) {
771        if (newCount == 0) {
772          return deleteMe();
773        }
774        this.totalCount += newCount - elemCount;
775        this.elemCount = newCount;
776      }
777      return this;
778    }
779
780    private AvlNode<E> deleteMe() {
781      int oldElemCount = this.elemCount;
782      this.elemCount = 0;
783      successor(pred, succ);
784      if (left == null) {
785        return right;
786      } else if (right == null) {
787        return left;
788      } else if (left.height >= right.height) {
789        AvlNode<E> newTop = pred;
790        // newTop is the maximum node in my left subtree
791        newTop.left = left.removeMax(newTop);
792        newTop.right = right;
793        newTop.distinctElements = distinctElements - 1;
794        newTop.totalCount = totalCount - oldElemCount;
795        return newTop.rebalance();
796      } else {
797        AvlNode<E> newTop = succ;
798        newTop.right = right.removeMin(newTop);
799        newTop.left = left;
800        newTop.distinctElements = distinctElements - 1;
801        newTop.totalCount = totalCount - oldElemCount;
802        return newTop.rebalance();
803      }
804    }
805
806    // Removes the minimum node from this subtree to be reused elsewhere
807    private AvlNode<E> removeMin(AvlNode<E> node) {
808      if (left == null) {
809        return right;
810      } else {
811        left = left.removeMin(node);
812        distinctElements--;
813        totalCount -= node.elemCount;
814        return rebalance();
815      }
816    }
817
818    // Removes the maximum node from this subtree to be reused elsewhere
819    private AvlNode<E> removeMax(AvlNode<E> node) {
820      if (right == null) {
821        return left;
822      } else {
823        right = right.removeMax(node);
824        distinctElements--;
825        totalCount -= node.elemCount;
826        return rebalance();
827      }
828    }
829
830    private void recomputeMultiset() {
831      this.distinctElements = 1 + TreeMultiset.distinctElements(left)
832          + TreeMultiset.distinctElements(right);
833      this.totalCount = elemCount + totalCount(left) + totalCount(right);
834    }
835
836    private void recomputeHeight() {
837      this.height = 1 + Math.max(height(left), height(right));
838    }
839
840    private void recompute() {
841      recomputeMultiset();
842      recomputeHeight();
843    }
844
845    private AvlNode<E> rebalance() {
846      switch (balanceFactor()) {
847        case -2:
848          if (right.balanceFactor() > 0) {
849            right = right.rotateRight();
850          }
851          return rotateLeft();
852        case 2:
853          if (left.balanceFactor() < 0) {
854            left = left.rotateLeft();
855          }
856          return rotateRight();
857        default:
858          recomputeHeight();
859          return this;
860      }
861    }
862
863    private int balanceFactor() {
864      return height(left) - height(right);
865    }
866
867    private AvlNode<E> rotateLeft() {
868      checkState(right != null);
869      AvlNode<E> newTop = right;
870      this.right = newTop.left;
871      newTop.left = this;
872      newTop.totalCount = this.totalCount;
873      newTop.distinctElements = this.distinctElements;
874      this.recompute();
875      newTop.recomputeHeight();
876      return newTop;
877    }
878
879    private AvlNode<E> rotateRight() {
880      checkState(left != null);
881      AvlNode<E> newTop = left;
882      this.left = newTop.right;
883      newTop.right = this;
884      newTop.totalCount = this.totalCount;
885      newTop.distinctElements = this.distinctElements;
886      this.recompute();
887      newTop.recomputeHeight();
888      return newTop;
889    }
890
891    private static long totalCount(@Nullable AvlNode<?> node) {
892      return (node == null) ? 0 : node.totalCount;
893    }
894
895    private static int height(@Nullable AvlNode<?> node) {
896      return (node == null) ? 0 : node.height;
897    }
898
899    @Nullable private AvlNode<E> ceiling(Comparator<? super E> comparator, E e) {
900      int cmp = comparator.compare(e, elem);
901      if (cmp < 0) {
902        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
903      } else if (cmp == 0) {
904        return this;
905      } else {
906        return (right == null) ? null : right.ceiling(comparator, e);
907      }
908    }
909
910    @Nullable private AvlNode<E> floor(Comparator<? super E> comparator, E e) {
911      int cmp = comparator.compare(e, elem);
912      if (cmp > 0) {
913        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
914      } else if (cmp == 0) {
915        return this;
916      } else {
917        return (left == null) ? null : left.floor(comparator, e);
918      }
919    }
920
921    @Override
922    public E getElement() {
923      return elem;
924    }
925
926    @Override
927    public int getCount() {
928      return elemCount;
929    }
930
931    @Override
932    public String toString() {
933      return Multisets.immutableEntry(getElement(), getCount()).toString();
934    }
935  }
936
937  private static <T> void successor(AvlNode<T> a, AvlNode<T> b) {
938    a.succ = b;
939    b.pred = a;
940  }
941
942  private static <T> void successor(AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
943    successor(a, b);
944    successor(b, c);
945  }
946
947  /*
948   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
949   * calls the comparator to compare the two keys. If that change is made,
950   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
951   */
952
953  /**
954   * @serialData the comparator, the number of distinct elements, the first element, its count, the
955   *             second element, its count, and so on
956   */
957  @GwtIncompatible("java.io.ObjectOutputStream")
958  private void writeObject(ObjectOutputStream stream) throws IOException {
959    stream.defaultWriteObject();
960    stream.writeObject(elementSet().comparator());
961    Serialization.writeMultiset(this, stream);
962  }
963
964  @GwtIncompatible("java.io.ObjectInputStream")
965  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
966    stream.defaultReadObject();
967    @SuppressWarnings("unchecked")
968    // reading data stored by writeObject
969    Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
970    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
971    Serialization.getFieldSetter(TreeMultiset.class, "range").set(
972        this,
973        GeneralRange.all(comparator));
974    Serialization.getFieldSetter(TreeMultiset.class, "rootReference").set(
975        this,
976        new Reference<AvlNode<E>>());
977    AvlNode<E> header = new AvlNode<E>(null, 1);
978    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
979    successor(header, header);
980    Serialization.populateMultiset(this, stream);
981  }
982
983  @GwtIncompatible("not needed in emulated source") private static final long serialVersionUID = 1;
984}