001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkState;
021import static com.google.common.collect.CollectPreconditions.checkNonnegative;
022import static com.google.common.collect.CollectPreconditions.checkRemove;
023
024import com.google.common.annotations.GwtCompatible;
025import com.google.common.annotations.GwtIncompatible;
026import com.google.common.base.MoreObjects;
027import com.google.common.primitives.Ints;
028import com.google.errorprone.annotations.CanIgnoreReturnValue;
029import java.io.IOException;
030import java.io.ObjectInputStream;
031import java.io.ObjectOutputStream;
032import java.io.Serializable;
033import java.util.Comparator;
034import java.util.ConcurrentModificationException;
035import java.util.Iterator;
036import java.util.NoSuchElementException;
037import org.checkerframework.checker.nullness.compatqual.NullableDecl;
038
039/**
040 * A multiset which maintains the ordering of its elements, according to either their natural order
041 * or an explicit {@link Comparator}. In all cases, this implementation uses {@link
042 * Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to determine
043 * equivalence of instances.
044 *
045 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
046 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the {@link
047 * java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
048 *
049 * <p>See the Guava User Guide article on <a href=
050 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset"> {@code
051 * Multiset}</a>.
052 *
053 * @author Louis Wasserman
054 * @author Jared Levy
055 * @since 2.0
056 */
057@GwtCompatible(emulated = true)
058public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements Serializable {
059
060  /**
061   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
062   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
063   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
064   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
065   * user attempts to add an element to the multiset that violates this constraint (for example, the
066   * user attempts to add a string element to a set whose elements are integers), the {@code
067   * add(Object)} call will throw a {@code ClassCastException}.
068   *
069   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
070   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
071   */
072  public static <E extends Comparable> TreeMultiset<E> create() {
073    return new TreeMultiset<E>(Ordering.natural());
074  }
075
076  /**
077   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
078   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
079   * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements
080   * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the
081   * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code
082   * ClassCastException}.
083   *
084   * @param comparator the comparator that will be used to sort this multiset. A null value
085   *     indicates that the elements' <i>natural ordering</i> should be used.
086   */
087  @SuppressWarnings("unchecked")
088  public static <E> TreeMultiset<E> create(@NullableDecl Comparator<? super E> comparator) {
089    return (comparator == null)
090        ? new TreeMultiset<E>((Comparator) Ordering.natural())
091        : new TreeMultiset<E>(comparator);
092  }
093
094  /**
095   * Creates an empty multiset containing the given initial elements, sorted according to the
096   * elements' natural order.
097   *
098   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
099   *
100   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
101   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
102   */
103  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
104    TreeMultiset<E> multiset = create();
105    Iterables.addAll(multiset, elements);
106    return multiset;
107  }
108
109  private final transient Reference<AvlNode<E>> rootReference;
110  private final transient GeneralRange<E> range;
111  private final transient AvlNode<E> header;
112
113  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
114    super(range.comparator());
115    this.rootReference = rootReference;
116    this.range = range;
117    this.header = endLink;
118  }
119
120  TreeMultiset(Comparator<? super E> comparator) {
121    super(comparator);
122    this.range = GeneralRange.all(comparator);
123    this.header = new AvlNode<E>(null, 1);
124    successor(header, header);
125    this.rootReference = new Reference<>();
126  }
127
128  /** A function which can be summed across a subtree. */
129  private enum Aggregate {
130    SIZE {
131      @Override
132      int nodeAggregate(AvlNode<?> node) {
133        return node.elemCount;
134      }
135
136      @Override
137      long treeAggregate(@NullableDecl AvlNode<?> root) {
138        return (root == null) ? 0 : root.totalCount;
139      }
140    },
141    DISTINCT {
142      @Override
143      int nodeAggregate(AvlNode<?> node) {
144        return 1;
145      }
146
147      @Override
148      long treeAggregate(@NullableDecl AvlNode<?> root) {
149        return (root == null) ? 0 : root.distinctElements;
150      }
151    };
152
153    abstract int nodeAggregate(AvlNode<?> node);
154
155    abstract long treeAggregate(@NullableDecl AvlNode<?> root);
156  }
157
158  private long aggregateForEntries(Aggregate aggr) {
159    AvlNode<E> root = rootReference.get();
160    long total = aggr.treeAggregate(root);
161    if (range.hasLowerBound()) {
162      total -= aggregateBelowRange(aggr, root);
163    }
164    if (range.hasUpperBound()) {
165      total -= aggregateAboveRange(aggr, root);
166    }
167    return total;
168  }
169
170  private long aggregateBelowRange(Aggregate aggr, @NullableDecl AvlNode<E> node) {
171    if (node == null) {
172      return 0;
173    }
174    int cmp = comparator().compare(range.getLowerEndpoint(), node.elem);
175    if (cmp < 0) {
176      return aggregateBelowRange(aggr, node.left);
177    } else if (cmp == 0) {
178      switch (range.getLowerBoundType()) {
179        case OPEN:
180          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
181        case CLOSED:
182          return aggr.treeAggregate(node.left);
183        default:
184          throw new AssertionError();
185      }
186    } else {
187      return aggr.treeAggregate(node.left)
188          + aggr.nodeAggregate(node)
189          + aggregateBelowRange(aggr, node.right);
190    }
191  }
192
193  private long aggregateAboveRange(Aggregate aggr, @NullableDecl AvlNode<E> node) {
194    if (node == null) {
195      return 0;
196    }
197    int cmp = comparator().compare(range.getUpperEndpoint(), node.elem);
198    if (cmp > 0) {
199      return aggregateAboveRange(aggr, node.right);
200    } else if (cmp == 0) {
201      switch (range.getUpperBoundType()) {
202        case OPEN:
203          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
204        case CLOSED:
205          return aggr.treeAggregate(node.right);
206        default:
207          throw new AssertionError();
208      }
209    } else {
210      return aggr.treeAggregate(node.right)
211          + aggr.nodeAggregate(node)
212          + aggregateAboveRange(aggr, node.left);
213    }
214  }
215
216  @Override
217  public int size() {
218    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
219  }
220
221  @Override
222  int distinctElements() {
223    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
224  }
225
226  static int distinctElements(@NullableDecl AvlNode<?> node) {
227    return (node == null) ? 0 : node.distinctElements;
228  }
229
230  @Override
231  public int count(@NullableDecl Object element) {
232    try {
233      @SuppressWarnings("unchecked")
234      E e = (E) element;
235      AvlNode<E> root = rootReference.get();
236      if (!range.contains(e) || root == null) {
237        return 0;
238      }
239      return root.count(comparator(), e);
240    } catch (ClassCastException | NullPointerException e) {
241      return 0;
242    }
243  }
244
245  @CanIgnoreReturnValue
246  @Override
247  public int add(@NullableDecl E element, int occurrences) {
248    checkNonnegative(occurrences, "occurrences");
249    if (occurrences == 0) {
250      return count(element);
251    }
252    checkArgument(range.contains(element));
253    AvlNode<E> root = rootReference.get();
254    if (root == null) {
255      comparator().compare(element, element);
256      AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
257      successor(header, newRoot, header);
258      rootReference.checkAndSet(root, newRoot);
259      return 0;
260    }
261    int[] result = new int[1]; // used as a mutable int reference to hold result
262    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
263    rootReference.checkAndSet(root, newRoot);
264    return result[0];
265  }
266
267  @CanIgnoreReturnValue
268  @Override
269  public int remove(@NullableDecl Object element, int occurrences) {
270    checkNonnegative(occurrences, "occurrences");
271    if (occurrences == 0) {
272      return count(element);
273    }
274    AvlNode<E> root = rootReference.get();
275    int[] result = new int[1]; // used as a mutable int reference to hold result
276    AvlNode<E> newRoot;
277    try {
278      @SuppressWarnings("unchecked")
279      E e = (E) element;
280      if (!range.contains(e) || root == null) {
281        return 0;
282      }
283      newRoot = root.remove(comparator(), e, occurrences, result);
284    } catch (ClassCastException | NullPointerException e) {
285      return 0;
286    }
287    rootReference.checkAndSet(root, newRoot);
288    return result[0];
289  }
290
291  @CanIgnoreReturnValue
292  @Override
293  public int setCount(@NullableDecl E element, int count) {
294    checkNonnegative(count, "count");
295    if (!range.contains(element)) {
296      checkArgument(count == 0);
297      return 0;
298    }
299
300    AvlNode<E> root = rootReference.get();
301    if (root == null) {
302      if (count > 0) {
303        add(element, count);
304      }
305      return 0;
306    }
307    int[] result = new int[1]; // used as a mutable int reference to hold result
308    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
309    rootReference.checkAndSet(root, newRoot);
310    return result[0];
311  }
312
313  @CanIgnoreReturnValue
314  @Override
315  public boolean setCount(@NullableDecl E element, int oldCount, int newCount) {
316    checkNonnegative(newCount, "newCount");
317    checkNonnegative(oldCount, "oldCount");
318    checkArgument(range.contains(element));
319
320    AvlNode<E> root = rootReference.get();
321    if (root == null) {
322      if (oldCount == 0) {
323        if (newCount > 0) {
324          add(element, newCount);
325        }
326        return true;
327      } else {
328        return false;
329      }
330    }
331    int[] result = new int[1]; // used as a mutable int reference to hold result
332    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
333    rootReference.checkAndSet(root, newRoot);
334    return result[0] == oldCount;
335  }
336
337  @Override
338  public void clear() {
339    if (!range.hasLowerBound() && !range.hasUpperBound()) {
340      // We can do this in O(n) rather than removing one by one, which could force rebalancing.
341      for (AvlNode<E> current = header.succ; current != header; ) {
342        AvlNode<E> next = current.succ;
343
344        current.elemCount = 0;
345        // Also clear these fields so that one deleted Entry doesn't retain all elements.
346        current.left = null;
347        current.right = null;
348        current.pred = null;
349        current.succ = null;
350
351        current = next;
352      }
353      successor(header, header);
354      rootReference.clear();
355    } else {
356      // TODO(cpovirk): Perhaps we can optimize in this case, too?
357      Iterators.clear(entryIterator());
358    }
359  }
360
361  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
362    return new Multisets.AbstractEntry<E>() {
363      @Override
364      public E getElement() {
365        return baseEntry.getElement();
366      }
367
368      @Override
369      public int getCount() {
370        int result = baseEntry.getCount();
371        if (result == 0) {
372          return count(getElement());
373        } else {
374          return result;
375        }
376      }
377    };
378  }
379
380  /** Returns the first node in the tree that is in range. */
381  @NullableDecl
382  private AvlNode<E> firstNode() {
383    AvlNode<E> root = rootReference.get();
384    if (root == null) {
385      return null;
386    }
387    AvlNode<E> node;
388    if (range.hasLowerBound()) {
389      E endpoint = range.getLowerEndpoint();
390      node = rootReference.get().ceiling(comparator(), endpoint);
391      if (node == null) {
392        return null;
393      }
394      if (range.getLowerBoundType() == BoundType.OPEN
395          && comparator().compare(endpoint, node.getElement()) == 0) {
396        node = node.succ;
397      }
398    } else {
399      node = header.succ;
400    }
401    return (node == header || !range.contains(node.getElement())) ? null : node;
402  }
403
404  @NullableDecl
405  private AvlNode<E> lastNode() {
406    AvlNode<E> root = rootReference.get();
407    if (root == null) {
408      return null;
409    }
410    AvlNode<E> node;
411    if (range.hasUpperBound()) {
412      E endpoint = range.getUpperEndpoint();
413      node = rootReference.get().floor(comparator(), endpoint);
414      if (node == null) {
415        return null;
416      }
417      if (range.getUpperBoundType() == BoundType.OPEN
418          && comparator().compare(endpoint, node.getElement()) == 0) {
419        node = node.pred;
420      }
421    } else {
422      node = header.pred;
423    }
424    return (node == header || !range.contains(node.getElement())) ? null : node;
425  }
426
427  @Override
428  Iterator<E> elementIterator() {
429    return Multisets.elementIterator(entryIterator());
430  }
431
432  @Override
433  Iterator<Entry<E>> entryIterator() {
434    return new Iterator<Entry<E>>() {
435      AvlNode<E> current = firstNode();
436      @NullableDecl Entry<E> prevEntry;
437
438      @Override
439      public boolean hasNext() {
440        if (current == null) {
441          return false;
442        } else if (range.tooHigh(current.getElement())) {
443          current = null;
444          return false;
445        } else {
446          return true;
447        }
448      }
449
450      @Override
451      public Entry<E> next() {
452        if (!hasNext()) {
453          throw new NoSuchElementException();
454        }
455        Entry<E> result = wrapEntry(current);
456        prevEntry = result;
457        if (current.succ == header) {
458          current = null;
459        } else {
460          current = current.succ;
461        }
462        return result;
463      }
464
465      @Override
466      public void remove() {
467        checkRemove(prevEntry != null);
468        setCount(prevEntry.getElement(), 0);
469        prevEntry = null;
470      }
471    };
472  }
473
474  @Override
475  Iterator<Entry<E>> descendingEntryIterator() {
476    return new Iterator<Entry<E>>() {
477      AvlNode<E> current = lastNode();
478      Entry<E> prevEntry = null;
479
480      @Override
481      public boolean hasNext() {
482        if (current == null) {
483          return false;
484        } else if (range.tooLow(current.getElement())) {
485          current = null;
486          return false;
487        } else {
488          return true;
489        }
490      }
491
492      @Override
493      public Entry<E> next() {
494        if (!hasNext()) {
495          throw new NoSuchElementException();
496        }
497        Entry<E> result = wrapEntry(current);
498        prevEntry = result;
499        if (current.pred == header) {
500          current = null;
501        } else {
502          current = current.pred;
503        }
504        return result;
505      }
506
507      @Override
508      public void remove() {
509        checkRemove(prevEntry != null);
510        setCount(prevEntry.getElement(), 0);
511        prevEntry = null;
512      }
513    };
514  }
515
516  @Override
517  public Iterator<E> iterator() {
518    return Multisets.iteratorImpl(this);
519  }
520
521  @Override
522  public SortedMultiset<E> headMultiset(@NullableDecl E upperBound, BoundType boundType) {
523    return new TreeMultiset<E>(
524        rootReference,
525        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
526        header);
527  }
528
529  @Override
530  public SortedMultiset<E> tailMultiset(@NullableDecl E lowerBound, BoundType boundType) {
531    return new TreeMultiset<E>(
532        rootReference,
533        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
534        header);
535  }
536
537  private static final class Reference<T> {
538    @NullableDecl private T value;
539
540    @NullableDecl
541    public T get() {
542      return value;
543    }
544
545    public void checkAndSet(@NullableDecl T expected, T newValue) {
546      if (value != expected) {
547        throw new ConcurrentModificationException();
548      }
549      value = newValue;
550    }
551
552    void clear() {
553      value = null;
554    }
555  }
556
557  private static final class AvlNode<E> {
558    @NullableDecl private final E elem;
559
560    // elemCount is 0 iff this node has been deleted.
561    private int elemCount;
562
563    private int distinctElements;
564    private long totalCount;
565    private int height;
566    @NullableDecl private AvlNode<E> left;
567    @NullableDecl private AvlNode<E> right;
568    @NullableDecl private AvlNode<E> pred;
569    @NullableDecl private AvlNode<E> succ;
570
571    AvlNode(@NullableDecl E elem, int elemCount) {
572      checkArgument(elemCount > 0);
573      this.elem = elem;
574      this.elemCount = elemCount;
575      this.totalCount = elemCount;
576      this.distinctElements = 1;
577      this.height = 1;
578      this.left = null;
579      this.right = null;
580    }
581
582    public int count(Comparator<? super E> comparator, E e) {
583      int cmp = comparator.compare(e, elem);
584      if (cmp < 0) {
585        return (left == null) ? 0 : left.count(comparator, e);
586      } else if (cmp > 0) {
587        return (right == null) ? 0 : right.count(comparator, e);
588      } else {
589        return elemCount;
590      }
591    }
592
593    private AvlNode<E> addRightChild(E e, int count) {
594      right = new AvlNode<E>(e, count);
595      successor(this, right, succ);
596      height = Math.max(2, height);
597      distinctElements++;
598      totalCount += count;
599      return this;
600    }
601
602    private AvlNode<E> addLeftChild(E e, int count) {
603      left = new AvlNode<E>(e, count);
604      successor(pred, left, this);
605      height = Math.max(2, height);
606      distinctElements++;
607      totalCount += count;
608      return this;
609    }
610
611    AvlNode<E> add(Comparator<? super E> comparator, @NullableDecl E e, int count, int[] result) {
612      /*
613       * It speeds things up considerably to unconditionally add count to totalCount here,
614       * but that destroys failure atomicity in the case of count overflow. =(
615       */
616      int cmp = comparator.compare(e, elem);
617      if (cmp < 0) {
618        AvlNode<E> initLeft = left;
619        if (initLeft == null) {
620          result[0] = 0;
621          return addLeftChild(e, count);
622        }
623        int initHeight = initLeft.height;
624
625        left = initLeft.add(comparator, e, count, result);
626        if (result[0] == 0) {
627          distinctElements++;
628        }
629        this.totalCount += count;
630        return (left.height == initHeight) ? this : rebalance();
631      } else if (cmp > 0) {
632        AvlNode<E> initRight = right;
633        if (initRight == null) {
634          result[0] = 0;
635          return addRightChild(e, count);
636        }
637        int initHeight = initRight.height;
638
639        right = initRight.add(comparator, e, count, result);
640        if (result[0] == 0) {
641          distinctElements++;
642        }
643        this.totalCount += count;
644        return (right.height == initHeight) ? this : rebalance();
645      }
646
647      // adding count to me!  No rebalance possible.
648      result[0] = elemCount;
649      long resultCount = (long) elemCount + count;
650      checkArgument(resultCount <= Integer.MAX_VALUE);
651      this.elemCount += count;
652      this.totalCount += count;
653      return this;
654    }
655
656    AvlNode<E> remove(
657        Comparator<? super E> comparator, @NullableDecl E e, int count, int[] result) {
658      int cmp = comparator.compare(e, elem);
659      if (cmp < 0) {
660        AvlNode<E> initLeft = left;
661        if (initLeft == null) {
662          result[0] = 0;
663          return this;
664        }
665
666        left = initLeft.remove(comparator, e, count, result);
667
668        if (result[0] > 0) {
669          if (count >= result[0]) {
670            this.distinctElements--;
671            this.totalCount -= result[0];
672          } else {
673            this.totalCount -= count;
674          }
675        }
676        return (result[0] == 0) ? this : rebalance();
677      } else if (cmp > 0) {
678        AvlNode<E> initRight = right;
679        if (initRight == null) {
680          result[0] = 0;
681          return this;
682        }
683
684        right = initRight.remove(comparator, e, count, result);
685
686        if (result[0] > 0) {
687          if (count >= result[0]) {
688            this.distinctElements--;
689            this.totalCount -= result[0];
690          } else {
691            this.totalCount -= count;
692          }
693        }
694        return rebalance();
695      }
696
697      // removing count from me!
698      result[0] = elemCount;
699      if (count >= elemCount) {
700        return deleteMe();
701      } else {
702        this.elemCount -= count;
703        this.totalCount -= count;
704        return this;
705      }
706    }
707
708    AvlNode<E> setCount(
709        Comparator<? super E> comparator, @NullableDecl E e, int count, int[] result) {
710      int cmp = comparator.compare(e, elem);
711      if (cmp < 0) {
712        AvlNode<E> initLeft = left;
713        if (initLeft == null) {
714          result[0] = 0;
715          return (count > 0) ? addLeftChild(e, count) : this;
716        }
717
718        left = initLeft.setCount(comparator, e, count, result);
719
720        if (count == 0 && result[0] != 0) {
721          this.distinctElements--;
722        } else if (count > 0 && result[0] == 0) {
723          this.distinctElements++;
724        }
725
726        this.totalCount += count - result[0];
727        return rebalance();
728      } else if (cmp > 0) {
729        AvlNode<E> initRight = right;
730        if (initRight == null) {
731          result[0] = 0;
732          return (count > 0) ? addRightChild(e, count) : this;
733        }
734
735        right = initRight.setCount(comparator, e, count, result);
736
737        if (count == 0 && result[0] != 0) {
738          this.distinctElements--;
739        } else if (count > 0 && result[0] == 0) {
740          this.distinctElements++;
741        }
742
743        this.totalCount += count - result[0];
744        return rebalance();
745      }
746
747      // setting my count
748      result[0] = elemCount;
749      if (count == 0) {
750        return deleteMe();
751      }
752      this.totalCount += count - elemCount;
753      this.elemCount = count;
754      return this;
755    }
756
757    AvlNode<E> setCount(
758        Comparator<? super E> comparator,
759        @NullableDecl E e,
760        int expectedCount,
761        int newCount,
762        int[] result) {
763      int cmp = comparator.compare(e, elem);
764      if (cmp < 0) {
765        AvlNode<E> initLeft = left;
766        if (initLeft == null) {
767          result[0] = 0;
768          if (expectedCount == 0 && newCount > 0) {
769            return addLeftChild(e, newCount);
770          }
771          return this;
772        }
773
774        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
775
776        if (result[0] == expectedCount) {
777          if (newCount == 0 && result[0] != 0) {
778            this.distinctElements--;
779          } else if (newCount > 0 && result[0] == 0) {
780            this.distinctElements++;
781          }
782          this.totalCount += newCount - result[0];
783        }
784        return rebalance();
785      } else if (cmp > 0) {
786        AvlNode<E> initRight = right;
787        if (initRight == null) {
788          result[0] = 0;
789          if (expectedCount == 0 && newCount > 0) {
790            return addRightChild(e, newCount);
791          }
792          return this;
793        }
794
795        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
796
797        if (result[0] == expectedCount) {
798          if (newCount == 0 && result[0] != 0) {
799            this.distinctElements--;
800          } else if (newCount > 0 && result[0] == 0) {
801            this.distinctElements++;
802          }
803          this.totalCount += newCount - result[0];
804        }
805        return rebalance();
806      }
807
808      // setting my count
809      result[0] = elemCount;
810      if (expectedCount == elemCount) {
811        if (newCount == 0) {
812          return deleteMe();
813        }
814        this.totalCount += newCount - elemCount;
815        this.elemCount = newCount;
816      }
817      return this;
818    }
819
820    private AvlNode<E> deleteMe() {
821      int oldElemCount = this.elemCount;
822      this.elemCount = 0;
823      successor(pred, succ);
824      if (left == null) {
825        return right;
826      } else if (right == null) {
827        return left;
828      } else if (left.height >= right.height) {
829        AvlNode<E> newTop = pred;
830        // newTop is the maximum node in my left subtree
831        newTop.left = left.removeMax(newTop);
832        newTop.right = right;
833        newTop.distinctElements = distinctElements - 1;
834        newTop.totalCount = totalCount - oldElemCount;
835        return newTop.rebalance();
836      } else {
837        AvlNode<E> newTop = succ;
838        newTop.right = right.removeMin(newTop);
839        newTop.left = left;
840        newTop.distinctElements = distinctElements - 1;
841        newTop.totalCount = totalCount - oldElemCount;
842        return newTop.rebalance();
843      }
844    }
845
846    // Removes the minimum node from this subtree to be reused elsewhere
847    private AvlNode<E> removeMin(AvlNode<E> node) {
848      if (left == null) {
849        return right;
850      } else {
851        left = left.removeMin(node);
852        distinctElements--;
853        totalCount -= node.elemCount;
854        return rebalance();
855      }
856    }
857
858    // Removes the maximum node from this subtree to be reused elsewhere
859    private AvlNode<E> removeMax(AvlNode<E> node) {
860      if (right == null) {
861        return left;
862      } else {
863        right = right.removeMax(node);
864        distinctElements--;
865        totalCount -= node.elemCount;
866        return rebalance();
867      }
868    }
869
870    private void recomputeMultiset() {
871      this.distinctElements =
872          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
873      this.totalCount = elemCount + totalCount(left) + totalCount(right);
874    }
875
876    private void recomputeHeight() {
877      this.height = 1 + Math.max(height(left), height(right));
878    }
879
880    private void recompute() {
881      recomputeMultiset();
882      recomputeHeight();
883    }
884
885    private AvlNode<E> rebalance() {
886      switch (balanceFactor()) {
887        case -2:
888          if (right.balanceFactor() > 0) {
889            right = right.rotateRight();
890          }
891          return rotateLeft();
892        case 2:
893          if (left.balanceFactor() < 0) {
894            left = left.rotateLeft();
895          }
896          return rotateRight();
897        default:
898          recomputeHeight();
899          return this;
900      }
901    }
902
903    private int balanceFactor() {
904      return height(left) - height(right);
905    }
906
907    private AvlNode<E> rotateLeft() {
908      checkState(right != null);
909      AvlNode<E> newTop = right;
910      this.right = newTop.left;
911      newTop.left = this;
912      newTop.totalCount = this.totalCount;
913      newTop.distinctElements = this.distinctElements;
914      this.recompute();
915      newTop.recomputeHeight();
916      return newTop;
917    }
918
919    private AvlNode<E> rotateRight() {
920      checkState(left != null);
921      AvlNode<E> newTop = left;
922      this.left = newTop.right;
923      newTop.right = this;
924      newTop.totalCount = this.totalCount;
925      newTop.distinctElements = this.distinctElements;
926      this.recompute();
927      newTop.recomputeHeight();
928      return newTop;
929    }
930
931    private static long totalCount(@NullableDecl AvlNode<?> node) {
932      return (node == null) ? 0 : node.totalCount;
933    }
934
935    private static int height(@NullableDecl AvlNode<?> node) {
936      return (node == null) ? 0 : node.height;
937    }
938
939    @NullableDecl
940    private AvlNode<E> ceiling(Comparator<? super E> comparator, E e) {
941      int cmp = comparator.compare(e, elem);
942      if (cmp < 0) {
943        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
944      } else if (cmp == 0) {
945        return this;
946      } else {
947        return (right == null) ? null : right.ceiling(comparator, e);
948      }
949    }
950
951    @NullableDecl
952    private AvlNode<E> floor(Comparator<? super E> comparator, E e) {
953      int cmp = comparator.compare(e, elem);
954      if (cmp > 0) {
955        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
956      } else if (cmp == 0) {
957        return this;
958      } else {
959        return (left == null) ? null : left.floor(comparator, e);
960      }
961    }
962
963    E getElement() {
964      return elem;
965    }
966
967    int getCount() {
968      return elemCount;
969    }
970
971    @Override
972    public String toString() {
973      return Multisets.immutableEntry(getElement(), getCount()).toString();
974    }
975  }
976
977  private static <T> void successor(AvlNode<T> a, AvlNode<T> b) {
978    a.succ = b;
979    b.pred = a;
980  }
981
982  private static <T> void successor(AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
983    successor(a, b);
984    successor(b, c);
985  }
986
987  /*
988   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
989   * calls the comparator to compare the two keys. If that change is made,
990   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
991   */
992
993  /**
994   * @serialData the comparator, the number of distinct elements, the first element, its count, the
995   *     second element, its count, and so on
996   */
997  @GwtIncompatible // java.io.ObjectOutputStream
998  private void writeObject(ObjectOutputStream stream) throws IOException {
999    stream.defaultWriteObject();
1000    stream.writeObject(elementSet().comparator());
1001    Serialization.writeMultiset(this, stream);
1002  }
1003
1004  @GwtIncompatible // java.io.ObjectInputStream
1005  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
1006    stream.defaultReadObject();
1007    @SuppressWarnings("unchecked")
1008    // reading data stored by writeObject
1009    Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
1010    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
1011    Serialization.getFieldSetter(TreeMultiset.class, "range")
1012        .set(this, GeneralRange.all(comparator));
1013    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
1014        .set(this, new Reference<AvlNode<E>>());
1015    AvlNode<E> header = new AvlNode<E>(null, 1);
1016    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
1017    successor(header, header);
1018    Serialization.populateMultiset(this, stream);
1019  }
1020
1021  @GwtIncompatible // not needed in emulated source
1022  private static final long serialVersionUID = 1;
1023}