001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkState;
021import static com.google.common.collect.CollectPreconditions.checkNonnegative;
022import static com.google.common.collect.CollectPreconditions.checkRemove;
023
024import com.google.common.annotations.GwtCompatible;
025import com.google.common.annotations.GwtIncompatible;
026import com.google.common.base.MoreObjects;
027import com.google.common.primitives.Ints;
028import com.google.errorprone.annotations.CanIgnoreReturnValue;
029import java.io.IOException;
030import java.io.ObjectInputStream;
031import java.io.ObjectOutputStream;
032import java.io.Serializable;
033import java.util.Comparator;
034import java.util.ConcurrentModificationException;
035import java.util.Iterator;
036import java.util.NoSuchElementException;
037import javax.annotation.Nullable;
038
039/**
040 * A multiset which maintains the ordering of its elements, according to either their natural order
041 * or an explicit {@link Comparator}. In all cases, this implementation uses
042 * {@link Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to
043 * determine equivalence of instances.
044 *
045 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
046 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the
047 * {@link java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
048 *
049 * <p>See the Guava User Guide article on <a href=
050 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset">
051 * {@code Multiset}</a>.
052 *
053 * @author Louis Wasserman
054 * @author Jared Levy
055 * @since 2.0
056 */
057@GwtCompatible(emulated = true)
058public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements Serializable {
059
060  /**
061   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
062   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
063   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
064   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
065   * user attempts to add an element to the multiset that violates this constraint (for example,
066   * the user attempts to add a string element to a set whose elements are integers), the
067   * {@code add(Object)} call will throw a {@code ClassCastException}.
068   *
069   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
070   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
071   */
072  public static <E extends Comparable> TreeMultiset<E> create() {
073    return new TreeMultiset<E>(Ordering.natural());
074  }
075
076  /**
077   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
078   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
079   * {@code comparator.compare(e1,
080   * e2)} must not throw a {@code ClassCastException} for any elements {@code e1} and {@code e2} in
081   * the multiset. If the user attempts to add an element to the multiset that violates this
082   * constraint, the {@code add(Object)} call will throw a {@code ClassCastException}.
083   *
084   * @param comparator
085   *          the comparator that will be used to sort this multiset. A null value indicates that
086   *          the elements' <i>natural ordering</i> should be used.
087   */
088  @SuppressWarnings("unchecked")
089  public static <E> TreeMultiset<E> create(@Nullable Comparator<? super E> comparator) {
090    return (comparator == null)
091        ? new TreeMultiset<E>((Comparator) Ordering.natural())
092        : new TreeMultiset<E>(comparator);
093  }
094
095  /**
096   * Creates an empty multiset containing the given initial elements, sorted according to the
097   * elements' natural order.
098   *
099   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
100   *
101   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
102   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
103   */
104  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
105    TreeMultiset<E> multiset = create();
106    Iterables.addAll(multiset, elements);
107    return multiset;
108  }
109
110  private final transient Reference<AvlNode<E>> rootReference;
111  private final transient GeneralRange<E> range;
112  private final transient AvlNode<E> header;
113
114  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
115    super(range.comparator());
116    this.rootReference = rootReference;
117    this.range = range;
118    this.header = endLink;
119  }
120
121  TreeMultiset(Comparator<? super E> comparator) {
122    super(comparator);
123    this.range = GeneralRange.all(comparator);
124    this.header = new AvlNode<E>(null, 1);
125    successor(header, header);
126    this.rootReference = new Reference<AvlNode<E>>();
127  }
128
129  /**
130   * A function which can be summed across a subtree.
131   */
132  private enum Aggregate {
133    SIZE {
134      @Override
135      int nodeAggregate(AvlNode<?> node) {
136        return node.elemCount;
137      }
138
139      @Override
140      long treeAggregate(@Nullable AvlNode<?> root) {
141        return (root == null) ? 0 : root.totalCount;
142      }
143    },
144    DISTINCT {
145      @Override
146      int nodeAggregate(AvlNode<?> node) {
147        return 1;
148      }
149
150      @Override
151      long treeAggregate(@Nullable AvlNode<?> root) {
152        return (root == null) ? 0 : root.distinctElements;
153      }
154    };
155
156    abstract int nodeAggregate(AvlNode<?> node);
157
158    abstract long treeAggregate(@Nullable AvlNode<?> root);
159  }
160
161  private long aggregateForEntries(Aggregate aggr) {
162    AvlNode<E> root = rootReference.get();
163    long total = aggr.treeAggregate(root);
164    if (range.hasLowerBound()) {
165      total -= aggregateBelowRange(aggr, root);
166    }
167    if (range.hasUpperBound()) {
168      total -= aggregateAboveRange(aggr, root);
169    }
170    return total;
171  }
172
173  private long aggregateBelowRange(Aggregate aggr, @Nullable AvlNode<E> node) {
174    if (node == null) {
175      return 0;
176    }
177    int cmp = comparator().compare(range.getLowerEndpoint(), node.elem);
178    if (cmp < 0) {
179      return aggregateBelowRange(aggr, node.left);
180    } else if (cmp == 0) {
181      switch (range.getLowerBoundType()) {
182        case OPEN:
183          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
184        case CLOSED:
185          return aggr.treeAggregate(node.left);
186        default:
187          throw new AssertionError();
188      }
189    } else {
190      return aggr.treeAggregate(node.left)
191          + aggr.nodeAggregate(node)
192          + aggregateBelowRange(aggr, node.right);
193    }
194  }
195
196  private long aggregateAboveRange(Aggregate aggr, @Nullable AvlNode<E> node) {
197    if (node == null) {
198      return 0;
199    }
200    int cmp = comparator().compare(range.getUpperEndpoint(), node.elem);
201    if (cmp > 0) {
202      return aggregateAboveRange(aggr, node.right);
203    } else if (cmp == 0) {
204      switch (range.getUpperBoundType()) {
205        case OPEN:
206          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
207        case CLOSED:
208          return aggr.treeAggregate(node.right);
209        default:
210          throw new AssertionError();
211      }
212    } else {
213      return aggr.treeAggregate(node.right)
214          + aggr.nodeAggregate(node)
215          + aggregateAboveRange(aggr, node.left);
216    }
217  }
218
219  @Override
220  public int size() {
221    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
222  }
223
224  @Override
225  int distinctElements() {
226    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
227  }
228
229  @Override
230  public int count(@Nullable Object element) {
231    try {
232      @SuppressWarnings("unchecked")
233      E e = (E) element;
234      AvlNode<E> root = rootReference.get();
235      if (!range.contains(e) || root == null) {
236        return 0;
237      }
238      return root.count(comparator(), e);
239    } catch (ClassCastException e) {
240      return 0;
241    } catch (NullPointerException e) {
242      return 0;
243    }
244  }
245
246  @CanIgnoreReturnValue
247  @Override
248  public int add(@Nullable E element, int occurrences) {
249    checkNonnegative(occurrences, "occurrences");
250    if (occurrences == 0) {
251      return count(element);
252    }
253    checkArgument(range.contains(element));
254    AvlNode<E> root = rootReference.get();
255    if (root == null) {
256      comparator().compare(element, element);
257      AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
258      successor(header, newRoot, header);
259      rootReference.checkAndSet(root, newRoot);
260      return 0;
261    }
262    int[] result = new int[1]; // used as a mutable int reference to hold result
263    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
264    rootReference.checkAndSet(root, newRoot);
265    return result[0];
266  }
267
268  @CanIgnoreReturnValue
269  @Override
270  public int remove(@Nullable Object element, int occurrences) {
271    checkNonnegative(occurrences, "occurrences");
272    if (occurrences == 0) {
273      return count(element);
274    }
275    AvlNode<E> root = rootReference.get();
276    int[] result = new int[1]; // used as a mutable int reference to hold result
277    AvlNode<E> newRoot;
278    try {
279      @SuppressWarnings("unchecked")
280      E e = (E) element;
281      if (!range.contains(e) || root == null) {
282        return 0;
283      }
284      newRoot = root.remove(comparator(), e, occurrences, result);
285    } catch (ClassCastException e) {
286      return 0;
287    } catch (NullPointerException e) {
288      return 0;
289    }
290    rootReference.checkAndSet(root, newRoot);
291    return result[0];
292  }
293
294  @CanIgnoreReturnValue
295  @Override
296  public int setCount(@Nullable E element, int count) {
297    checkNonnegative(count, "count");
298    if (!range.contains(element)) {
299      checkArgument(count == 0);
300      return 0;
301    }
302
303    AvlNode<E> root = rootReference.get();
304    if (root == null) {
305      if (count > 0) {
306        add(element, count);
307      }
308      return 0;
309    }
310    int[] result = new int[1]; // used as a mutable int reference to hold result
311    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
312    rootReference.checkAndSet(root, newRoot);
313    return result[0];
314  }
315
316  @CanIgnoreReturnValue
317  @Override
318  public boolean setCount(@Nullable E element, int oldCount, int newCount) {
319    checkNonnegative(newCount, "newCount");
320    checkNonnegative(oldCount, "oldCount");
321    checkArgument(range.contains(element));
322
323    AvlNode<E> root = rootReference.get();
324    if (root == null) {
325      if (oldCount == 0) {
326        if (newCount > 0) {
327          add(element, newCount);
328        }
329        return true;
330      } else {
331        return false;
332      }
333    }
334    int[] result = new int[1]; // used as a mutable int reference to hold result
335    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
336    rootReference.checkAndSet(root, newRoot);
337    return result[0] == oldCount;
338  }
339
340  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
341    return new Multisets.AbstractEntry<E>() {
342      @Override
343      public E getElement() {
344        return baseEntry.getElement();
345      }
346
347      @Override
348      public int getCount() {
349        int result = baseEntry.getCount();
350        if (result == 0) {
351          return count(getElement());
352        } else {
353          return result;
354        }
355      }
356    };
357  }
358
359  /**
360   * Returns the first node in the tree that is in range.
361   */
362  @Nullable
363  private AvlNode<E> firstNode() {
364    AvlNode<E> root = rootReference.get();
365    if (root == null) {
366      return null;
367    }
368    AvlNode<E> node;
369    if (range.hasLowerBound()) {
370      E endpoint = range.getLowerEndpoint();
371      node = rootReference.get().ceiling(comparator(), endpoint);
372      if (node == null) {
373        return null;
374      }
375      if (range.getLowerBoundType() == BoundType.OPEN
376          && comparator().compare(endpoint, node.getElement()) == 0) {
377        node = node.succ;
378      }
379    } else {
380      node = header.succ;
381    }
382    return (node == header || !range.contains(node.getElement())) ? null : node;
383  }
384
385  @Nullable
386  private AvlNode<E> lastNode() {
387    AvlNode<E> root = rootReference.get();
388    if (root == null) {
389      return null;
390    }
391    AvlNode<E> node;
392    if (range.hasUpperBound()) {
393      E endpoint = range.getUpperEndpoint();
394      node = rootReference.get().floor(comparator(), endpoint);
395      if (node == null) {
396        return null;
397      }
398      if (range.getUpperBoundType() == BoundType.OPEN
399          && comparator().compare(endpoint, node.getElement()) == 0) {
400        node = node.pred;
401      }
402    } else {
403      node = header.pred;
404    }
405    return (node == header || !range.contains(node.getElement())) ? null : node;
406  }
407
408  @Override
409  Iterator<Entry<E>> entryIterator() {
410    return new Iterator<Entry<E>>() {
411      AvlNode<E> current = firstNode();
412      Entry<E> prevEntry;
413
414      @Override
415      public boolean hasNext() {
416        if (current == null) {
417          return false;
418        } else if (range.tooHigh(current.getElement())) {
419          current = null;
420          return false;
421        } else {
422          return true;
423        }
424      }
425
426      @Override
427      public Entry<E> next() {
428        if (!hasNext()) {
429          throw new NoSuchElementException();
430        }
431        Entry<E> result = wrapEntry(current);
432        prevEntry = result;
433        if (current.succ == header) {
434          current = null;
435        } else {
436          current = current.succ;
437        }
438        return result;
439      }
440
441      @Override
442      public void remove() {
443        checkRemove(prevEntry != null);
444        setCount(prevEntry.getElement(), 0);
445        prevEntry = null;
446      }
447    };
448  }
449
450  @Override
451  Iterator<Entry<E>> descendingEntryIterator() {
452    return new Iterator<Entry<E>>() {
453      AvlNode<E> current = lastNode();
454      Entry<E> prevEntry = null;
455
456      @Override
457      public boolean hasNext() {
458        if (current == null) {
459          return false;
460        } else if (range.tooLow(current.getElement())) {
461          current = null;
462          return false;
463        } else {
464          return true;
465        }
466      }
467
468      @Override
469      public Entry<E> next() {
470        if (!hasNext()) {
471          throw new NoSuchElementException();
472        }
473        Entry<E> result = wrapEntry(current);
474        prevEntry = result;
475        if (current.pred == header) {
476          current = null;
477        } else {
478          current = current.pred;
479        }
480        return result;
481      }
482
483      @Override
484      public void remove() {
485        checkRemove(prevEntry != null);
486        setCount(prevEntry.getElement(), 0);
487        prevEntry = null;
488      }
489    };
490  }
491
492  @Override
493  public SortedMultiset<E> headMultiset(@Nullable E upperBound, BoundType boundType) {
494    return new TreeMultiset<E>(
495        rootReference,
496        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
497        header);
498  }
499
500  @Override
501  public SortedMultiset<E> tailMultiset(@Nullable E lowerBound, BoundType boundType) {
502    return new TreeMultiset<E>(
503        rootReference,
504        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
505        header);
506  }
507
508  static int distinctElements(@Nullable AvlNode<?> node) {
509    return (node == null) ? 0 : node.distinctElements;
510  }
511
512  private static final class Reference<T> {
513    @Nullable private T value;
514
515    @Nullable
516    public T get() {
517      return value;
518    }
519
520    public void checkAndSet(@Nullable T expected, T newValue) {
521      if (value != expected) {
522        throw new ConcurrentModificationException();
523      }
524      value = newValue;
525    }
526  }
527
528  private static final class AvlNode<E> extends Multisets.AbstractEntry<E> {
529    @Nullable private final E elem;
530
531    // elemCount is 0 iff this node has been deleted.
532    private int elemCount;
533
534    private int distinctElements;
535    private long totalCount;
536    private int height;
537    private AvlNode<E> left;
538    private AvlNode<E> right;
539    private AvlNode<E> pred;
540    private AvlNode<E> succ;
541
542    AvlNode(@Nullable E elem, int elemCount) {
543      checkArgument(elemCount > 0);
544      this.elem = elem;
545      this.elemCount = elemCount;
546      this.totalCount = elemCount;
547      this.distinctElements = 1;
548      this.height = 1;
549      this.left = null;
550      this.right = null;
551    }
552
553    public int count(Comparator<? super E> comparator, E e) {
554      int cmp = comparator.compare(e, elem);
555      if (cmp < 0) {
556        return (left == null) ? 0 : left.count(comparator, e);
557      } else if (cmp > 0) {
558        return (right == null) ? 0 : right.count(comparator, e);
559      } else {
560        return elemCount;
561      }
562    }
563
564    private AvlNode<E> addRightChild(E e, int count) {
565      right = new AvlNode<E>(e, count);
566      successor(this, right, succ);
567      height = Math.max(2, height);
568      distinctElements++;
569      totalCount += count;
570      return this;
571    }
572
573    private AvlNode<E> addLeftChild(E e, int count) {
574      left = new AvlNode<E>(e, count);
575      successor(pred, left, this);
576      height = Math.max(2, height);
577      distinctElements++;
578      totalCount += count;
579      return this;
580    }
581
582    AvlNode<E> add(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
583      /*
584       * It speeds things up considerably to unconditionally add count to totalCount here,
585       * but that destroys failure atomicity in the case of count overflow. =(
586       */
587      int cmp = comparator.compare(e, elem);
588      if (cmp < 0) {
589        AvlNode<E> initLeft = left;
590        if (initLeft == null) {
591          result[0] = 0;
592          return addLeftChild(e, count);
593        }
594        int initHeight = initLeft.height;
595
596        left = initLeft.add(comparator, e, count, result);
597        if (result[0] == 0) {
598          distinctElements++;
599        }
600        this.totalCount += count;
601        return (left.height == initHeight) ? this : rebalance();
602      } else if (cmp > 0) {
603        AvlNode<E> initRight = right;
604        if (initRight == null) {
605          result[0] = 0;
606          return addRightChild(e, count);
607        }
608        int initHeight = initRight.height;
609
610        right = initRight.add(comparator, e, count, result);
611        if (result[0] == 0) {
612          distinctElements++;
613        }
614        this.totalCount += count;
615        return (right.height == initHeight) ? this : rebalance();
616      }
617
618      // adding count to me!  No rebalance possible.
619      result[0] = elemCount;
620      long resultCount = (long) elemCount + count;
621      checkArgument(resultCount <= Integer.MAX_VALUE);
622      this.elemCount += count;
623      this.totalCount += count;
624      return this;
625    }
626
627    AvlNode<E> remove(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
628      int cmp = comparator.compare(e, elem);
629      if (cmp < 0) {
630        AvlNode<E> initLeft = left;
631        if (initLeft == null) {
632          result[0] = 0;
633          return this;
634        }
635
636        left = initLeft.remove(comparator, e, count, result);
637
638        if (result[0] > 0) {
639          if (count >= result[0]) {
640            this.distinctElements--;
641            this.totalCount -= result[0];
642          } else {
643            this.totalCount -= count;
644          }
645        }
646        return (result[0] == 0) ? this : rebalance();
647      } else if (cmp > 0) {
648        AvlNode<E> initRight = right;
649        if (initRight == null) {
650          result[0] = 0;
651          return this;
652        }
653
654        right = initRight.remove(comparator, e, count, result);
655
656        if (result[0] > 0) {
657          if (count >= result[0]) {
658            this.distinctElements--;
659            this.totalCount -= result[0];
660          } else {
661            this.totalCount -= count;
662          }
663        }
664        return rebalance();
665      }
666
667      // removing count from me!
668      result[0] = elemCount;
669      if (count >= elemCount) {
670        return deleteMe();
671      } else {
672        this.elemCount -= count;
673        this.totalCount -= count;
674        return this;
675      }
676    }
677
678    AvlNode<E> setCount(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
679      int cmp = comparator.compare(e, elem);
680      if (cmp < 0) {
681        AvlNode<E> initLeft = left;
682        if (initLeft == null) {
683          result[0] = 0;
684          return (count > 0) ? addLeftChild(e, count) : this;
685        }
686
687        left = initLeft.setCount(comparator, e, count, result);
688
689        if (count == 0 && result[0] != 0) {
690          this.distinctElements--;
691        } else if (count > 0 && result[0] == 0) {
692          this.distinctElements++;
693        }
694
695        this.totalCount += count - result[0];
696        return rebalance();
697      } else if (cmp > 0) {
698        AvlNode<E> initRight = right;
699        if (initRight == null) {
700          result[0] = 0;
701          return (count > 0) ? addRightChild(e, count) : this;
702        }
703
704        right = initRight.setCount(comparator, e, count, result);
705
706        if (count == 0 && result[0] != 0) {
707          this.distinctElements--;
708        } else if (count > 0 && result[0] == 0) {
709          this.distinctElements++;
710        }
711
712        this.totalCount += count - result[0];
713        return rebalance();
714      }
715
716      // setting my count
717      result[0] = elemCount;
718      if (count == 0) {
719        return deleteMe();
720      }
721      this.totalCount += count - elemCount;
722      this.elemCount = count;
723      return this;
724    }
725
726    AvlNode<E> setCount(
727        Comparator<? super E> comparator,
728        @Nullable E e,
729        int expectedCount,
730        int newCount,
731        int[] result) {
732      int cmp = comparator.compare(e, elem);
733      if (cmp < 0) {
734        AvlNode<E> initLeft = left;
735        if (initLeft == null) {
736          result[0] = 0;
737          if (expectedCount == 0 && newCount > 0) {
738            return addLeftChild(e, newCount);
739          }
740          return this;
741        }
742
743        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
744
745        if (result[0] == expectedCount) {
746          if (newCount == 0 && result[0] != 0) {
747            this.distinctElements--;
748          } else if (newCount > 0 && result[0] == 0) {
749            this.distinctElements++;
750          }
751          this.totalCount += newCount - result[0];
752        }
753        return rebalance();
754      } else if (cmp > 0) {
755        AvlNode<E> initRight = right;
756        if (initRight == null) {
757          result[0] = 0;
758          if (expectedCount == 0 && newCount > 0) {
759            return addRightChild(e, newCount);
760          }
761          return this;
762        }
763
764        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
765
766        if (result[0] == expectedCount) {
767          if (newCount == 0 && result[0] != 0) {
768            this.distinctElements--;
769          } else if (newCount > 0 && result[0] == 0) {
770            this.distinctElements++;
771          }
772          this.totalCount += newCount - result[0];
773        }
774        return rebalance();
775      }
776
777      // setting my count
778      result[0] = elemCount;
779      if (expectedCount == elemCount) {
780        if (newCount == 0) {
781          return deleteMe();
782        }
783        this.totalCount += newCount - elemCount;
784        this.elemCount = newCount;
785      }
786      return this;
787    }
788
789    private AvlNode<E> deleteMe() {
790      int oldElemCount = this.elemCount;
791      this.elemCount = 0;
792      successor(pred, succ);
793      if (left == null) {
794        return right;
795      } else if (right == null) {
796        return left;
797      } else if (left.height >= right.height) {
798        AvlNode<E> newTop = pred;
799        // newTop is the maximum node in my left subtree
800        newTop.left = left.removeMax(newTop);
801        newTop.right = right;
802        newTop.distinctElements = distinctElements - 1;
803        newTop.totalCount = totalCount - oldElemCount;
804        return newTop.rebalance();
805      } else {
806        AvlNode<E> newTop = succ;
807        newTop.right = right.removeMin(newTop);
808        newTop.left = left;
809        newTop.distinctElements = distinctElements - 1;
810        newTop.totalCount = totalCount - oldElemCount;
811        return newTop.rebalance();
812      }
813    }
814
815    // Removes the minimum node from this subtree to be reused elsewhere
816    private AvlNode<E> removeMin(AvlNode<E> node) {
817      if (left == null) {
818        return right;
819      } else {
820        left = left.removeMin(node);
821        distinctElements--;
822        totalCount -= node.elemCount;
823        return rebalance();
824      }
825    }
826
827    // Removes the maximum node from this subtree to be reused elsewhere
828    private AvlNode<E> removeMax(AvlNode<E> node) {
829      if (right == null) {
830        return left;
831      } else {
832        right = right.removeMax(node);
833        distinctElements--;
834        totalCount -= node.elemCount;
835        return rebalance();
836      }
837    }
838
839    private void recomputeMultiset() {
840      this.distinctElements =
841          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
842      this.totalCount = elemCount + totalCount(left) + totalCount(right);
843    }
844
845    private void recomputeHeight() {
846      this.height = 1 + Math.max(height(left), height(right));
847    }
848
849    private void recompute() {
850      recomputeMultiset();
851      recomputeHeight();
852    }
853
854    private AvlNode<E> rebalance() {
855      switch (balanceFactor()) {
856        case -2:
857          if (right.balanceFactor() > 0) {
858            right = right.rotateRight();
859          }
860          return rotateLeft();
861        case 2:
862          if (left.balanceFactor() < 0) {
863            left = left.rotateLeft();
864          }
865          return rotateRight();
866        default:
867          recomputeHeight();
868          return this;
869      }
870    }
871
872    private int balanceFactor() {
873      return height(left) - height(right);
874    }
875
876    private AvlNode<E> rotateLeft() {
877      checkState(right != null);
878      AvlNode<E> newTop = right;
879      this.right = newTop.left;
880      newTop.left = this;
881      newTop.totalCount = this.totalCount;
882      newTop.distinctElements = this.distinctElements;
883      this.recompute();
884      newTop.recomputeHeight();
885      return newTop;
886    }
887
888    private AvlNode<E> rotateRight() {
889      checkState(left != null);
890      AvlNode<E> newTop = left;
891      this.left = newTop.right;
892      newTop.right = this;
893      newTop.totalCount = this.totalCount;
894      newTop.distinctElements = this.distinctElements;
895      this.recompute();
896      newTop.recomputeHeight();
897      return newTop;
898    }
899
900    private static long totalCount(@Nullable AvlNode<?> node) {
901      return (node == null) ? 0 : node.totalCount;
902    }
903
904    private static int height(@Nullable AvlNode<?> node) {
905      return (node == null) ? 0 : node.height;
906    }
907
908    @Nullable
909    private AvlNode<E> ceiling(Comparator<? super E> comparator, E e) {
910      int cmp = comparator.compare(e, elem);
911      if (cmp < 0) {
912        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
913      } else if (cmp == 0) {
914        return this;
915      } else {
916        return (right == null) ? null : right.ceiling(comparator, e);
917      }
918    }
919
920    @Nullable
921    private AvlNode<E> floor(Comparator<? super E> comparator, E e) {
922      int cmp = comparator.compare(e, elem);
923      if (cmp > 0) {
924        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
925      } else if (cmp == 0) {
926        return this;
927      } else {
928        return (left == null) ? null : left.floor(comparator, e);
929      }
930    }
931
932    @Override
933    public E getElement() {
934      return elem;
935    }
936
937    @Override
938    public int getCount() {
939      return elemCount;
940    }
941
942    @Override
943    public String toString() {
944      return Multisets.immutableEntry(getElement(), getCount()).toString();
945    }
946  }
947
948  private static <T> void successor(AvlNode<T> a, AvlNode<T> b) {
949    a.succ = b;
950    b.pred = a;
951  }
952
953  private static <T> void successor(AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
954    successor(a, b);
955    successor(b, c);
956  }
957
958  /*
959   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
960   * calls the comparator to compare the two keys. If that change is made,
961   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
962   */
963
964  /**
965   * @serialData the comparator, the number of distinct elements, the first element, its count, the
966   *             second element, its count, and so on
967   */
968  @GwtIncompatible // java.io.ObjectOutputStream
969  private void writeObject(ObjectOutputStream stream) throws IOException {
970    stream.defaultWriteObject();
971    stream.writeObject(elementSet().comparator());
972    Serialization.writeMultiset(this, stream);
973  }
974
975  @GwtIncompatible // java.io.ObjectInputStream
976  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
977    stream.defaultReadObject();
978    @SuppressWarnings("unchecked")
979    // reading data stored by writeObject
980    Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
981    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
982    Serialization.getFieldSetter(TreeMultiset.class, "range")
983        .set(this, GeneralRange.all(comparator));
984    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
985        .set(this, new Reference<AvlNode<E>>());
986    AvlNode<E> header = new AvlNode<E>(null, 1);
987    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
988    successor(header, header);
989    Serialization.populateMultiset(this, stream);
990  }
991
992  @GwtIncompatible // not needed in emulated source
993  private static final long serialVersionUID = 1;
994}