001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkState;
021
022import com.google.common.annotations.GwtCompatible;
023import com.google.common.annotations.GwtIncompatible;
024import com.google.common.base.Objects;
025import com.google.common.primitives.Ints;
026
027import java.io.IOException;
028import java.io.ObjectInputStream;
029import java.io.ObjectOutputStream;
030import java.io.Serializable;
031import java.util.Comparator;
032import java.util.ConcurrentModificationException;
033import java.util.Iterator;
034import java.util.NoSuchElementException;
035
036import javax.annotation.Nullable;
037
038/**
039 * A multiset which maintains the ordering of its elements, according to either their natural order
040 * or an explicit {@link Comparator}. In all cases, this implementation uses
041 * {@link Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to
042 * determine equivalence of instances.
043 *
044 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
045 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the
046 * {@link java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
047 *
048 * <p>See the Guava User Guide article on <a href=
049 * "http://code.google.com/p/guava-libraries/wiki/NewCollectionTypesExplained#Multiset">
050 * {@code Multiset}</a>.
051 *
052 * @author Louis Wasserman
053 * @author Jared Levy
054 * @since 2.0 (imported from Google Collections Library)
055 */
056@GwtCompatible(emulated = true)
057public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements Serializable {
058
059  /**
060   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
061   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
062   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
063   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
064   * user attempts to add an element to the multiset that violates this constraint (for example,
065   * the user attempts to add a string element to a set whose elements are integers), the
066   * {@code add(Object)} call will throw a {@code ClassCastException}.
067   *
068   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
069   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
070   */
071  public static <E extends Comparable> TreeMultiset<E> create() {
072    return new TreeMultiset<E>(Ordering.natural());
073  }
074
075  /**
076   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
077   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
078   * {@code comparator.compare(e1,
079   * e2)} must not throw a {@code ClassCastException} for any elements {@code e1} and {@code e2} in
080   * the multiset. If the user attempts to add an element to the multiset that violates this
081   * constraint, the {@code add(Object)} call will throw a {@code ClassCastException}.
082   *
083   * @param comparator
084   *          the comparator that will be used to sort this multiset. A null value indicates that
085   *          the elements' <i>natural ordering</i> should be used.
086   */
087  @SuppressWarnings("unchecked")
088  public static <E> TreeMultiset<E> create(@Nullable Comparator<? super E> comparator) {
089    return (comparator == null)
090        ? new TreeMultiset<E>((Comparator) Ordering.natural())
091        : new TreeMultiset<E>(comparator);
092  }
093
094  /**
095   * Creates an empty multiset containing the given initial elements, sorted according to the
096   * elements' natural order.
097   *
098   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
099   *
100   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
101   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
102   */
103  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
104    TreeMultiset<E> multiset = create();
105    Iterables.addAll(multiset, elements);
106    return multiset;
107  }
108
109  private final transient Reference<AvlNode<E>> rootReference;
110  private final transient GeneralRange<E> range;
111  private final transient AvlNode<E> header;
112
113  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
114    super(range.comparator());
115    this.rootReference = rootReference;
116    this.range = range;
117    this.header = endLink;
118  }
119
120  TreeMultiset(Comparator<? super E> comparator) {
121    super(comparator);
122    this.range = GeneralRange.all(comparator);
123    this.header = new AvlNode<E>(null, 1);
124    successor(header, header);
125    this.rootReference = new Reference<AvlNode<E>>();
126  }
127
128  /**
129   * A function which can be summed across a subtree.
130   */
131  private enum Aggregate {
132    SIZE {
133      @Override
134      int nodeAggregate(AvlNode<?> node) {
135        return node.elemCount;
136      }
137
138      @Override
139      long treeAggregate(@Nullable AvlNode<?> root) {
140        return (root == null) ? 0 : root.totalCount;
141      }
142    },
143    DISTINCT {
144      @Override
145      int nodeAggregate(AvlNode<?> node) {
146        return 1;
147      }
148
149      @Override
150      long treeAggregate(@Nullable AvlNode<?> root) {
151        return (root == null) ? 0 : root.distinctElements;
152      }
153    };
154    abstract int nodeAggregate(AvlNode<?> node);
155
156    abstract long treeAggregate(@Nullable AvlNode<?> root);
157  }
158
159  private long aggregateForEntries(Aggregate aggr) {
160    AvlNode<E> root = rootReference.get();
161    long total = aggr.treeAggregate(root);
162    if (range.hasLowerBound()) {
163      total -= aggregateBelowRange(aggr, root);
164    }
165    if (range.hasUpperBound()) {
166      total -= aggregateAboveRange(aggr, root);
167    }
168    return total;
169  }
170
171  private long aggregateBelowRange(Aggregate aggr, @Nullable AvlNode<E> node) {
172    if (node == null) {
173      return 0;
174    }
175    int cmp = comparator().compare(range.getLowerEndpoint(), node.elem);
176    if (cmp < 0) {
177      return aggregateBelowRange(aggr, node.left);
178    } else if (cmp == 0) {
179      switch (range.getLowerBoundType()) {
180        case OPEN:
181          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
182        case CLOSED:
183          return aggr.treeAggregate(node.left);
184        default:
185          throw new AssertionError();
186      }
187    } else {
188      return aggr.treeAggregate(node.left) + aggr.nodeAggregate(node)
189          + aggregateBelowRange(aggr, node.right);
190    }
191  }
192
193  private long aggregateAboveRange(Aggregate aggr, @Nullable AvlNode<E> node) {
194    if (node == null) {
195      return 0;
196    }
197    int cmp = comparator().compare(range.getUpperEndpoint(), node.elem);
198    if (cmp > 0) {
199      return aggregateAboveRange(aggr, node.right);
200    } else if (cmp == 0) {
201      switch (range.getUpperBoundType()) {
202        case OPEN:
203          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
204        case CLOSED:
205          return aggr.treeAggregate(node.right);
206        default:
207          throw new AssertionError();
208      }
209    } else {
210      return aggr.treeAggregate(node.right) + aggr.nodeAggregate(node)
211          + aggregateAboveRange(aggr, node.left);
212    }
213  }
214
215  @Override
216  public int size() {
217    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
218  }
219
220  @Override
221  int distinctElements() {
222    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
223  }
224
225  @Override
226  public int count(@Nullable Object element) {
227    try {
228      @SuppressWarnings("unchecked")
229      E e = (E) element;
230      AvlNode<E> root = rootReference.get();
231      if (!range.contains(e) || root == null) {
232        return 0;
233      }
234      return root.count(comparator(), e);
235    } catch (ClassCastException e) {
236      return 0;
237    } catch (NullPointerException e) {
238      return 0;
239    }
240  }
241
242  @Override
243  public int add(@Nullable E element, int occurrences) {
244    checkArgument(occurrences >= 0, "occurrences must be >= 0 but was %s", occurrences);
245    if (occurrences == 0) {
246      return count(element);
247    }
248    checkArgument(range.contains(element));
249    AvlNode<E> root = rootReference.get();
250    if (root == null) {
251      comparator().compare(element, element);
252      AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
253      successor(header, newRoot, header);
254      rootReference.checkAndSet(root, newRoot);
255      return 0;
256    }
257    int[] result = new int[1]; // used as a mutable int reference to hold result
258    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
259    rootReference.checkAndSet(root, newRoot);
260    return result[0];
261  }
262
263  @Override
264  public int remove(@Nullable Object element, int occurrences) {
265    checkArgument(occurrences >= 0, "occurrences must be >= 0 but was %s", occurrences);
266    if (occurrences == 0) {
267      return count(element);
268    }
269    AvlNode<E> root = rootReference.get();
270    int[] result = new int[1]; // used as a mutable int reference to hold result
271    AvlNode<E> newRoot;
272    try {
273      @SuppressWarnings("unchecked")
274      E e = (E) element;
275      if (!range.contains(e) || root == null) {
276        return 0;
277      }
278      newRoot = root.remove(comparator(), e, occurrences, result);
279    } catch (ClassCastException e) {
280      return 0;
281    } catch (NullPointerException e) {
282      return 0;
283    }
284    rootReference.checkAndSet(root, newRoot);
285    return result[0];
286  }
287
288  @Override
289  public int setCount(@Nullable E element, int count) {
290    checkArgument(count >= 0);
291    if (!range.contains(element)) {
292      checkArgument(count == 0);
293      return 0;
294    }
295
296    AvlNode<E> root = rootReference.get();
297    if (root == null) {
298      if (count > 0) {
299        add(element, count);
300      }
301      return 0;
302    }
303    int[] result = new int[1]; // used as a mutable int reference to hold result
304    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
305    rootReference.checkAndSet(root, newRoot);
306    return result[0];
307  }
308
309  @Override
310  public boolean setCount(@Nullable E element, int oldCount, int newCount) {
311    checkArgument(newCount >= 0);
312    checkArgument(oldCount >= 0);
313    checkArgument(range.contains(element));
314
315    AvlNode<E> root = rootReference.get();
316    if (root == null) {
317      if (oldCount == 0) {
318        if (newCount > 0) {
319          add(element, newCount);
320        }
321        return true;
322      } else {
323        return false;
324      }
325    }
326    int[] result = new int[1]; // used as a mutable int reference to hold result
327    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
328    rootReference.checkAndSet(root, newRoot);
329    return result[0] == oldCount;
330  }
331
332  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
333    return new Multisets.AbstractEntry<E>() {
334      @Override
335      public E getElement() {
336        return baseEntry.getElement();
337      }
338
339      @Override
340      public int getCount() {
341        int result = baseEntry.getCount();
342        if (result == 0) {
343          return count(getElement());
344        } else {
345          return result;
346        }
347      }
348    };
349  }
350
351  /**
352   * Returns the first node in the tree that is in range.
353   */
354  @Nullable private AvlNode<E> firstNode() {
355    AvlNode<E> root = rootReference.get();
356    if (root == null) {
357      return null;
358    }
359    AvlNode<E> node;
360    if (range.hasLowerBound()) {
361      E endpoint = range.getLowerEndpoint();
362      node = rootReference.get().ceiling(comparator(), endpoint);
363      if (node == null) {
364        return null;
365      }
366      if (range.getLowerBoundType() == BoundType.OPEN
367          && comparator().compare(endpoint, node.getElement()) == 0) {
368        node = node.succ;
369      }
370    } else {
371      node = header.succ;
372    }
373    return (node == header || !range.contains(node.getElement())) ? null : node;
374  }
375
376  @Nullable private AvlNode<E> lastNode() {
377    AvlNode<E> root = rootReference.get();
378    if (root == null) {
379      return null;
380    }
381    AvlNode<E> node;
382    if (range.hasUpperBound()) {
383      E endpoint = range.getUpperEndpoint();
384      node = rootReference.get().floor(comparator(), endpoint);
385      if (node == null) {
386        return null;
387      }
388      if (range.getUpperBoundType() == BoundType.OPEN
389          && comparator().compare(endpoint, node.getElement()) == 0) {
390        node = node.pred;
391      }
392    } else {
393      node = header.pred;
394    }
395    return (node == header || !range.contains(node.getElement())) ? null : node;
396  }
397
398  @Override
399  Iterator<Entry<E>> entryIterator() {
400    return new Iterator<Entry<E>>() {
401      AvlNode<E> current = firstNode();
402      Entry<E> prevEntry;
403
404      @Override
405      public boolean hasNext() {
406        if (current == null) {
407          return false;
408        } else if (range.tooHigh(current.getElement())) {
409          current = null;
410          return false;
411        } else {
412          return true;
413        }
414      }
415
416      @Override
417      public Entry<E> next() {
418        if (!hasNext()) {
419          throw new NoSuchElementException();
420        }
421        Entry<E> result = wrapEntry(current);
422        prevEntry = result;
423        if (current.succ == header) {
424          current = null;
425        } else {
426          current = current.succ;
427        }
428        return result;
429      }
430
431      @Override
432      public void remove() {
433        checkState(prevEntry != null);
434        setCount(prevEntry.getElement(), 0);
435        prevEntry = null;
436      }
437    };
438  }
439
440  @Override
441  Iterator<Entry<E>> descendingEntryIterator() {
442    return new Iterator<Entry<E>>() {
443      AvlNode<E> current = lastNode();
444      Entry<E> prevEntry = null;
445
446      @Override
447      public boolean hasNext() {
448        if (current == null) {
449          return false;
450        } else if (range.tooLow(current.getElement())) {
451          current = null;
452          return false;
453        } else {
454          return true;
455        }
456      }
457
458      @Override
459      public Entry<E> next() {
460        if (!hasNext()) {
461          throw new NoSuchElementException();
462        }
463        Entry<E> result = wrapEntry(current);
464        prevEntry = result;
465        if (current.pred == header) {
466          current = null;
467        } else {
468          current = current.pred;
469        }
470        return result;
471      }
472
473      @Override
474      public void remove() {
475        checkState(prevEntry != null);
476        setCount(prevEntry.getElement(), 0);
477        prevEntry = null;
478      }
479    };
480  }
481
482  @Override
483  public SortedMultiset<E> headMultiset(@Nullable E upperBound, BoundType boundType) {
484    return new TreeMultiset<E>(rootReference, range.intersect(GeneralRange.upTo(
485        comparator(),
486        upperBound,
487        boundType)), header);
488  }
489
490  @Override
491  public SortedMultiset<E> tailMultiset(@Nullable E lowerBound, BoundType boundType) {
492    return new TreeMultiset<E>(rootReference, range.intersect(GeneralRange.downTo(
493        comparator(),
494        lowerBound,
495        boundType)), header);
496  }
497
498  static int distinctElements(@Nullable AvlNode<?> node) {
499    return (node == null) ? 0 : node.distinctElements;
500  }
501
502  private static final class Reference<T> {
503    @Nullable private T value;
504
505    @Nullable public T get() {
506      return value;
507    }
508
509    public void checkAndSet(@Nullable T expected, T newValue) {
510      if (value != expected) {
511        throw new ConcurrentModificationException();
512      }
513      value = newValue;
514    }
515  }
516
517  private static final class AvlNode<E> extends Multisets.AbstractEntry<E> {
518    @Nullable private final E elem;
519
520    // elemCount is 0 iff this node has been deleted.
521    private int elemCount;
522
523    private int distinctElements;
524    private long totalCount;
525    private int height;
526    private AvlNode<E> left;
527    private AvlNode<E> right;
528    private AvlNode<E> pred;
529    private AvlNode<E> succ;
530
531    AvlNode(@Nullable E elem, int elemCount) {
532      checkArgument(elemCount > 0);
533      this.elem = elem;
534      this.elemCount = elemCount;
535      this.totalCount = elemCount;
536      this.distinctElements = 1;
537      this.height = 1;
538      this.left = null;
539      this.right = null;
540    }
541
542    public int count(Comparator<? super E> comparator, E e) {
543      int cmp = comparator.compare(e, elem);
544      if (cmp < 0) {
545        return (left == null) ? 0 : left.count(comparator, e);
546      } else if (cmp > 0) {
547        return (right == null) ? 0 : right.count(comparator, e);
548      } else {
549        return elemCount;
550      }
551    }
552
553    private AvlNode<E> addRightChild(E e, int count) {
554      right = new AvlNode<E>(e, count);
555      successor(this, right, succ);
556      height = Math.max(2, height);
557      distinctElements++;
558      totalCount += count;
559      return this;
560    }
561
562    private AvlNode<E> addLeftChild(E e, int count) {
563      left = new AvlNode<E>(e, count);
564      successor(pred, left, this);
565      height = Math.max(2, height);
566      distinctElements++;
567      totalCount += count;
568      return this;
569    }
570
571    AvlNode<E> add(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
572      /*
573       * It speeds things up considerably to unconditionally add count to totalCount here,
574       * but that destroys failure atomicity in the case of count overflow. =(
575       */
576      int cmp = comparator.compare(e, elem);
577      if (cmp < 0) {
578        AvlNode<E> initLeft = left;
579        if (initLeft == null) {
580          result[0] = 0;
581          return addLeftChild(e, count);
582        }
583        int initHeight = initLeft.height;
584
585        left = initLeft.add(comparator, e, count, result);
586        if (result[0] == 0) {
587          distinctElements++;
588        }
589        this.totalCount += count;
590        return (left.height == initHeight) ? this : rebalance();
591      } else if (cmp > 0) {
592        AvlNode<E> initRight = right;
593        if (initRight == null) {
594          result[0] = 0;
595          return addRightChild(e, count);
596        }
597        int initHeight = initRight.height;
598
599        right = initRight.add(comparator, e, count, result);
600        if (result[0] == 0) {
601          distinctElements++;
602        }
603        this.totalCount += count;
604        return (right.height == initHeight) ? this : rebalance();
605      }
606
607      // adding count to me!  No rebalance possible.
608      result[0] = elemCount;
609      long resultCount = (long) elemCount + count;
610      checkArgument(resultCount <= Integer.MAX_VALUE);
611      this.elemCount += count;
612      this.totalCount += count;
613      return this;
614    }
615
616    AvlNode<E> remove(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
617      int cmp = comparator.compare(e, elem);
618      if (cmp < 0) {
619        AvlNode<E> initLeft = left;
620        if (initLeft == null) {
621          result[0] = 0;
622          return this;
623        }
624
625        left = initLeft.remove(comparator, e, count, result);
626
627        if (result[0] > 0) {
628          if (count >= result[0]) {
629            this.distinctElements--;
630            this.totalCount -= result[0];
631          } else {
632            this.totalCount -= count;
633          }
634        }
635        return (result[0] == 0) ? this : rebalance();
636      } else if (cmp > 0) {
637        AvlNode<E> initRight = right;
638        if (initRight == null) {
639          result[0] = 0;
640          return this;
641        }
642
643        right = initRight.remove(comparator, e, count, result);
644
645        if (result[0] > 0) {
646          if (count >= result[0]) {
647            this.distinctElements--;
648            this.totalCount -= result[0];
649          } else {
650            this.totalCount -= count;
651          }
652        }
653        return rebalance();
654      }
655
656      // removing count from me!
657      result[0] = elemCount;
658      if (count >= elemCount) {
659        return deleteMe();
660      } else {
661        this.elemCount -= count;
662        this.totalCount -= count;
663        return this;
664      }
665    }
666
667    AvlNode<E> setCount(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
668      int cmp = comparator.compare(e, elem);
669      if (cmp < 0) {
670        AvlNode<E> initLeft = left;
671        if (initLeft == null) {
672          result[0] = 0;
673          return (count > 0) ? addLeftChild(e, count) : this;
674        }
675
676        left = initLeft.setCount(comparator, e, count, result);
677
678        if (count == 0 && result[0] != 0) {
679          this.distinctElements--;
680        } else if (count > 0 && result[0] == 0) {
681          this.distinctElements++;
682        }
683
684        this.totalCount += count - result[0];
685        return rebalance();
686      } else if (cmp > 0) {
687        AvlNode<E> initRight = right;
688        if (initRight == null) {
689          result[0] = 0;
690          return (count > 0) ? addRightChild(e, count) : this;
691        }
692
693        right = initRight.setCount(comparator, e, count, result);
694
695        if (count == 0 && result[0] != 0) {
696          this.distinctElements--;
697        } else if (count > 0 && result[0] == 0) {
698          this.distinctElements++;
699        }
700
701        this.totalCount += count - result[0];
702        return rebalance();
703      }
704
705      // setting my count
706      result[0] = elemCount;
707      if (count == 0) {
708        return deleteMe();
709      }
710      this.totalCount += count - elemCount;
711      this.elemCount = count;
712      return this;
713    }
714
715    AvlNode<E> setCount(
716        Comparator<? super E> comparator,
717        @Nullable E e,
718        int expectedCount,
719        int newCount,
720        int[] result) {
721      int cmp = comparator.compare(e, elem);
722      if (cmp < 0) {
723        AvlNode<E> initLeft = left;
724        if (initLeft == null) {
725          result[0] = 0;
726          if (expectedCount == 0 && newCount > 0) {
727            return addLeftChild(e, newCount);
728          }
729          return this;
730        }
731
732        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
733
734        if (result[0] == expectedCount) {
735          if (newCount == 0 && result[0] != 0) {
736            this.distinctElements--;
737          } else if (newCount > 0 && result[0] == 0) {
738            this.distinctElements++;
739          }
740          this.totalCount += newCount - result[0];
741        }
742        return rebalance();
743      } else if (cmp > 0) {
744        AvlNode<E> initRight = right;
745        if (initRight == null) {
746          result[0] = 0;
747          if (expectedCount == 0 && newCount > 0) {
748            return addRightChild(e, newCount);
749          }
750          return this;
751        }
752
753        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
754
755        if (result[0] == expectedCount) {
756          if (newCount == 0 && result[0] != 0) {
757            this.distinctElements--;
758          } else if (newCount > 0 && result[0] == 0) {
759            this.distinctElements++;
760          }
761          this.totalCount += newCount - result[0];
762        }
763        return rebalance();
764      }
765
766      // setting my count
767      result[0] = elemCount;
768      if (expectedCount == elemCount) {
769        if (newCount == 0) {
770          return deleteMe();
771        }
772        this.totalCount += newCount - elemCount;
773        this.elemCount = newCount;
774      }
775      return this;
776    }
777
778    private AvlNode<E> deleteMe() {
779      int oldElemCount = this.elemCount;
780      this.elemCount = 0;
781      successor(pred, succ);
782      if (left == null) {
783        return right;
784      } else if (right == null) {
785        return left;
786      } else if (left.height >= right.height) {
787        AvlNode<E> newTop = pred;
788        // newTop is the maximum node in my left subtree
789        newTop.left = left.removeMax(newTop);
790        newTop.right = right;
791        newTop.distinctElements = distinctElements - 1;
792        newTop.totalCount = totalCount - oldElemCount;
793        return newTop.rebalance();
794      } else {
795        AvlNode<E> newTop = succ;
796        newTop.right = right.removeMin(newTop);
797        newTop.left = left;
798        newTop.distinctElements = distinctElements - 1;
799        newTop.totalCount = totalCount - oldElemCount;
800        return newTop.rebalance();
801      }
802    }
803
804    // Removes the minimum node from this subtree to be reused elsewhere
805    private AvlNode<E> removeMin(AvlNode<E> node) {
806      if (left == null) {
807        return right;
808      } else {
809        left = left.removeMin(node);
810        distinctElements--;
811        totalCount -= node.elemCount;
812        return rebalance();
813      }
814    }
815
816    // Removes the maximum node from this subtree to be reused elsewhere
817    private AvlNode<E> removeMax(AvlNode<E> node) {
818      if (right == null) {
819        return left;
820      } else {
821        right = right.removeMax(node);
822        distinctElements--;
823        totalCount -= node.elemCount;
824        return rebalance();
825      }
826    }
827
828    private void recomputeMultiset() {
829      this.distinctElements = 1 + TreeMultiset.distinctElements(left)
830          + TreeMultiset.distinctElements(right);
831      this.totalCount = elemCount + totalCount(left) + totalCount(right);
832    }
833
834    private void recomputeHeight() {
835      this.height = 1 + Math.max(height(left), height(right));
836    }
837
838    private void recompute() {
839      recomputeMultiset();
840      recomputeHeight();
841    }
842
843    private AvlNode<E> rebalance() {
844      switch (balanceFactor()) {
845        case -2:
846          if (right.balanceFactor() > 0) {
847            right = right.rotateRight();
848          }
849          return rotateLeft();
850        case 2:
851          if (left.balanceFactor() < 0) {
852            left = left.rotateLeft();
853          }
854          return rotateRight();
855        default:
856          recomputeHeight();
857          return this;
858      }
859    }
860
861    private int balanceFactor() {
862      return height(left) - height(right);
863    }
864
865    private AvlNode<E> rotateLeft() {
866      checkState(right != null);
867      AvlNode<E> newTop = right;
868      this.right = newTop.left;
869      newTop.left = this;
870      newTop.totalCount = this.totalCount;
871      newTop.distinctElements = this.distinctElements;
872      this.recompute();
873      newTop.recomputeHeight();
874      return newTop;
875    }
876
877    private AvlNode<E> rotateRight() {
878      checkState(left != null);
879      AvlNode<E> newTop = left;
880      this.left = newTop.right;
881      newTop.right = this;
882      newTop.totalCount = this.totalCount;
883      newTop.distinctElements = this.distinctElements;
884      this.recompute();
885      newTop.recomputeHeight();
886      return newTop;
887    }
888
889    private static long totalCount(@Nullable AvlNode<?> node) {
890      return (node == null) ? 0 : node.totalCount;
891    }
892
893    private static int height(@Nullable AvlNode<?> node) {
894      return (node == null) ? 0 : node.height;
895    }
896
897    @Nullable private AvlNode<E> ceiling(Comparator<? super E> comparator, E e) {
898      int cmp = comparator.compare(e, elem);
899      if (cmp < 0) {
900        return (left == null) ? this : Objects.firstNonNull(left.ceiling(comparator, e), this);
901      } else if (cmp == 0) {
902        return this;
903      } else {
904        return (right == null) ? null : right.ceiling(comparator, e);
905      }
906    }
907
908    @Nullable private AvlNode<E> floor(Comparator<? super E> comparator, E e) {
909      int cmp = comparator.compare(e, elem);
910      if (cmp > 0) {
911        return (right == null) ? this : Objects.firstNonNull(right.floor(comparator, e), this);
912      } else if (cmp == 0) {
913        return this;
914      } else {
915        return (left == null) ? null : left.floor(comparator, e);
916      }
917    }
918
919    @Override
920    public E getElement() {
921      return elem;
922    }
923
924    @Override
925    public int getCount() {
926      return elemCount;
927    }
928
929    @Override
930    public String toString() {
931      return Multisets.immutableEntry(getElement(), getCount()).toString();
932    }
933  }
934
935  private static <T> void successor(AvlNode<T> a, AvlNode<T> b) {
936    a.succ = b;
937    b.pred = a;
938  }
939
940  private static <T> void successor(AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
941    successor(a, b);
942    successor(b, c);
943  }
944
945  /*
946   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
947   * calls the comparator to compare the two keys. If that change is made,
948   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
949   */
950
951  /**
952   * @serialData the comparator, the number of distinct elements, the first element, its count, the
953   *             second element, its count, and so on
954   */
955  @GwtIncompatible("java.io.ObjectOutputStream")
956  private void writeObject(ObjectOutputStream stream) throws IOException {
957    stream.defaultWriteObject();
958    stream.writeObject(elementSet().comparator());
959    Serialization.writeMultiset(this, stream);
960  }
961
962  @GwtIncompatible("java.io.ObjectInputStream")
963  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
964    stream.defaultReadObject();
965    @SuppressWarnings("unchecked")
966    // reading data stored by writeObject
967    Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
968    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
969    Serialization.getFieldSetter(TreeMultiset.class, "range").set(
970        this,
971        GeneralRange.all(comparator));
972    Serialization.getFieldSetter(TreeMultiset.class, "rootReference").set(
973        this,
974        new Reference<AvlNode<E>>());
975    AvlNode<E> header = new AvlNode<E>(null, 1);
976    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
977    successor(header, header);
978    Serialization.populateMultiset(this, stream);
979  }
980
981  @GwtIncompatible("not needed in emulated source") private static final long serialVersionUID = 1;
982}