001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkState;
021import static com.google.common.collect.CollectPreconditions.checkNonnegative;
022import static com.google.common.collect.CollectPreconditions.checkRemove;
023
024import com.google.common.annotations.GwtCompatible;
025import com.google.common.annotations.GwtIncompatible;
026import com.google.common.base.MoreObjects;
027import com.google.common.primitives.Ints;
028
029import java.io.IOException;
030import java.io.ObjectInputStream;
031import java.io.ObjectOutputStream;
032import java.io.Serializable;
033import java.util.Comparator;
034import java.util.ConcurrentModificationException;
035import java.util.Iterator;
036import java.util.NoSuchElementException;
037
038import javax.annotation.Nullable;
039
040/**
041 * A multiset which maintains the ordering of its elements, according to either their natural order
042 * or an explicit {@link Comparator}. In all cases, this implementation uses
043 * {@link Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to
044 * determine equivalence of instances.
045 *
046 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
047 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the
048 * {@link java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
049 *
050 * <p>See the Guava User Guide article on <a href=
051 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset">
052 * {@code Multiset}</a>.
053 *
054 * @author Louis Wasserman
055 * @author Jared Levy
056 * @since 2.0
057 */
058@GwtCompatible(emulated = true)
059public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements Serializable {
060
061  /**
062   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
063   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
064   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
065   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
066   * user attempts to add an element to the multiset that violates this constraint (for example,
067   * the user attempts to add a string element to a set whose elements are integers), the
068   * {@code add(Object)} call will throw a {@code ClassCastException}.
069   *
070   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
071   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
072   */
073  public static <E extends Comparable> TreeMultiset<E> create() {
074    return new TreeMultiset<E>(Ordering.natural());
075  }
076
077  /**
078   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
079   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
080   * {@code comparator.compare(e1,
081   * e2)} must not throw a {@code ClassCastException} for any elements {@code e1} and {@code e2} in
082   * the multiset. If the user attempts to add an element to the multiset that violates this
083   * constraint, the {@code add(Object)} call will throw a {@code ClassCastException}.
084   *
085   * @param comparator
086   *          the comparator that will be used to sort this multiset. A null value indicates that
087   *          the elements' <i>natural ordering</i> should be used.
088   */
089  @SuppressWarnings("unchecked")
090  public static <E> TreeMultiset<E> create(@Nullable Comparator<? super E> comparator) {
091    return (comparator == null)
092        ? new TreeMultiset<E>((Comparator) Ordering.natural())
093        : new TreeMultiset<E>(comparator);
094  }
095
096  /**
097   * Creates an empty multiset containing the given initial elements, sorted according to the
098   * elements' natural order.
099   *
100   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
101   *
102   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
103   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
104   */
105  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
106    TreeMultiset<E> multiset = create();
107    Iterables.addAll(multiset, elements);
108    return multiset;
109  }
110
111  private final transient Reference<AvlNode<E>> rootReference;
112  private final transient GeneralRange<E> range;
113  private final transient AvlNode<E> header;
114
115  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
116    super(range.comparator());
117    this.rootReference = rootReference;
118    this.range = range;
119    this.header = endLink;
120  }
121
122  TreeMultiset(Comparator<? super E> comparator) {
123    super(comparator);
124    this.range = GeneralRange.all(comparator);
125    this.header = new AvlNode<E>(null, 1);
126    successor(header, header);
127    this.rootReference = new Reference<AvlNode<E>>();
128  }
129
130  /**
131   * A function which can be summed across a subtree.
132   */
133  private enum Aggregate {
134    SIZE {
135      @Override
136      int nodeAggregate(AvlNode<?> node) {
137        return node.elemCount;
138      }
139
140      @Override
141      long treeAggregate(@Nullable AvlNode<?> root) {
142        return (root == null) ? 0 : root.totalCount;
143      }
144    },
145    DISTINCT {
146      @Override
147      int nodeAggregate(AvlNode<?> node) {
148        return 1;
149      }
150
151      @Override
152      long treeAggregate(@Nullable AvlNode<?> root) {
153        return (root == null) ? 0 : root.distinctElements;
154      }
155    };
156
157    abstract int nodeAggregate(AvlNode<?> node);
158
159    abstract long treeAggregate(@Nullable AvlNode<?> root);
160  }
161
162  private long aggregateForEntries(Aggregate aggr) {
163    AvlNode<E> root = rootReference.get();
164    long total = aggr.treeAggregate(root);
165    if (range.hasLowerBound()) {
166      total -= aggregateBelowRange(aggr, root);
167    }
168    if (range.hasUpperBound()) {
169      total -= aggregateAboveRange(aggr, root);
170    }
171    return total;
172  }
173
174  private long aggregateBelowRange(Aggregate aggr, @Nullable AvlNode<E> node) {
175    if (node == null) {
176      return 0;
177    }
178    int cmp = comparator().compare(range.getLowerEndpoint(), node.elem);
179    if (cmp < 0) {
180      return aggregateBelowRange(aggr, node.left);
181    } else if (cmp == 0) {
182      switch (range.getLowerBoundType()) {
183        case OPEN:
184          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
185        case CLOSED:
186          return aggr.treeAggregate(node.left);
187        default:
188          throw new AssertionError();
189      }
190    } else {
191      return aggr.treeAggregate(node.left)
192          + aggr.nodeAggregate(node)
193          + aggregateBelowRange(aggr, node.right);
194    }
195  }
196
197  private long aggregateAboveRange(Aggregate aggr, @Nullable AvlNode<E> node) {
198    if (node == null) {
199      return 0;
200    }
201    int cmp = comparator().compare(range.getUpperEndpoint(), node.elem);
202    if (cmp > 0) {
203      return aggregateAboveRange(aggr, node.right);
204    } else if (cmp == 0) {
205      switch (range.getUpperBoundType()) {
206        case OPEN:
207          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
208        case CLOSED:
209          return aggr.treeAggregate(node.right);
210        default:
211          throw new AssertionError();
212      }
213    } else {
214      return aggr.treeAggregate(node.right)
215          + aggr.nodeAggregate(node)
216          + aggregateAboveRange(aggr, node.left);
217    }
218  }
219
220  @Override
221  public int size() {
222    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
223  }
224
225  @Override
226  int distinctElements() {
227    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
228  }
229
230  @Override
231  public int count(@Nullable Object element) {
232    try {
233      @SuppressWarnings("unchecked")
234      E e = (E) element;
235      AvlNode<E> root = rootReference.get();
236      if (!range.contains(e) || root == null) {
237        return 0;
238      }
239      return root.count(comparator(), e);
240    } catch (ClassCastException e) {
241      return 0;
242    } catch (NullPointerException e) {
243      return 0;
244    }
245  }
246
247  @Override
248  public int add(@Nullable E element, int occurrences) {
249    checkNonnegative(occurrences, "occurrences");
250    if (occurrences == 0) {
251      return count(element);
252    }
253    checkArgument(range.contains(element));
254    AvlNode<E> root = rootReference.get();
255    if (root == null) {
256      comparator().compare(element, element);
257      AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
258      successor(header, newRoot, header);
259      rootReference.checkAndSet(root, newRoot);
260      return 0;
261    }
262    int[] result = new int[1]; // used as a mutable int reference to hold result
263    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
264    rootReference.checkAndSet(root, newRoot);
265    return result[0];
266  }
267
268  @Override
269  public int remove(@Nullable Object element, int occurrences) {
270    checkNonnegative(occurrences, "occurrences");
271    if (occurrences == 0) {
272      return count(element);
273    }
274    AvlNode<E> root = rootReference.get();
275    int[] result = new int[1]; // used as a mutable int reference to hold result
276    AvlNode<E> newRoot;
277    try {
278      @SuppressWarnings("unchecked")
279      E e = (E) element;
280      if (!range.contains(e) || root == null) {
281        return 0;
282      }
283      newRoot = root.remove(comparator(), e, occurrences, result);
284    } catch (ClassCastException e) {
285      return 0;
286    } catch (NullPointerException e) {
287      return 0;
288    }
289    rootReference.checkAndSet(root, newRoot);
290    return result[0];
291  }
292
293  @Override
294  public int setCount(@Nullable E element, int count) {
295    checkNonnegative(count, "count");
296    if (!range.contains(element)) {
297      checkArgument(count == 0);
298      return 0;
299    }
300
301    AvlNode<E> root = rootReference.get();
302    if (root == null) {
303      if (count > 0) {
304        add(element, count);
305      }
306      return 0;
307    }
308    int[] result = new int[1]; // used as a mutable int reference to hold result
309    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
310    rootReference.checkAndSet(root, newRoot);
311    return result[0];
312  }
313
314  @Override
315  public boolean setCount(@Nullable E element, int oldCount, int newCount) {
316    checkNonnegative(newCount, "newCount");
317    checkNonnegative(oldCount, "oldCount");
318    checkArgument(range.contains(element));
319
320    AvlNode<E> root = rootReference.get();
321    if (root == null) {
322      if (oldCount == 0) {
323        if (newCount > 0) {
324          add(element, newCount);
325        }
326        return true;
327      } else {
328        return false;
329      }
330    }
331    int[] result = new int[1]; // used as a mutable int reference to hold result
332    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
333    rootReference.checkAndSet(root, newRoot);
334    return result[0] == oldCount;
335  }
336
337  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
338    return new Multisets.AbstractEntry<E>() {
339      @Override
340      public E getElement() {
341        return baseEntry.getElement();
342      }
343
344      @Override
345      public int getCount() {
346        int result = baseEntry.getCount();
347        if (result == 0) {
348          return count(getElement());
349        } else {
350          return result;
351        }
352      }
353    };
354  }
355
356  /**
357   * Returns the first node in the tree that is in range.
358   */
359  @Nullable
360  private AvlNode<E> firstNode() {
361    AvlNode<E> root = rootReference.get();
362    if (root == null) {
363      return null;
364    }
365    AvlNode<E> node;
366    if (range.hasLowerBound()) {
367      E endpoint = range.getLowerEndpoint();
368      node = rootReference.get().ceiling(comparator(), endpoint);
369      if (node == null) {
370        return null;
371      }
372      if (range.getLowerBoundType() == BoundType.OPEN
373          && comparator().compare(endpoint, node.getElement()) == 0) {
374        node = node.succ;
375      }
376    } else {
377      node = header.succ;
378    }
379    return (node == header || !range.contains(node.getElement())) ? null : node;
380  }
381
382  @Nullable
383  private AvlNode<E> lastNode() {
384    AvlNode<E> root = rootReference.get();
385    if (root == null) {
386      return null;
387    }
388    AvlNode<E> node;
389    if (range.hasUpperBound()) {
390      E endpoint = range.getUpperEndpoint();
391      node = rootReference.get().floor(comparator(), endpoint);
392      if (node == null) {
393        return null;
394      }
395      if (range.getUpperBoundType() == BoundType.OPEN
396          && comparator().compare(endpoint, node.getElement()) == 0) {
397        node = node.pred;
398      }
399    } else {
400      node = header.pred;
401    }
402    return (node == header || !range.contains(node.getElement())) ? null : node;
403  }
404
405  @Override
406  Iterator<Entry<E>> entryIterator() {
407    return new Iterator<Entry<E>>() {
408      AvlNode<E> current = firstNode();
409      Entry<E> prevEntry;
410
411      @Override
412      public boolean hasNext() {
413        if (current == null) {
414          return false;
415        } else if (range.tooHigh(current.getElement())) {
416          current = null;
417          return false;
418        } else {
419          return true;
420        }
421      }
422
423      @Override
424      public Entry<E> next() {
425        if (!hasNext()) {
426          throw new NoSuchElementException();
427        }
428        Entry<E> result = wrapEntry(current);
429        prevEntry = result;
430        if (current.succ == header) {
431          current = null;
432        } else {
433          current = current.succ;
434        }
435        return result;
436      }
437
438      @Override
439      public void remove() {
440        checkRemove(prevEntry != null);
441        setCount(prevEntry.getElement(), 0);
442        prevEntry = null;
443      }
444    };
445  }
446
447  @Override
448  Iterator<Entry<E>> descendingEntryIterator() {
449    return new Iterator<Entry<E>>() {
450      AvlNode<E> current = lastNode();
451      Entry<E> prevEntry = null;
452
453      @Override
454      public boolean hasNext() {
455        if (current == null) {
456          return false;
457        } else if (range.tooLow(current.getElement())) {
458          current = null;
459          return false;
460        } else {
461          return true;
462        }
463      }
464
465      @Override
466      public Entry<E> next() {
467        if (!hasNext()) {
468          throw new NoSuchElementException();
469        }
470        Entry<E> result = wrapEntry(current);
471        prevEntry = result;
472        if (current.pred == header) {
473          current = null;
474        } else {
475          current = current.pred;
476        }
477        return result;
478      }
479
480      @Override
481      public void remove() {
482        checkRemove(prevEntry != null);
483        setCount(prevEntry.getElement(), 0);
484        prevEntry = null;
485      }
486    };
487  }
488
489  @Override
490  public SortedMultiset<E> headMultiset(@Nullable E upperBound, BoundType boundType) {
491    return new TreeMultiset<E>(
492        rootReference,
493        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
494        header);
495  }
496
497  @Override
498  public SortedMultiset<E> tailMultiset(@Nullable E lowerBound, BoundType boundType) {
499    return new TreeMultiset<E>(
500        rootReference,
501        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
502        header);
503  }
504
505  static int distinctElements(@Nullable AvlNode<?> node) {
506    return (node == null) ? 0 : node.distinctElements;
507  }
508
509  private static final class Reference<T> {
510    @Nullable private T value;
511
512    @Nullable
513    public T get() {
514      return value;
515    }
516
517    public void checkAndSet(@Nullable T expected, T newValue) {
518      if (value != expected) {
519        throw new ConcurrentModificationException();
520      }
521      value = newValue;
522    }
523  }
524
525  private static final class AvlNode<E> extends Multisets.AbstractEntry<E> {
526    @Nullable private final E elem;
527
528    // elemCount is 0 iff this node has been deleted.
529    private int elemCount;
530
531    private int distinctElements;
532    private long totalCount;
533    private int height;
534    private AvlNode<E> left;
535    private AvlNode<E> right;
536    private AvlNode<E> pred;
537    private AvlNode<E> succ;
538
539    AvlNode(@Nullable E elem, int elemCount) {
540      checkArgument(elemCount > 0);
541      this.elem = elem;
542      this.elemCount = elemCount;
543      this.totalCount = elemCount;
544      this.distinctElements = 1;
545      this.height = 1;
546      this.left = null;
547      this.right = null;
548    }
549
550    public int count(Comparator<? super E> comparator, E e) {
551      int cmp = comparator.compare(e, elem);
552      if (cmp < 0) {
553        return (left == null) ? 0 : left.count(comparator, e);
554      } else if (cmp > 0) {
555        return (right == null) ? 0 : right.count(comparator, e);
556      } else {
557        return elemCount;
558      }
559    }
560
561    private AvlNode<E> addRightChild(E e, int count) {
562      right = new AvlNode<E>(e, count);
563      successor(this, right, succ);
564      height = Math.max(2, height);
565      distinctElements++;
566      totalCount += count;
567      return this;
568    }
569
570    private AvlNode<E> addLeftChild(E e, int count) {
571      left = new AvlNode<E>(e, count);
572      successor(pred, left, this);
573      height = Math.max(2, height);
574      distinctElements++;
575      totalCount += count;
576      return this;
577    }
578
579    AvlNode<E> add(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
580      /*
581       * It speeds things up considerably to unconditionally add count to totalCount here,
582       * but that destroys failure atomicity in the case of count overflow. =(
583       */
584      int cmp = comparator.compare(e, elem);
585      if (cmp < 0) {
586        AvlNode<E> initLeft = left;
587        if (initLeft == null) {
588          result[0] = 0;
589          return addLeftChild(e, count);
590        }
591        int initHeight = initLeft.height;
592
593        left = initLeft.add(comparator, e, count, result);
594        if (result[0] == 0) {
595          distinctElements++;
596        }
597        this.totalCount += count;
598        return (left.height == initHeight) ? this : rebalance();
599      } else if (cmp > 0) {
600        AvlNode<E> initRight = right;
601        if (initRight == null) {
602          result[0] = 0;
603          return addRightChild(e, count);
604        }
605        int initHeight = initRight.height;
606
607        right = initRight.add(comparator, e, count, result);
608        if (result[0] == 0) {
609          distinctElements++;
610        }
611        this.totalCount += count;
612        return (right.height == initHeight) ? this : rebalance();
613      }
614
615      // adding count to me!  No rebalance possible.
616      result[0] = elemCount;
617      long resultCount = (long) elemCount + count;
618      checkArgument(resultCount <= Integer.MAX_VALUE);
619      this.elemCount += count;
620      this.totalCount += count;
621      return this;
622    }
623
624    AvlNode<E> remove(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
625      int cmp = comparator.compare(e, elem);
626      if (cmp < 0) {
627        AvlNode<E> initLeft = left;
628        if (initLeft == null) {
629          result[0] = 0;
630          return this;
631        }
632
633        left = initLeft.remove(comparator, e, count, result);
634
635        if (result[0] > 0) {
636          if (count >= result[0]) {
637            this.distinctElements--;
638            this.totalCount -= result[0];
639          } else {
640            this.totalCount -= count;
641          }
642        }
643        return (result[0] == 0) ? this : rebalance();
644      } else if (cmp > 0) {
645        AvlNode<E> initRight = right;
646        if (initRight == null) {
647          result[0] = 0;
648          return this;
649        }
650
651        right = initRight.remove(comparator, e, count, result);
652
653        if (result[0] > 0) {
654          if (count >= result[0]) {
655            this.distinctElements--;
656            this.totalCount -= result[0];
657          } else {
658            this.totalCount -= count;
659          }
660        }
661        return rebalance();
662      }
663
664      // removing count from me!
665      result[0] = elemCount;
666      if (count >= elemCount) {
667        return deleteMe();
668      } else {
669        this.elemCount -= count;
670        this.totalCount -= count;
671        return this;
672      }
673    }
674
675    AvlNode<E> setCount(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
676      int cmp = comparator.compare(e, elem);
677      if (cmp < 0) {
678        AvlNode<E> initLeft = left;
679        if (initLeft == null) {
680          result[0] = 0;
681          return (count > 0) ? addLeftChild(e, count) : this;
682        }
683
684        left = initLeft.setCount(comparator, e, count, result);
685
686        if (count == 0 && result[0] != 0) {
687          this.distinctElements--;
688        } else if (count > 0 && result[0] == 0) {
689          this.distinctElements++;
690        }
691
692        this.totalCount += count - result[0];
693        return rebalance();
694      } else if (cmp > 0) {
695        AvlNode<E> initRight = right;
696        if (initRight == null) {
697          result[0] = 0;
698          return (count > 0) ? addRightChild(e, count) : this;
699        }
700
701        right = initRight.setCount(comparator, e, count, result);
702
703        if (count == 0 && result[0] != 0) {
704          this.distinctElements--;
705        } else if (count > 0 && result[0] == 0) {
706          this.distinctElements++;
707        }
708
709        this.totalCount += count - result[0];
710        return rebalance();
711      }
712
713      // setting my count
714      result[0] = elemCount;
715      if (count == 0) {
716        return deleteMe();
717      }
718      this.totalCount += count - elemCount;
719      this.elemCount = count;
720      return this;
721    }
722
723    AvlNode<E> setCount(
724        Comparator<? super E> comparator,
725        @Nullable E e,
726        int expectedCount,
727        int newCount,
728        int[] result) {
729      int cmp = comparator.compare(e, elem);
730      if (cmp < 0) {
731        AvlNode<E> initLeft = left;
732        if (initLeft == null) {
733          result[0] = 0;
734          if (expectedCount == 0 && newCount > 0) {
735            return addLeftChild(e, newCount);
736          }
737          return this;
738        }
739
740        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
741
742        if (result[0] == expectedCount) {
743          if (newCount == 0 && result[0] != 0) {
744            this.distinctElements--;
745          } else if (newCount > 0 && result[0] == 0) {
746            this.distinctElements++;
747          }
748          this.totalCount += newCount - result[0];
749        }
750        return rebalance();
751      } else if (cmp > 0) {
752        AvlNode<E> initRight = right;
753        if (initRight == null) {
754          result[0] = 0;
755          if (expectedCount == 0 && newCount > 0) {
756            return addRightChild(e, newCount);
757          }
758          return this;
759        }
760
761        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
762
763        if (result[0] == expectedCount) {
764          if (newCount == 0 && result[0] != 0) {
765            this.distinctElements--;
766          } else if (newCount > 0 && result[0] == 0) {
767            this.distinctElements++;
768          }
769          this.totalCount += newCount - result[0];
770        }
771        return rebalance();
772      }
773
774      // setting my count
775      result[0] = elemCount;
776      if (expectedCount == elemCount) {
777        if (newCount == 0) {
778          return deleteMe();
779        }
780        this.totalCount += newCount - elemCount;
781        this.elemCount = newCount;
782      }
783      return this;
784    }
785
786    private AvlNode<E> deleteMe() {
787      int oldElemCount = this.elemCount;
788      this.elemCount = 0;
789      successor(pred, succ);
790      if (left == null) {
791        return right;
792      } else if (right == null) {
793        return left;
794      } else if (left.height >= right.height) {
795        AvlNode<E> newTop = pred;
796        // newTop is the maximum node in my left subtree
797        newTop.left = left.removeMax(newTop);
798        newTop.right = right;
799        newTop.distinctElements = distinctElements - 1;
800        newTop.totalCount = totalCount - oldElemCount;
801        return newTop.rebalance();
802      } else {
803        AvlNode<E> newTop = succ;
804        newTop.right = right.removeMin(newTop);
805        newTop.left = left;
806        newTop.distinctElements = distinctElements - 1;
807        newTop.totalCount = totalCount - oldElemCount;
808        return newTop.rebalance();
809      }
810    }
811
812    // Removes the minimum node from this subtree to be reused elsewhere
813    private AvlNode<E> removeMin(AvlNode<E> node) {
814      if (left == null) {
815        return right;
816      } else {
817        left = left.removeMin(node);
818        distinctElements--;
819        totalCount -= node.elemCount;
820        return rebalance();
821      }
822    }
823
824    // Removes the maximum node from this subtree to be reused elsewhere
825    private AvlNode<E> removeMax(AvlNode<E> node) {
826      if (right == null) {
827        return left;
828      } else {
829        right = right.removeMax(node);
830        distinctElements--;
831        totalCount -= node.elemCount;
832        return rebalance();
833      }
834    }
835
836    private void recomputeMultiset() {
837      this.distinctElements =
838          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
839      this.totalCount = elemCount + totalCount(left) + totalCount(right);
840    }
841
842    private void recomputeHeight() {
843      this.height = 1 + Math.max(height(left), height(right));
844    }
845
846    private void recompute() {
847      recomputeMultiset();
848      recomputeHeight();
849    }
850
851    private AvlNode<E> rebalance() {
852      switch (balanceFactor()) {
853        case -2:
854          if (right.balanceFactor() > 0) {
855            right = right.rotateRight();
856          }
857          return rotateLeft();
858        case 2:
859          if (left.balanceFactor() < 0) {
860            left = left.rotateLeft();
861          }
862          return rotateRight();
863        default:
864          recomputeHeight();
865          return this;
866      }
867    }
868
869    private int balanceFactor() {
870      return height(left) - height(right);
871    }
872
873    private AvlNode<E> rotateLeft() {
874      checkState(right != null);
875      AvlNode<E> newTop = right;
876      this.right = newTop.left;
877      newTop.left = this;
878      newTop.totalCount = this.totalCount;
879      newTop.distinctElements = this.distinctElements;
880      this.recompute();
881      newTop.recomputeHeight();
882      return newTop;
883    }
884
885    private AvlNode<E> rotateRight() {
886      checkState(left != null);
887      AvlNode<E> newTop = left;
888      this.left = newTop.right;
889      newTop.right = this;
890      newTop.totalCount = this.totalCount;
891      newTop.distinctElements = this.distinctElements;
892      this.recompute();
893      newTop.recomputeHeight();
894      return newTop;
895    }
896
897    private static long totalCount(@Nullable AvlNode<?> node) {
898      return (node == null) ? 0 : node.totalCount;
899    }
900
901    private static int height(@Nullable AvlNode<?> node) {
902      return (node == null) ? 0 : node.height;
903    }
904
905    @Nullable
906    private AvlNode<E> ceiling(Comparator<? super E> comparator, E e) {
907      int cmp = comparator.compare(e, elem);
908      if (cmp < 0) {
909        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
910      } else if (cmp == 0) {
911        return this;
912      } else {
913        return (right == null) ? null : right.ceiling(comparator, e);
914      }
915    }
916
917    @Nullable
918    private AvlNode<E> floor(Comparator<? super E> comparator, E e) {
919      int cmp = comparator.compare(e, elem);
920      if (cmp > 0) {
921        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
922      } else if (cmp == 0) {
923        return this;
924      } else {
925        return (left == null) ? null : left.floor(comparator, e);
926      }
927    }
928
929    @Override
930    public E getElement() {
931      return elem;
932    }
933
934    @Override
935    public int getCount() {
936      return elemCount;
937    }
938
939    @Override
940    public String toString() {
941      return Multisets.immutableEntry(getElement(), getCount()).toString();
942    }
943  }
944
945  private static <T> void successor(AvlNode<T> a, AvlNode<T> b) {
946    a.succ = b;
947    b.pred = a;
948  }
949
950  private static <T> void successor(AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
951    successor(a, b);
952    successor(b, c);
953  }
954
955  /*
956   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
957   * calls the comparator to compare the two keys. If that change is made,
958   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
959   */
960
961  /**
962   * @serialData the comparator, the number of distinct elements, the first element, its count, the
963   *             second element, its count, and so on
964   */
965  @GwtIncompatible("java.io.ObjectOutputStream")
966  private void writeObject(ObjectOutputStream stream) throws IOException {
967    stream.defaultWriteObject();
968    stream.writeObject(elementSet().comparator());
969    Serialization.writeMultiset(this, stream);
970  }
971
972  @GwtIncompatible("java.io.ObjectInputStream")
973  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
974    stream.defaultReadObject();
975    @SuppressWarnings("unchecked")
976    // reading data stored by writeObject
977    Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
978    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
979    Serialization.getFieldSetter(TreeMultiset.class, "range")
980        .set(this, GeneralRange.all(comparator));
981    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
982        .set(this, new Reference<AvlNode<E>>());
983    AvlNode<E> header = new AvlNode<E>(null, 1);
984    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
985    successor(header, header);
986    Serialization.populateMultiset(this, stream);
987  }
988
989  @GwtIncompatible("not needed in emulated source")
990  private static final long serialVersionUID = 1;
991}