001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkState;
021import static com.google.common.collect.CollectPreconditions.checkRemove;
022
023import com.google.common.annotations.GwtCompatible;
024import com.google.common.annotations.GwtIncompatible;
025import com.google.common.base.Objects;
026import com.google.common.primitives.Ints;
027
028import java.io.IOException;
029import java.io.ObjectInputStream;
030import java.io.ObjectOutputStream;
031import java.io.Serializable;
032import java.util.Comparator;
033import java.util.ConcurrentModificationException;
034import java.util.Iterator;
035import java.util.NoSuchElementException;
036
037import javax.annotation.Nullable;
038
039/**
040 * A multiset which maintains the ordering of its elements, according to either their natural order
041 * or an explicit {@link Comparator}. In all cases, this implementation uses
042 * {@link Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to
043 * determine equivalence of instances.
044 *
045 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
046 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the
047 * {@link java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
048 *
049 * <p>See the Guava User Guide article on <a href=
050 * "http://code.google.com/p/guava-libraries/wiki/NewCollectionTypesExplained#Multiset">
051 * {@code Multiset}</a>.
052 *
053 * @author Louis Wasserman
054 * @author Jared Levy
055 * @since 2.0 (imported from Google Collections Library)
056 */
057@GwtCompatible(emulated = true)
058public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements Serializable {
059
060  /**
061   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
062   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
063   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
064   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
065   * user attempts to add an element to the multiset that violates this constraint (for example,
066   * the user attempts to add a string element to a set whose elements are integers), the
067   * {@code add(Object)} call will throw a {@code ClassCastException}.
068   *
069   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
070   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
071   */
072  public static <E extends Comparable> TreeMultiset<E> create() {
073    return new TreeMultiset<E>(Ordering.natural());
074  }
075
076  /**
077   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
078   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
079   * {@code comparator.compare(e1,
080   * e2)} must not throw a {@code ClassCastException} for any elements {@code e1} and {@code e2} in
081   * the multiset. If the user attempts to add an element to the multiset that violates this
082   * constraint, the {@code add(Object)} call will throw a {@code ClassCastException}.
083   *
084   * @param comparator
085   *          the comparator that will be used to sort this multiset. A null value indicates that
086   *          the elements' <i>natural ordering</i> should be used.
087   */
088  @SuppressWarnings("unchecked")
089  public static <E> TreeMultiset<E> create(@Nullable Comparator<? super E> comparator) {
090    return (comparator == null)
091        ? new TreeMultiset<E>((Comparator) Ordering.natural())
092        : new TreeMultiset<E>(comparator);
093  }
094
095  /**
096   * Creates an empty multiset containing the given initial elements, sorted according to the
097   * elements' natural order.
098   *
099   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
100   *
101   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
102   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
103   */
104  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
105    TreeMultiset<E> multiset = create();
106    Iterables.addAll(multiset, elements);
107    return multiset;
108  }
109
110  private final transient Reference<AvlNode<E>> rootReference;
111  private final transient GeneralRange<E> range;
112  private final transient AvlNode<E> header;
113
114  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
115    super(range.comparator());
116    this.rootReference = rootReference;
117    this.range = range;
118    this.header = endLink;
119  }
120
121  TreeMultiset(Comparator<? super E> comparator) {
122    super(comparator);
123    this.range = GeneralRange.all(comparator);
124    this.header = new AvlNode<E>(null, 1);
125    successor(header, header);
126    this.rootReference = new Reference<AvlNode<E>>();
127  }
128
129  /**
130   * A function which can be summed across a subtree.
131   */
132  private enum Aggregate {
133    SIZE {
134      @Override
135      int nodeAggregate(AvlNode<?> node) {
136        return node.elemCount;
137      }
138
139      @Override
140      long treeAggregate(@Nullable AvlNode<?> root) {
141        return (root == null) ? 0 : root.totalCount;
142      }
143    },
144    DISTINCT {
145      @Override
146      int nodeAggregate(AvlNode<?> node) {
147        return 1;
148      }
149
150      @Override
151      long treeAggregate(@Nullable AvlNode<?> root) {
152        return (root == null) ? 0 : root.distinctElements;
153      }
154    };
155    abstract int nodeAggregate(AvlNode<?> node);
156
157    abstract long treeAggregate(@Nullable AvlNode<?> root);
158  }
159
160  private long aggregateForEntries(Aggregate aggr) {
161    AvlNode<E> root = rootReference.get();
162    long total = aggr.treeAggregate(root);
163    if (range.hasLowerBound()) {
164      total -= aggregateBelowRange(aggr, root);
165    }
166    if (range.hasUpperBound()) {
167      total -= aggregateAboveRange(aggr, root);
168    }
169    return total;
170  }
171
172  private long aggregateBelowRange(Aggregate aggr, @Nullable AvlNode<E> node) {
173    if (node == null) {
174      return 0;
175    }
176    int cmp = comparator().compare(range.getLowerEndpoint(), node.elem);
177    if (cmp < 0) {
178      return aggregateBelowRange(aggr, node.left);
179    } else if (cmp == 0) {
180      switch (range.getLowerBoundType()) {
181        case OPEN:
182          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
183        case CLOSED:
184          return aggr.treeAggregate(node.left);
185        default:
186          throw new AssertionError();
187      }
188    } else {
189      return aggr.treeAggregate(node.left) + aggr.nodeAggregate(node)
190          + aggregateBelowRange(aggr, node.right);
191    }
192  }
193
194  private long aggregateAboveRange(Aggregate aggr, @Nullable AvlNode<E> node) {
195    if (node == null) {
196      return 0;
197    }
198    int cmp = comparator().compare(range.getUpperEndpoint(), node.elem);
199    if (cmp > 0) {
200      return aggregateAboveRange(aggr, node.right);
201    } else if (cmp == 0) {
202      switch (range.getUpperBoundType()) {
203        case OPEN:
204          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
205        case CLOSED:
206          return aggr.treeAggregate(node.right);
207        default:
208          throw new AssertionError();
209      }
210    } else {
211      return aggr.treeAggregate(node.right) + aggr.nodeAggregate(node)
212          + aggregateAboveRange(aggr, node.left);
213    }
214  }
215
216  @Override
217  public int size() {
218    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
219  }
220
221  @Override
222  int distinctElements() {
223    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
224  }
225
226  @Override
227  public int count(@Nullable Object element) {
228    try {
229      @SuppressWarnings("unchecked")
230      E e = (E) element;
231      AvlNode<E> root = rootReference.get();
232      if (!range.contains(e) || root == null) {
233        return 0;
234      }
235      return root.count(comparator(), e);
236    } catch (ClassCastException e) {
237      return 0;
238    } catch (NullPointerException e) {
239      return 0;
240    }
241  }
242
243  @Override
244  public int add(@Nullable E element, int occurrences) {
245    checkArgument(occurrences >= 0, "occurrences must be >= 0 but was %s", occurrences);
246    if (occurrences == 0) {
247      return count(element);
248    }
249    checkArgument(range.contains(element));
250    AvlNode<E> root = rootReference.get();
251    if (root == null) {
252      comparator().compare(element, element);
253      AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
254      successor(header, newRoot, header);
255      rootReference.checkAndSet(root, newRoot);
256      return 0;
257    }
258    int[] result = new int[1]; // used as a mutable int reference to hold result
259    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
260    rootReference.checkAndSet(root, newRoot);
261    return result[0];
262  }
263
264  @Override
265  public int remove(@Nullable Object element, int occurrences) {
266    checkArgument(occurrences >= 0, "occurrences must be >= 0 but was %s", occurrences);
267    if (occurrences == 0) {
268      return count(element);
269    }
270    AvlNode<E> root = rootReference.get();
271    int[] result = new int[1]; // used as a mutable int reference to hold result
272    AvlNode<E> newRoot;
273    try {
274      @SuppressWarnings("unchecked")
275      E e = (E) element;
276      if (!range.contains(e) || root == null) {
277        return 0;
278      }
279      newRoot = root.remove(comparator(), e, occurrences, result);
280    } catch (ClassCastException e) {
281      return 0;
282    } catch (NullPointerException e) {
283      return 0;
284    }
285    rootReference.checkAndSet(root, newRoot);
286    return result[0];
287  }
288
289  @Override
290  public int setCount(@Nullable E element, int count) {
291    checkArgument(count >= 0);
292    if (!range.contains(element)) {
293      checkArgument(count == 0);
294      return 0;
295    }
296
297    AvlNode<E> root = rootReference.get();
298    if (root == null) {
299      if (count > 0) {
300        add(element, count);
301      }
302      return 0;
303    }
304    int[] result = new int[1]; // used as a mutable int reference to hold result
305    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
306    rootReference.checkAndSet(root, newRoot);
307    return result[0];
308  }
309
310  @Override
311  public boolean setCount(@Nullable E element, int oldCount, int newCount) {
312    checkArgument(newCount >= 0);
313    checkArgument(oldCount >= 0);
314    checkArgument(range.contains(element));
315
316    AvlNode<E> root = rootReference.get();
317    if (root == null) {
318      if (oldCount == 0) {
319        if (newCount > 0) {
320          add(element, newCount);
321        }
322        return true;
323      } else {
324        return false;
325      }
326    }
327    int[] result = new int[1]; // used as a mutable int reference to hold result
328    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
329    rootReference.checkAndSet(root, newRoot);
330    return result[0] == oldCount;
331  }
332
333  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
334    return new Multisets.AbstractEntry<E>() {
335      @Override
336      public E getElement() {
337        return baseEntry.getElement();
338      }
339
340      @Override
341      public int getCount() {
342        int result = baseEntry.getCount();
343        if (result == 0) {
344          return count(getElement());
345        } else {
346          return result;
347        }
348      }
349    };
350  }
351
352  /**
353   * Returns the first node in the tree that is in range.
354   */
355  @Nullable private AvlNode<E> firstNode() {
356    AvlNode<E> root = rootReference.get();
357    if (root == null) {
358      return null;
359    }
360    AvlNode<E> node;
361    if (range.hasLowerBound()) {
362      E endpoint = range.getLowerEndpoint();
363      node = rootReference.get().ceiling(comparator(), endpoint);
364      if (node == null) {
365        return null;
366      }
367      if (range.getLowerBoundType() == BoundType.OPEN
368          && comparator().compare(endpoint, node.getElement()) == 0) {
369        node = node.succ;
370      }
371    } else {
372      node = header.succ;
373    }
374    return (node == header || !range.contains(node.getElement())) ? null : node;
375  }
376
377  @Nullable private AvlNode<E> lastNode() {
378    AvlNode<E> root = rootReference.get();
379    if (root == null) {
380      return null;
381    }
382    AvlNode<E> node;
383    if (range.hasUpperBound()) {
384      E endpoint = range.getUpperEndpoint();
385      node = rootReference.get().floor(comparator(), endpoint);
386      if (node == null) {
387        return null;
388      }
389      if (range.getUpperBoundType() == BoundType.OPEN
390          && comparator().compare(endpoint, node.getElement()) == 0) {
391        node = node.pred;
392      }
393    } else {
394      node = header.pred;
395    }
396    return (node == header || !range.contains(node.getElement())) ? null : node;
397  }
398
399  @Override
400  Iterator<Entry<E>> entryIterator() {
401    return new Iterator<Entry<E>>() {
402      AvlNode<E> current = firstNode();
403      Entry<E> prevEntry;
404
405      @Override
406      public boolean hasNext() {
407        if (current == null) {
408          return false;
409        } else if (range.tooHigh(current.getElement())) {
410          current = null;
411          return false;
412        } else {
413          return true;
414        }
415      }
416
417      @Override
418      public Entry<E> next() {
419        if (!hasNext()) {
420          throw new NoSuchElementException();
421        }
422        Entry<E> result = wrapEntry(current);
423        prevEntry = result;
424        if (current.succ == header) {
425          current = null;
426        } else {
427          current = current.succ;
428        }
429        return result;
430      }
431
432      @Override
433      public void remove() {
434        checkRemove(prevEntry != null);
435        setCount(prevEntry.getElement(), 0);
436        prevEntry = null;
437      }
438    };
439  }
440
441  @Override
442  Iterator<Entry<E>> descendingEntryIterator() {
443    return new Iterator<Entry<E>>() {
444      AvlNode<E> current = lastNode();
445      Entry<E> prevEntry = null;
446
447      @Override
448      public boolean hasNext() {
449        if (current == null) {
450          return false;
451        } else if (range.tooLow(current.getElement())) {
452          current = null;
453          return false;
454        } else {
455          return true;
456        }
457      }
458
459      @Override
460      public Entry<E> next() {
461        if (!hasNext()) {
462          throw new NoSuchElementException();
463        }
464        Entry<E> result = wrapEntry(current);
465        prevEntry = result;
466        if (current.pred == header) {
467          current = null;
468        } else {
469          current = current.pred;
470        }
471        return result;
472      }
473
474      @Override
475      public void remove() {
476        checkRemove(prevEntry != null);
477        setCount(prevEntry.getElement(), 0);
478        prevEntry = null;
479      }
480    };
481  }
482
483  @Override
484  public SortedMultiset<E> headMultiset(@Nullable E upperBound, BoundType boundType) {
485    return new TreeMultiset<E>(rootReference, range.intersect(GeneralRange.upTo(
486        comparator(),
487        upperBound,
488        boundType)), header);
489  }
490
491  @Override
492  public SortedMultiset<E> tailMultiset(@Nullable E lowerBound, BoundType boundType) {
493    return new TreeMultiset<E>(rootReference, range.intersect(GeneralRange.downTo(
494        comparator(),
495        lowerBound,
496        boundType)), header);
497  }
498
499  static int distinctElements(@Nullable AvlNode<?> node) {
500    return (node == null) ? 0 : node.distinctElements;
501  }
502
503  private static final class Reference<T> {
504    @Nullable private T value;
505
506    @Nullable public T get() {
507      return value;
508    }
509
510    public void checkAndSet(@Nullable T expected, T newValue) {
511      if (value != expected) {
512        throw new ConcurrentModificationException();
513      }
514      value = newValue;
515    }
516  }
517
518  private static final class AvlNode<E> extends Multisets.AbstractEntry<E> {
519    @Nullable private final E elem;
520
521    // elemCount is 0 iff this node has been deleted.
522    private int elemCount;
523
524    private int distinctElements;
525    private long totalCount;
526    private int height;
527    private AvlNode<E> left;
528    private AvlNode<E> right;
529    private AvlNode<E> pred;
530    private AvlNode<E> succ;
531
532    AvlNode(@Nullable E elem, int elemCount) {
533      checkArgument(elemCount > 0);
534      this.elem = elem;
535      this.elemCount = elemCount;
536      this.totalCount = elemCount;
537      this.distinctElements = 1;
538      this.height = 1;
539      this.left = null;
540      this.right = null;
541    }
542
543    public int count(Comparator<? super E> comparator, E e) {
544      int cmp = comparator.compare(e, elem);
545      if (cmp < 0) {
546        return (left == null) ? 0 : left.count(comparator, e);
547      } else if (cmp > 0) {
548        return (right == null) ? 0 : right.count(comparator, e);
549      } else {
550        return elemCount;
551      }
552    }
553
554    private AvlNode<E> addRightChild(E e, int count) {
555      right = new AvlNode<E>(e, count);
556      successor(this, right, succ);
557      height = Math.max(2, height);
558      distinctElements++;
559      totalCount += count;
560      return this;
561    }
562
563    private AvlNode<E> addLeftChild(E e, int count) {
564      left = new AvlNode<E>(e, count);
565      successor(pred, left, this);
566      height = Math.max(2, height);
567      distinctElements++;
568      totalCount += count;
569      return this;
570    }
571
572    AvlNode<E> add(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
573      /*
574       * It speeds things up considerably to unconditionally add count to totalCount here,
575       * but that destroys failure atomicity in the case of count overflow. =(
576       */
577      int cmp = comparator.compare(e, elem);
578      if (cmp < 0) {
579        AvlNode<E> initLeft = left;
580        if (initLeft == null) {
581          result[0] = 0;
582          return addLeftChild(e, count);
583        }
584        int initHeight = initLeft.height;
585
586        left = initLeft.add(comparator, e, count, result);
587        if (result[0] == 0) {
588          distinctElements++;
589        }
590        this.totalCount += count;
591        return (left.height == initHeight) ? this : rebalance();
592      } else if (cmp > 0) {
593        AvlNode<E> initRight = right;
594        if (initRight == null) {
595          result[0] = 0;
596          return addRightChild(e, count);
597        }
598        int initHeight = initRight.height;
599
600        right = initRight.add(comparator, e, count, result);
601        if (result[0] == 0) {
602          distinctElements++;
603        }
604        this.totalCount += count;
605        return (right.height == initHeight) ? this : rebalance();
606      }
607
608      // adding count to me!  No rebalance possible.
609      result[0] = elemCount;
610      long resultCount = (long) elemCount + count;
611      checkArgument(resultCount <= Integer.MAX_VALUE);
612      this.elemCount += count;
613      this.totalCount += count;
614      return this;
615    }
616
617    AvlNode<E> remove(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
618      int cmp = comparator.compare(e, elem);
619      if (cmp < 0) {
620        AvlNode<E> initLeft = left;
621        if (initLeft == null) {
622          result[0] = 0;
623          return this;
624        }
625
626        left = initLeft.remove(comparator, e, count, result);
627
628        if (result[0] > 0) {
629          if (count >= result[0]) {
630            this.distinctElements--;
631            this.totalCount -= result[0];
632          } else {
633            this.totalCount -= count;
634          }
635        }
636        return (result[0] == 0) ? this : rebalance();
637      } else if (cmp > 0) {
638        AvlNode<E> initRight = right;
639        if (initRight == null) {
640          result[0] = 0;
641          return this;
642        }
643
644        right = initRight.remove(comparator, e, count, result);
645
646        if (result[0] > 0) {
647          if (count >= result[0]) {
648            this.distinctElements--;
649            this.totalCount -= result[0];
650          } else {
651            this.totalCount -= count;
652          }
653        }
654        return rebalance();
655      }
656
657      // removing count from me!
658      result[0] = elemCount;
659      if (count >= elemCount) {
660        return deleteMe();
661      } else {
662        this.elemCount -= count;
663        this.totalCount -= count;
664        return this;
665      }
666    }
667
668    AvlNode<E> setCount(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
669      int cmp = comparator.compare(e, elem);
670      if (cmp < 0) {
671        AvlNode<E> initLeft = left;
672        if (initLeft == null) {
673          result[0] = 0;
674          return (count > 0) ? addLeftChild(e, count) : this;
675        }
676
677        left = initLeft.setCount(comparator, e, count, result);
678
679        if (count == 0 && result[0] != 0) {
680          this.distinctElements--;
681        } else if (count > 0 && result[0] == 0) {
682          this.distinctElements++;
683        }
684
685        this.totalCount += count - result[0];
686        return rebalance();
687      } else if (cmp > 0) {
688        AvlNode<E> initRight = right;
689        if (initRight == null) {
690          result[0] = 0;
691          return (count > 0) ? addRightChild(e, count) : this;
692        }
693
694        right = initRight.setCount(comparator, e, count, result);
695
696        if (count == 0 && result[0] != 0) {
697          this.distinctElements--;
698        } else if (count > 0 && result[0] == 0) {
699          this.distinctElements++;
700        }
701
702        this.totalCount += count - result[0];
703        return rebalance();
704      }
705
706      // setting my count
707      result[0] = elemCount;
708      if (count == 0) {
709        return deleteMe();
710      }
711      this.totalCount += count - elemCount;
712      this.elemCount = count;
713      return this;
714    }
715
716    AvlNode<E> setCount(
717        Comparator<? super E> comparator,
718        @Nullable E e,
719        int expectedCount,
720        int newCount,
721        int[] result) {
722      int cmp = comparator.compare(e, elem);
723      if (cmp < 0) {
724        AvlNode<E> initLeft = left;
725        if (initLeft == null) {
726          result[0] = 0;
727          if (expectedCount == 0 && newCount > 0) {
728            return addLeftChild(e, newCount);
729          }
730          return this;
731        }
732
733        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
734
735        if (result[0] == expectedCount) {
736          if (newCount == 0 && result[0] != 0) {
737            this.distinctElements--;
738          } else if (newCount > 0 && result[0] == 0) {
739            this.distinctElements++;
740          }
741          this.totalCount += newCount - result[0];
742        }
743        return rebalance();
744      } else if (cmp > 0) {
745        AvlNode<E> initRight = right;
746        if (initRight == null) {
747          result[0] = 0;
748          if (expectedCount == 0 && newCount > 0) {
749            return addRightChild(e, newCount);
750          }
751          return this;
752        }
753
754        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
755
756        if (result[0] == expectedCount) {
757          if (newCount == 0 && result[0] != 0) {
758            this.distinctElements--;
759          } else if (newCount > 0 && result[0] == 0) {
760            this.distinctElements++;
761          }
762          this.totalCount += newCount - result[0];
763        }
764        return rebalance();
765      }
766
767      // setting my count
768      result[0] = elemCount;
769      if (expectedCount == elemCount) {
770        if (newCount == 0) {
771          return deleteMe();
772        }
773        this.totalCount += newCount - elemCount;
774        this.elemCount = newCount;
775      }
776      return this;
777    }
778
779    private AvlNode<E> deleteMe() {
780      int oldElemCount = this.elemCount;
781      this.elemCount = 0;
782      successor(pred, succ);
783      if (left == null) {
784        return right;
785      } else if (right == null) {
786        return left;
787      } else if (left.height >= right.height) {
788        AvlNode<E> newTop = pred;
789        // newTop is the maximum node in my left subtree
790        newTop.left = left.removeMax(newTop);
791        newTop.right = right;
792        newTop.distinctElements = distinctElements - 1;
793        newTop.totalCount = totalCount - oldElemCount;
794        return newTop.rebalance();
795      } else {
796        AvlNode<E> newTop = succ;
797        newTop.right = right.removeMin(newTop);
798        newTop.left = left;
799        newTop.distinctElements = distinctElements - 1;
800        newTop.totalCount = totalCount - oldElemCount;
801        return newTop.rebalance();
802      }
803    }
804
805    // Removes the minimum node from this subtree to be reused elsewhere
806    private AvlNode<E> removeMin(AvlNode<E> node) {
807      if (left == null) {
808        return right;
809      } else {
810        left = left.removeMin(node);
811        distinctElements--;
812        totalCount -= node.elemCount;
813        return rebalance();
814      }
815    }
816
817    // Removes the maximum node from this subtree to be reused elsewhere
818    private AvlNode<E> removeMax(AvlNode<E> node) {
819      if (right == null) {
820        return left;
821      } else {
822        right = right.removeMax(node);
823        distinctElements--;
824        totalCount -= node.elemCount;
825        return rebalance();
826      }
827    }
828
829    private void recomputeMultiset() {
830      this.distinctElements = 1 + TreeMultiset.distinctElements(left)
831          + TreeMultiset.distinctElements(right);
832      this.totalCount = elemCount + totalCount(left) + totalCount(right);
833    }
834
835    private void recomputeHeight() {
836      this.height = 1 + Math.max(height(left), height(right));
837    }
838
839    private void recompute() {
840      recomputeMultiset();
841      recomputeHeight();
842    }
843
844    private AvlNode<E> rebalance() {
845      switch (balanceFactor()) {
846        case -2:
847          if (right.balanceFactor() > 0) {
848            right = right.rotateRight();
849          }
850          return rotateLeft();
851        case 2:
852          if (left.balanceFactor() < 0) {
853            left = left.rotateLeft();
854          }
855          return rotateRight();
856        default:
857          recomputeHeight();
858          return this;
859      }
860    }
861
862    private int balanceFactor() {
863      return height(left) - height(right);
864    }
865
866    private AvlNode<E> rotateLeft() {
867      checkState(right != null);
868      AvlNode<E> newTop = right;
869      this.right = newTop.left;
870      newTop.left = this;
871      newTop.totalCount = this.totalCount;
872      newTop.distinctElements = this.distinctElements;
873      this.recompute();
874      newTop.recomputeHeight();
875      return newTop;
876    }
877
878    private AvlNode<E> rotateRight() {
879      checkState(left != null);
880      AvlNode<E> newTop = left;
881      this.left = newTop.right;
882      newTop.right = this;
883      newTop.totalCount = this.totalCount;
884      newTop.distinctElements = this.distinctElements;
885      this.recompute();
886      newTop.recomputeHeight();
887      return newTop;
888    }
889
890    private static long totalCount(@Nullable AvlNode<?> node) {
891      return (node == null) ? 0 : node.totalCount;
892    }
893
894    private static int height(@Nullable AvlNode<?> node) {
895      return (node == null) ? 0 : node.height;
896    }
897
898    @Nullable private AvlNode<E> ceiling(Comparator<? super E> comparator, E e) {
899      int cmp = comparator.compare(e, elem);
900      if (cmp < 0) {
901        return (left == null) ? this : Objects.firstNonNull(left.ceiling(comparator, e), this);
902      } else if (cmp == 0) {
903        return this;
904      } else {
905        return (right == null) ? null : right.ceiling(comparator, e);
906      }
907    }
908
909    @Nullable private AvlNode<E> floor(Comparator<? super E> comparator, E e) {
910      int cmp = comparator.compare(e, elem);
911      if (cmp > 0) {
912        return (right == null) ? this : Objects.firstNonNull(right.floor(comparator, e), this);
913      } else if (cmp == 0) {
914        return this;
915      } else {
916        return (left == null) ? null : left.floor(comparator, e);
917      }
918    }
919
920    @Override
921    public E getElement() {
922      return elem;
923    }
924
925    @Override
926    public int getCount() {
927      return elemCount;
928    }
929
930    @Override
931    public String toString() {
932      return Multisets.immutableEntry(getElement(), getCount()).toString();
933    }
934  }
935
936  private static <T> void successor(AvlNode<T> a, AvlNode<T> b) {
937    a.succ = b;
938    b.pred = a;
939  }
940
941  private static <T> void successor(AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
942    successor(a, b);
943    successor(b, c);
944  }
945
946  /*
947   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
948   * calls the comparator to compare the two keys. If that change is made,
949   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
950   */
951
952  /**
953   * @serialData the comparator, the number of distinct elements, the first element, its count, the
954   *             second element, its count, and so on
955   */
956  @GwtIncompatible("java.io.ObjectOutputStream")
957  private void writeObject(ObjectOutputStream stream) throws IOException {
958    stream.defaultWriteObject();
959    stream.writeObject(elementSet().comparator());
960    Serialization.writeMultiset(this, stream);
961  }
962
963  @GwtIncompatible("java.io.ObjectInputStream")
964  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
965    stream.defaultReadObject();
966    @SuppressWarnings("unchecked")
967    // reading data stored by writeObject
968    Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
969    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
970    Serialization.getFieldSetter(TreeMultiset.class, "range").set(
971        this,
972        GeneralRange.all(comparator));
973    Serialization.getFieldSetter(TreeMultiset.class, "rootReference").set(
974        this,
975        new Reference<AvlNode<E>>());
976    AvlNode<E> header = new AvlNode<E>(null, 1);
977    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
978    successor(header, header);
979    Serialization.populateMultiset(this, stream);
980  }
981
982  @GwtIncompatible("not needed in emulated source") private static final long serialVersionUID = 1;
983}