001/* 002 * Copyright (C) 2007 The Guava Authors 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package com.google.common.collect; 018 019import static com.google.common.base.Preconditions.checkArgument; 020import static com.google.common.base.Preconditions.checkNotNull; 021import static com.google.common.base.Preconditions.checkState; 022import static com.google.common.collect.CollectPreconditions.checkNonnegative; 023import static com.google.common.collect.CollectPreconditions.checkRemove; 024 025import com.google.common.annotations.GwtCompatible; 026import com.google.common.annotations.GwtIncompatible; 027import com.google.common.base.MoreObjects; 028import com.google.common.primitives.Ints; 029import com.google.errorprone.annotations.CanIgnoreReturnValue; 030import java.io.IOException; 031import java.io.ObjectInputStream; 032import java.io.ObjectOutputStream; 033import java.io.Serializable; 034import java.util.Comparator; 035import java.util.ConcurrentModificationException; 036import java.util.Iterator; 037import java.util.NoSuchElementException; 038import java.util.function.ObjIntConsumer; 039import org.checkerframework.checker.nullness.compatqual.NullableDecl; 040 041/** 042 * A multiset which maintains the ordering of its elements, according to either their natural order 043 * or an explicit {@link Comparator}. In all cases, this implementation uses {@link 044 * Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to determine 045 * equivalence of instances. 046 * 047 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the 048 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the {@link 049 * java.util.Collection} contract, which is specified in terms of {@link Object#equals}. 050 * 051 * <p>See the Guava User Guide article on <a href= 052 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset"> {@code 053 * Multiset}</a>. 054 * 055 * @author Louis Wasserman 056 * @author Jared Levy 057 * @since 2.0 058 */ 059@GwtCompatible(emulated = true) 060public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements Serializable { 061 062 /** 063 * Creates a new, empty multiset, sorted according to the elements' natural order. All elements 064 * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all 065 * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a 066 * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the 067 * user attempts to add an element to the multiset that violates this constraint (for example, the 068 * user attempts to add a string element to a set whose elements are integers), the {@code 069 * add(Object)} call will throw a {@code ClassCastException}. 070 * 071 * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific 072 * {@code <E extends Comparable<? super E>>}, to support classes defined without generics. 073 */ 074 public static <E extends Comparable> TreeMultiset<E> create() { 075 return new TreeMultiset<E>(Ordering.natural()); 076 } 077 078 /** 079 * Creates a new, empty multiset, sorted according to the specified comparator. All elements 080 * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator: 081 * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements 082 * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the 083 * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code 084 * ClassCastException}. 085 * 086 * @param comparator the comparator that will be used to sort this multiset. A null value 087 * indicates that the elements' <i>natural ordering</i> should be used. 088 */ 089 @SuppressWarnings("unchecked") 090 public static <E> TreeMultiset<E> create(@NullableDecl Comparator<? super E> comparator) { 091 return (comparator == null) 092 ? new TreeMultiset<E>((Comparator) Ordering.natural()) 093 : new TreeMultiset<E>(comparator); 094 } 095 096 /** 097 * Creates an empty multiset containing the given initial elements, sorted according to the 098 * elements' natural order. 099 * 100 * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}. 101 * 102 * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific 103 * {@code <E extends Comparable<? super E>>}, to support classes defined without generics. 104 */ 105 public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) { 106 TreeMultiset<E> multiset = create(); 107 Iterables.addAll(multiset, elements); 108 return multiset; 109 } 110 111 private final transient Reference<AvlNode<E>> rootReference; 112 private final transient GeneralRange<E> range; 113 private final transient AvlNode<E> header; 114 115 TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) { 116 super(range.comparator()); 117 this.rootReference = rootReference; 118 this.range = range; 119 this.header = endLink; 120 } 121 122 TreeMultiset(Comparator<? super E> comparator) { 123 super(comparator); 124 this.range = GeneralRange.all(comparator); 125 this.header = new AvlNode<E>(null, 1); 126 successor(header, header); 127 this.rootReference = new Reference<>(); 128 } 129 130 /** A function which can be summed across a subtree. */ 131 private enum Aggregate { 132 SIZE { 133 @Override 134 int nodeAggregate(AvlNode<?> node) { 135 return node.elemCount; 136 } 137 138 @Override 139 long treeAggregate(@NullableDecl AvlNode<?> root) { 140 return (root == null) ? 0 : root.totalCount; 141 } 142 }, 143 DISTINCT { 144 @Override 145 int nodeAggregate(AvlNode<?> node) { 146 return 1; 147 } 148 149 @Override 150 long treeAggregate(@NullableDecl AvlNode<?> root) { 151 return (root == null) ? 0 : root.distinctElements; 152 } 153 }; 154 155 abstract int nodeAggregate(AvlNode<?> node); 156 157 abstract long treeAggregate(@NullableDecl AvlNode<?> root); 158 } 159 160 private long aggregateForEntries(Aggregate aggr) { 161 AvlNode<E> root = rootReference.get(); 162 long total = aggr.treeAggregate(root); 163 if (range.hasLowerBound()) { 164 total -= aggregateBelowRange(aggr, root); 165 } 166 if (range.hasUpperBound()) { 167 total -= aggregateAboveRange(aggr, root); 168 } 169 return total; 170 } 171 172 private long aggregateBelowRange(Aggregate aggr, @NullableDecl AvlNode<E> node) { 173 if (node == null) { 174 return 0; 175 } 176 int cmp = comparator().compare(range.getLowerEndpoint(), node.elem); 177 if (cmp < 0) { 178 return aggregateBelowRange(aggr, node.left); 179 } else if (cmp == 0) { 180 switch (range.getLowerBoundType()) { 181 case OPEN: 182 return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left); 183 case CLOSED: 184 return aggr.treeAggregate(node.left); 185 default: 186 throw new AssertionError(); 187 } 188 } else { 189 return aggr.treeAggregate(node.left) 190 + aggr.nodeAggregate(node) 191 + aggregateBelowRange(aggr, node.right); 192 } 193 } 194 195 private long aggregateAboveRange(Aggregate aggr, @NullableDecl AvlNode<E> node) { 196 if (node == null) { 197 return 0; 198 } 199 int cmp = comparator().compare(range.getUpperEndpoint(), node.elem); 200 if (cmp > 0) { 201 return aggregateAboveRange(aggr, node.right); 202 } else if (cmp == 0) { 203 switch (range.getUpperBoundType()) { 204 case OPEN: 205 return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right); 206 case CLOSED: 207 return aggr.treeAggregate(node.right); 208 default: 209 throw new AssertionError(); 210 } 211 } else { 212 return aggr.treeAggregate(node.right) 213 + aggr.nodeAggregate(node) 214 + aggregateAboveRange(aggr, node.left); 215 } 216 } 217 218 @Override 219 public int size() { 220 return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE)); 221 } 222 223 @Override 224 int distinctElements() { 225 return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT)); 226 } 227 228 @Override 229 public int count(@NullableDecl Object element) { 230 try { 231 @SuppressWarnings("unchecked") 232 E e = (E) element; 233 AvlNode<E> root = rootReference.get(); 234 if (!range.contains(e) || root == null) { 235 return 0; 236 } 237 return root.count(comparator(), e); 238 } catch (ClassCastException | NullPointerException e) { 239 return 0; 240 } 241 } 242 243 @CanIgnoreReturnValue 244 @Override 245 public int add(@NullableDecl E element, int occurrences) { 246 checkNonnegative(occurrences, "occurrences"); 247 if (occurrences == 0) { 248 return count(element); 249 } 250 checkArgument(range.contains(element)); 251 AvlNode<E> root = rootReference.get(); 252 if (root == null) { 253 comparator().compare(element, element); 254 AvlNode<E> newRoot = new AvlNode<E>(element, occurrences); 255 successor(header, newRoot, header); 256 rootReference.checkAndSet(root, newRoot); 257 return 0; 258 } 259 int[] result = new int[1]; // used as a mutable int reference to hold result 260 AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result); 261 rootReference.checkAndSet(root, newRoot); 262 return result[0]; 263 } 264 265 @CanIgnoreReturnValue 266 @Override 267 public int remove(@NullableDecl Object element, int occurrences) { 268 checkNonnegative(occurrences, "occurrences"); 269 if (occurrences == 0) { 270 return count(element); 271 } 272 AvlNode<E> root = rootReference.get(); 273 int[] result = new int[1]; // used as a mutable int reference to hold result 274 AvlNode<E> newRoot; 275 try { 276 @SuppressWarnings("unchecked") 277 E e = (E) element; 278 if (!range.contains(e) || root == null) { 279 return 0; 280 } 281 newRoot = root.remove(comparator(), e, occurrences, result); 282 } catch (ClassCastException | NullPointerException e) { 283 return 0; 284 } 285 rootReference.checkAndSet(root, newRoot); 286 return result[0]; 287 } 288 289 @CanIgnoreReturnValue 290 @Override 291 public int setCount(@NullableDecl E element, int count) { 292 checkNonnegative(count, "count"); 293 if (!range.contains(element)) { 294 checkArgument(count == 0); 295 return 0; 296 } 297 298 AvlNode<E> root = rootReference.get(); 299 if (root == null) { 300 if (count > 0) { 301 add(element, count); 302 } 303 return 0; 304 } 305 int[] result = new int[1]; // used as a mutable int reference to hold result 306 AvlNode<E> newRoot = root.setCount(comparator(), element, count, result); 307 rootReference.checkAndSet(root, newRoot); 308 return result[0]; 309 } 310 311 @CanIgnoreReturnValue 312 @Override 313 public boolean setCount(@NullableDecl E element, int oldCount, int newCount) { 314 checkNonnegative(newCount, "newCount"); 315 checkNonnegative(oldCount, "oldCount"); 316 checkArgument(range.contains(element)); 317 318 AvlNode<E> root = rootReference.get(); 319 if (root == null) { 320 if (oldCount == 0) { 321 if (newCount > 0) { 322 add(element, newCount); 323 } 324 return true; 325 } else { 326 return false; 327 } 328 } 329 int[] result = new int[1]; // used as a mutable int reference to hold result 330 AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result); 331 rootReference.checkAndSet(root, newRoot); 332 return result[0] == oldCount; 333 } 334 335 @Override 336 public void clear() { 337 if (!range.hasLowerBound() && !range.hasUpperBound()) { 338 // We can do this in O(n) rather than removing one by one, which could force rebalancing. 339 for (AvlNode<E> current = header.succ; current != header; ) { 340 AvlNode<E> next = current.succ; 341 342 current.elemCount = 0; 343 // Also clear these fields so that one deleted Entry doesn't retain all elements. 344 current.left = null; 345 current.right = null; 346 current.pred = null; 347 current.succ = null; 348 349 current = next; 350 } 351 successor(header, header); 352 rootReference.clear(); 353 } else { 354 // TODO(cpovirk): Perhaps we can optimize in this case, too? 355 Iterators.clear(entryIterator()); 356 } 357 } 358 359 private Entry<E> wrapEntry(final AvlNode<E> baseEntry) { 360 return new Multisets.AbstractEntry<E>() { 361 @Override 362 public E getElement() { 363 return baseEntry.getElement(); 364 } 365 366 @Override 367 public int getCount() { 368 int result = baseEntry.getCount(); 369 if (result == 0) { 370 return count(getElement()); 371 } else { 372 return result; 373 } 374 } 375 }; 376 } 377 378 /** Returns the first node in the tree that is in range. */ 379 @NullableDecl 380 private AvlNode<E> firstNode() { 381 AvlNode<E> root = rootReference.get(); 382 if (root == null) { 383 return null; 384 } 385 AvlNode<E> node; 386 if (range.hasLowerBound()) { 387 E endpoint = range.getLowerEndpoint(); 388 node = rootReference.get().ceiling(comparator(), endpoint); 389 if (node == null) { 390 return null; 391 } 392 if (range.getLowerBoundType() == BoundType.OPEN 393 && comparator().compare(endpoint, node.getElement()) == 0) { 394 node = node.succ; 395 } 396 } else { 397 node = header.succ; 398 } 399 return (node == header || !range.contains(node.getElement())) ? null : node; 400 } 401 402 @NullableDecl 403 private AvlNode<E> lastNode() { 404 AvlNode<E> root = rootReference.get(); 405 if (root == null) { 406 return null; 407 } 408 AvlNode<E> node; 409 if (range.hasUpperBound()) { 410 E endpoint = range.getUpperEndpoint(); 411 node = rootReference.get().floor(comparator(), endpoint); 412 if (node == null) { 413 return null; 414 } 415 if (range.getUpperBoundType() == BoundType.OPEN 416 && comparator().compare(endpoint, node.getElement()) == 0) { 417 node = node.pred; 418 } 419 } else { 420 node = header.pred; 421 } 422 return (node == header || !range.contains(node.getElement())) ? null : node; 423 } 424 425 @Override 426 Iterator<E> elementIterator() { 427 return Multisets.elementIterator(entryIterator()); 428 } 429 430 @Override 431 Iterator<Entry<E>> entryIterator() { 432 return new Iterator<Entry<E>>() { 433 AvlNode<E> current = firstNode(); 434 @NullableDecl Entry<E> prevEntry; 435 436 @Override 437 public boolean hasNext() { 438 if (current == null) { 439 return false; 440 } else if (range.tooHigh(current.getElement())) { 441 current = null; 442 return false; 443 } else { 444 return true; 445 } 446 } 447 448 @Override 449 public Entry<E> next() { 450 if (!hasNext()) { 451 throw new NoSuchElementException(); 452 } 453 Entry<E> result = wrapEntry(current); 454 prevEntry = result; 455 if (current.succ == header) { 456 current = null; 457 } else { 458 current = current.succ; 459 } 460 return result; 461 } 462 463 @Override 464 public void remove() { 465 checkRemove(prevEntry != null); 466 setCount(prevEntry.getElement(), 0); 467 prevEntry = null; 468 } 469 }; 470 } 471 472 @Override 473 Iterator<Entry<E>> descendingEntryIterator() { 474 return new Iterator<Entry<E>>() { 475 AvlNode<E> current = lastNode(); 476 Entry<E> prevEntry = null; 477 478 @Override 479 public boolean hasNext() { 480 if (current == null) { 481 return false; 482 } else if (range.tooLow(current.getElement())) { 483 current = null; 484 return false; 485 } else { 486 return true; 487 } 488 } 489 490 @Override 491 public Entry<E> next() { 492 if (!hasNext()) { 493 throw new NoSuchElementException(); 494 } 495 Entry<E> result = wrapEntry(current); 496 prevEntry = result; 497 if (current.pred == header) { 498 current = null; 499 } else { 500 current = current.pred; 501 } 502 return result; 503 } 504 505 @Override 506 public void remove() { 507 checkRemove(prevEntry != null); 508 setCount(prevEntry.getElement(), 0); 509 prevEntry = null; 510 } 511 }; 512 } 513 514 @Override 515 public void forEachEntry(ObjIntConsumer<? super E> action) { 516 checkNotNull(action); 517 for (AvlNode<E> node = firstNode(); 518 node != header && node != null && !range.tooHigh(node.getElement()); 519 node = node.succ) { 520 action.accept(node.getElement(), node.getCount()); 521 } 522 } 523 524 @Override 525 public Iterator<E> iterator() { 526 return Multisets.iteratorImpl(this); 527 } 528 529 @Override 530 public SortedMultiset<E> headMultiset(@NullableDecl E upperBound, BoundType boundType) { 531 return new TreeMultiset<E>( 532 rootReference, 533 range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)), 534 header); 535 } 536 537 @Override 538 public SortedMultiset<E> tailMultiset(@NullableDecl E lowerBound, BoundType boundType) { 539 return new TreeMultiset<E>( 540 rootReference, 541 range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)), 542 header); 543 } 544 545 static int distinctElements(@NullableDecl AvlNode<?> node) { 546 return (node == null) ? 0 : node.distinctElements; 547 } 548 549 private static final class Reference<T> { 550 @NullableDecl private T value; 551 552 @NullableDecl 553 public T get() { 554 return value; 555 } 556 557 public void checkAndSet(@NullableDecl T expected, T newValue) { 558 if (value != expected) { 559 throw new ConcurrentModificationException(); 560 } 561 value = newValue; 562 } 563 564 void clear() { 565 value = null; 566 } 567 } 568 569 private static final class AvlNode<E> { 570 @NullableDecl private final E elem; 571 572 // elemCount is 0 iff this node has been deleted. 573 private int elemCount; 574 575 private int distinctElements; 576 private long totalCount; 577 private int height; 578 @NullableDecl private AvlNode<E> left; 579 @NullableDecl private AvlNode<E> right; 580 @NullableDecl private AvlNode<E> pred; 581 @NullableDecl private AvlNode<E> succ; 582 583 AvlNode(@NullableDecl E elem, int elemCount) { 584 checkArgument(elemCount > 0); 585 this.elem = elem; 586 this.elemCount = elemCount; 587 this.totalCount = elemCount; 588 this.distinctElements = 1; 589 this.height = 1; 590 this.left = null; 591 this.right = null; 592 } 593 594 public int count(Comparator<? super E> comparator, E e) { 595 int cmp = comparator.compare(e, elem); 596 if (cmp < 0) { 597 return (left == null) ? 0 : left.count(comparator, e); 598 } else if (cmp > 0) { 599 return (right == null) ? 0 : right.count(comparator, e); 600 } else { 601 return elemCount; 602 } 603 } 604 605 private AvlNode<E> addRightChild(E e, int count) { 606 right = new AvlNode<E>(e, count); 607 successor(this, right, succ); 608 height = Math.max(2, height); 609 distinctElements++; 610 totalCount += count; 611 return this; 612 } 613 614 private AvlNode<E> addLeftChild(E e, int count) { 615 left = new AvlNode<E>(e, count); 616 successor(pred, left, this); 617 height = Math.max(2, height); 618 distinctElements++; 619 totalCount += count; 620 return this; 621 } 622 623 AvlNode<E> add(Comparator<? super E> comparator, @NullableDecl E e, int count, int[] result) { 624 /* 625 * It speeds things up considerably to unconditionally add count to totalCount here, 626 * but that destroys failure atomicity in the case of count overflow. =( 627 */ 628 int cmp = comparator.compare(e, elem); 629 if (cmp < 0) { 630 AvlNode<E> initLeft = left; 631 if (initLeft == null) { 632 result[0] = 0; 633 return addLeftChild(e, count); 634 } 635 int initHeight = initLeft.height; 636 637 left = initLeft.add(comparator, e, count, result); 638 if (result[0] == 0) { 639 distinctElements++; 640 } 641 this.totalCount += count; 642 return (left.height == initHeight) ? this : rebalance(); 643 } else if (cmp > 0) { 644 AvlNode<E> initRight = right; 645 if (initRight == null) { 646 result[0] = 0; 647 return addRightChild(e, count); 648 } 649 int initHeight = initRight.height; 650 651 right = initRight.add(comparator, e, count, result); 652 if (result[0] == 0) { 653 distinctElements++; 654 } 655 this.totalCount += count; 656 return (right.height == initHeight) ? this : rebalance(); 657 } 658 659 // adding count to me! No rebalance possible. 660 result[0] = elemCount; 661 long resultCount = (long) elemCount + count; 662 checkArgument(resultCount <= Integer.MAX_VALUE); 663 this.elemCount += count; 664 this.totalCount += count; 665 return this; 666 } 667 668 AvlNode<E> remove( 669 Comparator<? super E> comparator, @NullableDecl E e, int count, int[] result) { 670 int cmp = comparator.compare(e, elem); 671 if (cmp < 0) { 672 AvlNode<E> initLeft = left; 673 if (initLeft == null) { 674 result[0] = 0; 675 return this; 676 } 677 678 left = initLeft.remove(comparator, e, count, result); 679 680 if (result[0] > 0) { 681 if (count >= result[0]) { 682 this.distinctElements--; 683 this.totalCount -= result[0]; 684 } else { 685 this.totalCount -= count; 686 } 687 } 688 return (result[0] == 0) ? this : rebalance(); 689 } else if (cmp > 0) { 690 AvlNode<E> initRight = right; 691 if (initRight == null) { 692 result[0] = 0; 693 return this; 694 } 695 696 right = initRight.remove(comparator, e, count, result); 697 698 if (result[0] > 0) { 699 if (count >= result[0]) { 700 this.distinctElements--; 701 this.totalCount -= result[0]; 702 } else { 703 this.totalCount -= count; 704 } 705 } 706 return rebalance(); 707 } 708 709 // removing count from me! 710 result[0] = elemCount; 711 if (count >= elemCount) { 712 return deleteMe(); 713 } else { 714 this.elemCount -= count; 715 this.totalCount -= count; 716 return this; 717 } 718 } 719 720 AvlNode<E> setCount( 721 Comparator<? super E> comparator, @NullableDecl E e, int count, int[] result) { 722 int cmp = comparator.compare(e, elem); 723 if (cmp < 0) { 724 AvlNode<E> initLeft = left; 725 if (initLeft == null) { 726 result[0] = 0; 727 return (count > 0) ? addLeftChild(e, count) : this; 728 } 729 730 left = initLeft.setCount(comparator, e, count, result); 731 732 if (count == 0 && result[0] != 0) { 733 this.distinctElements--; 734 } else if (count > 0 && result[0] == 0) { 735 this.distinctElements++; 736 } 737 738 this.totalCount += count - result[0]; 739 return rebalance(); 740 } else if (cmp > 0) { 741 AvlNode<E> initRight = right; 742 if (initRight == null) { 743 result[0] = 0; 744 return (count > 0) ? addRightChild(e, count) : this; 745 } 746 747 right = initRight.setCount(comparator, e, count, result); 748 749 if (count == 0 && result[0] != 0) { 750 this.distinctElements--; 751 } else if (count > 0 && result[0] == 0) { 752 this.distinctElements++; 753 } 754 755 this.totalCount += count - result[0]; 756 return rebalance(); 757 } 758 759 // setting my count 760 result[0] = elemCount; 761 if (count == 0) { 762 return deleteMe(); 763 } 764 this.totalCount += count - elemCount; 765 this.elemCount = count; 766 return this; 767 } 768 769 AvlNode<E> setCount( 770 Comparator<? super E> comparator, 771 @NullableDecl E e, 772 int expectedCount, 773 int newCount, 774 int[] result) { 775 int cmp = comparator.compare(e, elem); 776 if (cmp < 0) { 777 AvlNode<E> initLeft = left; 778 if (initLeft == null) { 779 result[0] = 0; 780 if (expectedCount == 0 && newCount > 0) { 781 return addLeftChild(e, newCount); 782 } 783 return this; 784 } 785 786 left = initLeft.setCount(comparator, e, expectedCount, newCount, result); 787 788 if (result[0] == expectedCount) { 789 if (newCount == 0 && result[0] != 0) { 790 this.distinctElements--; 791 } else if (newCount > 0 && result[0] == 0) { 792 this.distinctElements++; 793 } 794 this.totalCount += newCount - result[0]; 795 } 796 return rebalance(); 797 } else if (cmp > 0) { 798 AvlNode<E> initRight = right; 799 if (initRight == null) { 800 result[0] = 0; 801 if (expectedCount == 0 && newCount > 0) { 802 return addRightChild(e, newCount); 803 } 804 return this; 805 } 806 807 right = initRight.setCount(comparator, e, expectedCount, newCount, result); 808 809 if (result[0] == expectedCount) { 810 if (newCount == 0 && result[0] != 0) { 811 this.distinctElements--; 812 } else if (newCount > 0 && result[0] == 0) { 813 this.distinctElements++; 814 } 815 this.totalCount += newCount - result[0]; 816 } 817 return rebalance(); 818 } 819 820 // setting my count 821 result[0] = elemCount; 822 if (expectedCount == elemCount) { 823 if (newCount == 0) { 824 return deleteMe(); 825 } 826 this.totalCount += newCount - elemCount; 827 this.elemCount = newCount; 828 } 829 return this; 830 } 831 832 private AvlNode<E> deleteMe() { 833 int oldElemCount = this.elemCount; 834 this.elemCount = 0; 835 successor(pred, succ); 836 if (left == null) { 837 return right; 838 } else if (right == null) { 839 return left; 840 } else if (left.height >= right.height) { 841 AvlNode<E> newTop = pred; 842 // newTop is the maximum node in my left subtree 843 newTop.left = left.removeMax(newTop); 844 newTop.right = right; 845 newTop.distinctElements = distinctElements - 1; 846 newTop.totalCount = totalCount - oldElemCount; 847 return newTop.rebalance(); 848 } else { 849 AvlNode<E> newTop = succ; 850 newTop.right = right.removeMin(newTop); 851 newTop.left = left; 852 newTop.distinctElements = distinctElements - 1; 853 newTop.totalCount = totalCount - oldElemCount; 854 return newTop.rebalance(); 855 } 856 } 857 858 // Removes the minimum node from this subtree to be reused elsewhere 859 private AvlNode<E> removeMin(AvlNode<E> node) { 860 if (left == null) { 861 return right; 862 } else { 863 left = left.removeMin(node); 864 distinctElements--; 865 totalCount -= node.elemCount; 866 return rebalance(); 867 } 868 } 869 870 // Removes the maximum node from this subtree to be reused elsewhere 871 private AvlNode<E> removeMax(AvlNode<E> node) { 872 if (right == null) { 873 return left; 874 } else { 875 right = right.removeMax(node); 876 distinctElements--; 877 totalCount -= node.elemCount; 878 return rebalance(); 879 } 880 } 881 882 private void recomputeMultiset() { 883 this.distinctElements = 884 1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right); 885 this.totalCount = elemCount + totalCount(left) + totalCount(right); 886 } 887 888 private void recomputeHeight() { 889 this.height = 1 + Math.max(height(left), height(right)); 890 } 891 892 private void recompute() { 893 recomputeMultiset(); 894 recomputeHeight(); 895 } 896 897 private AvlNode<E> rebalance() { 898 switch (balanceFactor()) { 899 case -2: 900 if (right.balanceFactor() > 0) { 901 right = right.rotateRight(); 902 } 903 return rotateLeft(); 904 case 2: 905 if (left.balanceFactor() < 0) { 906 left = left.rotateLeft(); 907 } 908 return rotateRight(); 909 default: 910 recomputeHeight(); 911 return this; 912 } 913 } 914 915 private int balanceFactor() { 916 return height(left) - height(right); 917 } 918 919 private AvlNode<E> rotateLeft() { 920 checkState(right != null); 921 AvlNode<E> newTop = right; 922 this.right = newTop.left; 923 newTop.left = this; 924 newTop.totalCount = this.totalCount; 925 newTop.distinctElements = this.distinctElements; 926 this.recompute(); 927 newTop.recomputeHeight(); 928 return newTop; 929 } 930 931 private AvlNode<E> rotateRight() { 932 checkState(left != null); 933 AvlNode<E> newTop = left; 934 this.left = newTop.right; 935 newTop.right = this; 936 newTop.totalCount = this.totalCount; 937 newTop.distinctElements = this.distinctElements; 938 this.recompute(); 939 newTop.recomputeHeight(); 940 return newTop; 941 } 942 943 private static long totalCount(@NullableDecl AvlNode<?> node) { 944 return (node == null) ? 0 : node.totalCount; 945 } 946 947 private static int height(@NullableDecl AvlNode<?> node) { 948 return (node == null) ? 0 : node.height; 949 } 950 951 @NullableDecl 952 private AvlNode<E> ceiling(Comparator<? super E> comparator, E e) { 953 int cmp = comparator.compare(e, elem); 954 if (cmp < 0) { 955 return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this); 956 } else if (cmp == 0) { 957 return this; 958 } else { 959 return (right == null) ? null : right.ceiling(comparator, e); 960 } 961 } 962 963 @NullableDecl 964 private AvlNode<E> floor(Comparator<? super E> comparator, E e) { 965 int cmp = comparator.compare(e, elem); 966 if (cmp > 0) { 967 return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this); 968 } else if (cmp == 0) { 969 return this; 970 } else { 971 return (left == null) ? null : left.floor(comparator, e); 972 } 973 } 974 975 E getElement() { 976 return elem; 977 } 978 979 int getCount() { 980 return elemCount; 981 } 982 983 @Override 984 public String toString() { 985 return Multisets.immutableEntry(getElement(), getCount()).toString(); 986 } 987 } 988 989 private static <T> void successor(AvlNode<T> a, AvlNode<T> b) { 990 a.succ = b; 991 b.pred = a; 992 } 993 994 private static <T> void successor(AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) { 995 successor(a, b); 996 successor(b, c); 997 } 998 999 /* 1000 * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that 1001 * calls the comparator to compare the two keys. If that change is made, 1002 * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets. 1003 */ 1004 1005 /** 1006 * @serialData the comparator, the number of distinct elements, the first element, its count, the 1007 * second element, its count, and so on 1008 */ 1009 @GwtIncompatible // java.io.ObjectOutputStream 1010 private void writeObject(ObjectOutputStream stream) throws IOException { 1011 stream.defaultWriteObject(); 1012 stream.writeObject(elementSet().comparator()); 1013 Serialization.writeMultiset(this, stream); 1014 } 1015 1016 @GwtIncompatible // java.io.ObjectInputStream 1017 private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException { 1018 stream.defaultReadObject(); 1019 @SuppressWarnings("unchecked") 1020 // reading data stored by writeObject 1021 Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject(); 1022 Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator); 1023 Serialization.getFieldSetter(TreeMultiset.class, "range") 1024 .set(this, GeneralRange.all(comparator)); 1025 Serialization.getFieldSetter(TreeMultiset.class, "rootReference") 1026 .set(this, new Reference<AvlNode<E>>()); 1027 AvlNode<E> header = new AvlNode<E>(null, 1); 1028 Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header); 1029 successor(header, header); 1030 Serialization.populateMultiset(this, stream); 1031 } 1032 1033 @GwtIncompatible // not needed in emulated source 1034 private static final long serialVersionUID = 1; 1035}