001/* 002 * Copyright (C) 2007 The Guava Authors 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package com.google.common.collect; 018 019import static com.google.common.base.Preconditions.checkArgument; 020import static com.google.common.base.Preconditions.checkNotNull; 021import static com.google.common.base.Preconditions.checkState; 022import static com.google.common.collect.CollectPreconditions.checkNonnegative; 023import static com.google.common.collect.NullnessCasts.uncheckedCastNullableTToT; 024import static java.lang.Math.max; 025import static java.util.Objects.requireNonNull; 026 027import com.google.common.annotations.GwtCompatible; 028import com.google.common.annotations.GwtIncompatible; 029import com.google.common.annotations.J2ktIncompatible; 030import com.google.common.base.MoreObjects; 031import com.google.common.primitives.Ints; 032import com.google.errorprone.annotations.CanIgnoreReturnValue; 033import java.io.IOException; 034import java.io.ObjectInputStream; 035import java.io.ObjectOutputStream; 036import java.io.Serializable; 037import java.util.Comparator; 038import java.util.ConcurrentModificationException; 039import java.util.Iterator; 040import java.util.NoSuchElementException; 041import java.util.function.ObjIntConsumer; 042import org.jspecify.annotations.Nullable; 043 044/** 045 * A multiset which maintains the ordering of its elements, according to either their natural order 046 * or an explicit {@link Comparator}. In all cases, this implementation uses {@link 047 * Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to determine 048 * equivalence of instances. 049 * 050 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the 051 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the {@link 052 * java.util.Collection} contract, which is specified in terms of {@link Object#equals}. 053 * 054 * <p>See the Guava User Guide article on <a href= 055 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset">{@code Multiset}</a>. 056 * 057 * @author Louis Wasserman 058 * @author Jared Levy 059 * @since 2.0 060 */ 061@GwtCompatible(emulated = true) 062public final class TreeMultiset<E extends @Nullable Object> extends AbstractSortedMultiset<E> 063 implements Serializable { 064 065 /** 066 * Creates a new, empty multiset, sorted according to the elements' natural order. All elements 067 * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all 068 * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a 069 * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the 070 * user attempts to add an element to the multiset that violates this constraint (for example, the 071 * user attempts to add a string element to a set whose elements are integers), the {@code 072 * add(Object)} call will throw a {@code ClassCastException}. 073 * 074 * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific 075 * {@code <E extends Comparable<? super E>>}, to support classes defined without generics. 076 */ 077 @SuppressWarnings("rawtypes") // https://github.com/google/guava/issues/989 078 public static <E extends Comparable> TreeMultiset<E> create() { 079 return new TreeMultiset<>(Ordering.natural()); 080 } 081 082 /** 083 * Creates a new, empty multiset, sorted according to the specified comparator. All elements 084 * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator: 085 * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements 086 * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the 087 * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code 088 * ClassCastException}. 089 * 090 * @param comparator the comparator that will be used to sort this multiset. A null value 091 * indicates that the elements' <i>natural ordering</i> should be used. 092 */ 093 @SuppressWarnings("unchecked") 094 public static <E extends @Nullable Object> TreeMultiset<E> create( 095 @Nullable Comparator<? super E> comparator) { 096 return (comparator == null) 097 ? new TreeMultiset<E>((Comparator) Ordering.natural()) 098 : new TreeMultiset<E>(comparator); 099 } 100 101 /** 102 * Creates an empty multiset containing the given initial elements, sorted according to the 103 * elements' natural order. 104 * 105 * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}. 106 * 107 * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific 108 * {@code <E extends Comparable<? super E>>}, to support classes defined without generics. 109 */ 110 @SuppressWarnings("rawtypes") // https://github.com/google/guava/issues/989 111 public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) { 112 TreeMultiset<E> multiset = create(); 113 Iterables.addAll(multiset, elements); 114 return multiset; 115 } 116 117 private final transient Reference<AvlNode<E>> rootReference; 118 private final transient GeneralRange<E> range; 119 private final transient AvlNode<E> header; 120 121 TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) { 122 super(range.comparator()); 123 this.rootReference = rootReference; 124 this.range = range; 125 this.header = endLink; 126 } 127 128 TreeMultiset(Comparator<? super E> comparator) { 129 super(comparator); 130 this.range = GeneralRange.all(comparator); 131 this.header = new AvlNode<>(); 132 successor(header, header); 133 this.rootReference = new Reference<>(); 134 } 135 136 /** A function which can be summed across a subtree. */ 137 private enum Aggregate { 138 SIZE { 139 @Override 140 int nodeAggregate(AvlNode<?> node) { 141 return node.elemCount; 142 } 143 144 @Override 145 long treeAggregate(@Nullable AvlNode<?> root) { 146 return (root == null) ? 0 : root.totalCount; 147 } 148 }, 149 DISTINCT { 150 @Override 151 int nodeAggregate(AvlNode<?> node) { 152 return 1; 153 } 154 155 @Override 156 long treeAggregate(@Nullable AvlNode<?> root) { 157 return (root == null) ? 0 : root.distinctElements; 158 } 159 }; 160 161 abstract int nodeAggregate(AvlNode<?> node); 162 163 abstract long treeAggregate(@Nullable AvlNode<?> root); 164 } 165 166 private long aggregateForEntries(Aggregate aggr) { 167 AvlNode<E> root = rootReference.get(); 168 long total = aggr.treeAggregate(root); 169 if (range.hasLowerBound()) { 170 total -= aggregateBelowRange(aggr, root); 171 } 172 if (range.hasUpperBound()) { 173 total -= aggregateAboveRange(aggr, root); 174 } 175 return total; 176 } 177 178 private long aggregateBelowRange(Aggregate aggr, @Nullable AvlNode<E> node) { 179 if (node == null) { 180 return 0; 181 } 182 // The cast is safe because we call this method only if hasLowerBound(). 183 int cmp = 184 comparator() 185 .compare(uncheckedCastNullableTToT(range.getLowerEndpoint()), node.getElement()); 186 if (cmp < 0) { 187 return aggregateBelowRange(aggr, node.left); 188 } else if (cmp == 0) { 189 switch (range.getLowerBoundType()) { 190 case OPEN: 191 return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left); 192 case CLOSED: 193 return aggr.treeAggregate(node.left); 194 } 195 throw new AssertionError(); 196 } else { 197 return aggr.treeAggregate(node.left) 198 + aggr.nodeAggregate(node) 199 + aggregateBelowRange(aggr, node.right); 200 } 201 } 202 203 private long aggregateAboveRange(Aggregate aggr, @Nullable AvlNode<E> node) { 204 if (node == null) { 205 return 0; 206 } 207 // The cast is safe because we call this method only if hasUpperBound(). 208 int cmp = 209 comparator() 210 .compare(uncheckedCastNullableTToT(range.getUpperEndpoint()), node.getElement()); 211 if (cmp > 0) { 212 return aggregateAboveRange(aggr, node.right); 213 } else if (cmp == 0) { 214 switch (range.getUpperBoundType()) { 215 case OPEN: 216 return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right); 217 case CLOSED: 218 return aggr.treeAggregate(node.right); 219 } 220 throw new AssertionError(); 221 } else { 222 return aggr.treeAggregate(node.right) 223 + aggr.nodeAggregate(node) 224 + aggregateAboveRange(aggr, node.left); 225 } 226 } 227 228 @Override 229 public int size() { 230 return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE)); 231 } 232 233 @Override 234 int distinctElements() { 235 return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT)); 236 } 237 238 static int distinctElements(@Nullable AvlNode<?> node) { 239 return (node == null) ? 0 : node.distinctElements; 240 } 241 242 @Override 243 public int count(@Nullable Object element) { 244 try { 245 @SuppressWarnings("unchecked") 246 E e = (E) element; 247 AvlNode<E> root = rootReference.get(); 248 if (!range.contains(e) || root == null) { 249 return 0; 250 } 251 return root.count(comparator(), e); 252 } catch (ClassCastException | NullPointerException e) { 253 return 0; 254 } 255 } 256 257 @CanIgnoreReturnValue 258 @Override 259 public int add(@ParametricNullness E element, int occurrences) { 260 checkNonnegative(occurrences, "occurrences"); 261 if (occurrences == 0) { 262 return count(element); 263 } 264 checkArgument(range.contains(element)); 265 AvlNode<E> root = rootReference.get(); 266 if (root == null) { 267 int unused = comparator().compare(element, element); 268 AvlNode<E> newRoot = new AvlNode<>(element, occurrences); 269 successor(header, newRoot, header); 270 rootReference.checkAndSet(root, newRoot); 271 return 0; 272 } 273 int[] result = new int[1]; // used as a mutable int reference to hold result 274 AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result); 275 rootReference.checkAndSet(root, newRoot); 276 return result[0]; 277 } 278 279 @CanIgnoreReturnValue 280 @Override 281 public int remove(@Nullable Object element, int occurrences) { 282 checkNonnegative(occurrences, "occurrences"); 283 if (occurrences == 0) { 284 return count(element); 285 } 286 AvlNode<E> root = rootReference.get(); 287 int[] result = new int[1]; // used as a mutable int reference to hold result 288 AvlNode<E> newRoot; 289 try { 290 @SuppressWarnings("unchecked") 291 E e = (E) element; 292 if (!range.contains(e) || root == null) { 293 return 0; 294 } 295 newRoot = root.remove(comparator(), e, occurrences, result); 296 } catch (ClassCastException | NullPointerException e) { 297 return 0; 298 } 299 rootReference.checkAndSet(root, newRoot); 300 return result[0]; 301 } 302 303 @CanIgnoreReturnValue 304 @Override 305 public int setCount(@ParametricNullness E element, int count) { 306 checkNonnegative(count, "count"); 307 if (!range.contains(element)) { 308 checkArgument(count == 0); 309 return 0; 310 } 311 312 AvlNode<E> root = rootReference.get(); 313 if (root == null) { 314 if (count > 0) { 315 add(element, count); 316 } 317 return 0; 318 } 319 int[] result = new int[1]; // used as a mutable int reference to hold result 320 AvlNode<E> newRoot = root.setCount(comparator(), element, count, result); 321 rootReference.checkAndSet(root, newRoot); 322 return result[0]; 323 } 324 325 @CanIgnoreReturnValue 326 @Override 327 public boolean setCount(@ParametricNullness E element, int oldCount, int newCount) { 328 checkNonnegative(newCount, "newCount"); 329 checkNonnegative(oldCount, "oldCount"); 330 checkArgument(range.contains(element)); 331 332 AvlNode<E> root = rootReference.get(); 333 if (root == null) { 334 if (oldCount == 0) { 335 if (newCount > 0) { 336 add(element, newCount); 337 } 338 return true; 339 } else { 340 return false; 341 } 342 } 343 int[] result = new int[1]; // used as a mutable int reference to hold result 344 AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result); 345 rootReference.checkAndSet(root, newRoot); 346 return result[0] == oldCount; 347 } 348 349 @Override 350 public void clear() { 351 if (!range.hasLowerBound() && !range.hasUpperBound()) { 352 // We can do this in O(n) rather than removing one by one, which could force rebalancing. 353 for (AvlNode<E> current = header.succ(); current != header; ) { 354 AvlNode<E> next = current.succ(); 355 356 current.elemCount = 0; 357 // Also clear these fields so that one deleted Entry doesn't retain all elements. 358 current.left = null; 359 current.right = null; 360 current.pred = null; 361 current.succ = null; 362 363 current = next; 364 } 365 successor(header, header); 366 rootReference.clear(); 367 } else { 368 // TODO(cpovirk): Perhaps we can optimize in this case, too? 369 Iterators.clear(entryIterator()); 370 } 371 } 372 373 private Entry<E> wrapEntry(AvlNode<E> baseEntry) { 374 return new Multisets.AbstractEntry<E>() { 375 @Override 376 @ParametricNullness 377 public E getElement() { 378 return baseEntry.getElement(); 379 } 380 381 @Override 382 public int getCount() { 383 int result = baseEntry.getCount(); 384 if (result == 0) { 385 return count(getElement()); 386 } else { 387 return result; 388 } 389 } 390 }; 391 } 392 393 /** Returns the first node in the tree that is in range. */ 394 private @Nullable AvlNode<E> firstNode() { 395 AvlNode<E> root = rootReference.get(); 396 if (root == null) { 397 return null; 398 } 399 AvlNode<E> node; 400 if (range.hasLowerBound()) { 401 // The cast is safe because of the hasLowerBound check. 402 E endpoint = uncheckedCastNullableTToT(range.getLowerEndpoint()); 403 node = root.ceiling(comparator(), endpoint); 404 if (node == null) { 405 return null; 406 } 407 if (range.getLowerBoundType() == BoundType.OPEN 408 && comparator().compare(endpoint, node.getElement()) == 0) { 409 node = node.succ(); 410 } 411 } else { 412 node = header.succ(); 413 } 414 return (node == header || !range.contains(node.getElement())) ? null : node; 415 } 416 417 private @Nullable AvlNode<E> lastNode() { 418 AvlNode<E> root = rootReference.get(); 419 if (root == null) { 420 return null; 421 } 422 AvlNode<E> node; 423 if (range.hasUpperBound()) { 424 // The cast is safe because of the hasUpperBound check. 425 E endpoint = uncheckedCastNullableTToT(range.getUpperEndpoint()); 426 node = root.floor(comparator(), endpoint); 427 if (node == null) { 428 return null; 429 } 430 if (range.getUpperBoundType() == BoundType.OPEN 431 && comparator().compare(endpoint, node.getElement()) == 0) { 432 node = node.pred(); 433 } 434 } else { 435 node = header.pred(); 436 } 437 return (node == header || !range.contains(node.getElement())) ? null : node; 438 } 439 440 @Override 441 Iterator<E> elementIterator() { 442 return Multisets.elementIterator(entryIterator()); 443 } 444 445 @Override 446 Iterator<Entry<E>> entryIterator() { 447 return new Iterator<Entry<E>>() { 448 @Nullable AvlNode<E> current = firstNode(); 449 @Nullable Entry<E> prevEntry; 450 451 @Override 452 public boolean hasNext() { 453 if (current == null) { 454 return false; 455 } else if (range.tooHigh(current.getElement())) { 456 current = null; 457 return false; 458 } else { 459 return true; 460 } 461 } 462 463 @Override 464 public Entry<E> next() { 465 if (!hasNext()) { 466 throw new NoSuchElementException(); 467 } 468 // requireNonNull is safe because current is only nulled out after iteration is complete. 469 Entry<E> result = wrapEntry(requireNonNull(current)); 470 prevEntry = result; 471 if (current.succ() == header) { 472 current = null; 473 } else { 474 current = current.succ(); 475 } 476 return result; 477 } 478 479 @Override 480 public void remove() { 481 checkState(prevEntry != null, "no calls to next() since the last call to remove()"); 482 setCount(prevEntry.getElement(), 0); 483 prevEntry = null; 484 } 485 }; 486 } 487 488 @Override 489 Iterator<Entry<E>> descendingEntryIterator() { 490 return new Iterator<Entry<E>>() { 491 @Nullable AvlNode<E> current = lastNode(); 492 @Nullable Entry<E> prevEntry = null; 493 494 @Override 495 public boolean hasNext() { 496 if (current == null) { 497 return false; 498 } else if (range.tooLow(current.getElement())) { 499 current = null; 500 return false; 501 } else { 502 return true; 503 } 504 } 505 506 @Override 507 public Entry<E> next() { 508 if (!hasNext()) { 509 throw new NoSuchElementException(); 510 } 511 // requireNonNull is safe because current is only nulled out after iteration is complete. 512 requireNonNull(current); 513 Entry<E> result = wrapEntry(current); 514 prevEntry = result; 515 if (current.pred() == header) { 516 current = null; 517 } else { 518 current = current.pred(); 519 } 520 return result; 521 } 522 523 @Override 524 public void remove() { 525 checkState(prevEntry != null, "no calls to next() since the last call to remove()"); 526 setCount(prevEntry.getElement(), 0); 527 prevEntry = null; 528 } 529 }; 530 } 531 532 @Override 533 public void forEachEntry(ObjIntConsumer<? super E> action) { 534 checkNotNull(action); 535 for (AvlNode<E> node = firstNode(); 536 node != header && node != null && !range.tooHigh(node.getElement()); 537 node = node.succ()) { 538 action.accept(node.getElement(), node.getCount()); 539 } 540 } 541 542 @Override 543 public Iterator<E> iterator() { 544 return Multisets.iteratorImpl(this); 545 } 546 547 @Override 548 public SortedMultiset<E> headMultiset(@ParametricNullness E upperBound, BoundType boundType) { 549 return new TreeMultiset<>( 550 rootReference, 551 range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)), 552 header); 553 } 554 555 @Override 556 public SortedMultiset<E> tailMultiset(@ParametricNullness E lowerBound, BoundType boundType) { 557 return new TreeMultiset<>( 558 rootReference, 559 range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)), 560 header); 561 } 562 563 private static final class Reference<T> { 564 private @Nullable T value; 565 566 public @Nullable T get() { 567 return value; 568 } 569 570 public void checkAndSet(@Nullable T expected, @Nullable T newValue) { 571 if (value != expected) { 572 throw new ConcurrentModificationException(); 573 } 574 value = newValue; 575 } 576 577 void clear() { 578 value = null; 579 } 580 } 581 582 private static final class AvlNode<E extends @Nullable Object> { 583 /* 584 * For "normal" nodes, the type of this field is `E`, not `@Nullable E` (though note that E is a 585 * type that can include null, as in a TreeMultiset<@Nullable String>). 586 * 587 * For the header node, though, this field contains `null`, regardless of the type of the 588 * multiset. 589 * 590 * Most code that operates on an AvlNode never operates on the header node. Such code can access 591 * the elem field without a null check by calling getElement(). 592 */ 593 private final @Nullable E elem; 594 595 // elemCount is 0 iff this node has been deleted. 596 private int elemCount; 597 598 private int distinctElements; 599 private long totalCount; 600 private int height; 601 private @Nullable AvlNode<E> left; 602 private @Nullable AvlNode<E> right; 603 /* 604 * pred and succ are nullable after construction, but we always call successor() to initialize 605 * them immediately thereafter. 606 * 607 * They may be subsequently nulled out by TreeMultiset.clear(). I think that the only place that 608 * we can reference a node whose fields have been cleared is inside the iterator (and presumably 609 * only under concurrent modification). 610 * 611 * To access these fields when you know that they are not null, call the pred() and succ() 612 * methods, which perform null checks before returning the fields. 613 */ 614 private @Nullable AvlNode<E> pred; 615 private @Nullable AvlNode<E> succ; 616 617 AvlNode(@ParametricNullness E elem, int elemCount) { 618 checkArgument(elemCount > 0); 619 this.elem = elem; 620 this.elemCount = elemCount; 621 this.totalCount = elemCount; 622 this.distinctElements = 1; 623 this.height = 1; 624 this.left = null; 625 this.right = null; 626 } 627 628 /** Constructor for the header node. */ 629 AvlNode() { 630 this.elem = null; 631 this.elemCount = 1; 632 } 633 634 // For discussion of pred() and succ(), see the comment on the pred and succ fields. 635 636 private AvlNode<E> pred() { 637 return requireNonNull(pred); 638 } 639 640 private AvlNode<E> succ() { 641 return requireNonNull(succ); 642 } 643 644 int count(Comparator<? super E> comparator, @ParametricNullness E e) { 645 int cmp = comparator.compare(e, getElement()); 646 if (cmp < 0) { 647 return (left == null) ? 0 : left.count(comparator, e); 648 } else if (cmp > 0) { 649 return (right == null) ? 0 : right.count(comparator, e); 650 } else { 651 return elemCount; 652 } 653 } 654 655 @CanIgnoreReturnValue 656 private AvlNode<E> addRightChild(@ParametricNullness E e, int count) { 657 right = new AvlNode<>(e, count); 658 successor(this, right, succ()); 659 height = max(2, height); 660 distinctElements++; 661 totalCount += count; 662 return this; 663 } 664 665 @CanIgnoreReturnValue 666 private AvlNode<E> addLeftChild(@ParametricNullness E e, int count) { 667 left = new AvlNode<>(e, count); 668 successor(pred(), left, this); 669 height = max(2, height); 670 distinctElements++; 671 totalCount += count; 672 return this; 673 } 674 675 AvlNode<E> add( 676 Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) { 677 /* 678 * It speeds things up considerably to unconditionally add count to totalCount here, 679 * but that destroys failure atomicity in the case of count overflow. =( 680 */ 681 int cmp = comparator.compare(e, getElement()); 682 if (cmp < 0) { 683 AvlNode<E> initLeft = left; 684 if (initLeft == null) { 685 result[0] = 0; 686 return addLeftChild(e, count); 687 } 688 int initHeight = initLeft.height; 689 690 left = initLeft.add(comparator, e, count, result); 691 if (result[0] == 0) { 692 distinctElements++; 693 } 694 this.totalCount += count; 695 return (left.height == initHeight) ? this : rebalance(); 696 } else if (cmp > 0) { 697 AvlNode<E> initRight = right; 698 if (initRight == null) { 699 result[0] = 0; 700 return addRightChild(e, count); 701 } 702 int initHeight = initRight.height; 703 704 right = initRight.add(comparator, e, count, result); 705 if (result[0] == 0) { 706 distinctElements++; 707 } 708 this.totalCount += count; 709 return (right.height == initHeight) ? this : rebalance(); 710 } 711 712 // adding count to me! No rebalance possible. 713 result[0] = elemCount; 714 long resultCount = (long) elemCount + count; 715 checkArgument(resultCount <= Integer.MAX_VALUE); 716 this.elemCount += count; 717 this.totalCount += count; 718 return this; 719 } 720 721 @Nullable AvlNode<E> remove( 722 Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) { 723 int cmp = comparator.compare(e, getElement()); 724 if (cmp < 0) { 725 AvlNode<E> initLeft = left; 726 if (initLeft == null) { 727 result[0] = 0; 728 return this; 729 } 730 731 left = initLeft.remove(comparator, e, count, result); 732 733 if (result[0] > 0) { 734 if (count >= result[0]) { 735 this.distinctElements--; 736 this.totalCount -= result[0]; 737 } else { 738 this.totalCount -= count; 739 } 740 } 741 return (result[0] == 0) ? this : rebalance(); 742 } else if (cmp > 0) { 743 AvlNode<E> initRight = right; 744 if (initRight == null) { 745 result[0] = 0; 746 return this; 747 } 748 749 right = initRight.remove(comparator, e, count, result); 750 751 if (result[0] > 0) { 752 if (count >= result[0]) { 753 this.distinctElements--; 754 this.totalCount -= result[0]; 755 } else { 756 this.totalCount -= count; 757 } 758 } 759 return rebalance(); 760 } 761 762 // removing count from me! 763 result[0] = elemCount; 764 if (count >= elemCount) { 765 return deleteMe(); 766 } else { 767 this.elemCount -= count; 768 this.totalCount -= count; 769 return this; 770 } 771 } 772 773 @Nullable AvlNode<E> setCount( 774 Comparator<? super E> comparator, @ParametricNullness E e, int count, int[] result) { 775 int cmp = comparator.compare(e, getElement()); 776 if (cmp < 0) { 777 AvlNode<E> initLeft = left; 778 if (initLeft == null) { 779 result[0] = 0; 780 return (count > 0) ? addLeftChild(e, count) : this; 781 } 782 783 left = initLeft.setCount(comparator, e, count, result); 784 785 if (count == 0 && result[0] != 0) { 786 this.distinctElements--; 787 } else if (count > 0 && result[0] == 0) { 788 this.distinctElements++; 789 } 790 791 this.totalCount += count - result[0]; 792 return rebalance(); 793 } else if (cmp > 0) { 794 AvlNode<E> initRight = right; 795 if (initRight == null) { 796 result[0] = 0; 797 return (count > 0) ? addRightChild(e, count) : this; 798 } 799 800 right = initRight.setCount(comparator, e, count, result); 801 802 if (count == 0 && result[0] != 0) { 803 this.distinctElements--; 804 } else if (count > 0 && result[0] == 0) { 805 this.distinctElements++; 806 } 807 808 this.totalCount += count - result[0]; 809 return rebalance(); 810 } 811 812 // setting my count 813 result[0] = elemCount; 814 if (count == 0) { 815 return deleteMe(); 816 } 817 this.totalCount += count - elemCount; 818 this.elemCount = count; 819 return this; 820 } 821 822 @Nullable AvlNode<E> setCount( 823 Comparator<? super E> comparator, 824 @ParametricNullness E e, 825 int expectedCount, 826 int newCount, 827 int[] result) { 828 int cmp = comparator.compare(e, getElement()); 829 if (cmp < 0) { 830 AvlNode<E> initLeft = left; 831 if (initLeft == null) { 832 result[0] = 0; 833 if (expectedCount == 0 && newCount > 0) { 834 return addLeftChild(e, newCount); 835 } 836 return this; 837 } 838 839 left = initLeft.setCount(comparator, e, expectedCount, newCount, result); 840 841 if (result[0] == expectedCount) { 842 if (newCount == 0 && result[0] != 0) { 843 this.distinctElements--; 844 } else if (newCount > 0 && result[0] == 0) { 845 this.distinctElements++; 846 } 847 this.totalCount += newCount - result[0]; 848 } 849 return rebalance(); 850 } else if (cmp > 0) { 851 AvlNode<E> initRight = right; 852 if (initRight == null) { 853 result[0] = 0; 854 if (expectedCount == 0 && newCount > 0) { 855 return addRightChild(e, newCount); 856 } 857 return this; 858 } 859 860 right = initRight.setCount(comparator, e, expectedCount, newCount, result); 861 862 if (result[0] == expectedCount) { 863 if (newCount == 0 && result[0] != 0) { 864 this.distinctElements--; 865 } else if (newCount > 0 && result[0] == 0) { 866 this.distinctElements++; 867 } 868 this.totalCount += newCount - result[0]; 869 } 870 return rebalance(); 871 } 872 873 // setting my count 874 result[0] = elemCount; 875 if (expectedCount == elemCount) { 876 if (newCount == 0) { 877 return deleteMe(); 878 } 879 this.totalCount += newCount - elemCount; 880 this.elemCount = newCount; 881 } 882 return this; 883 } 884 885 private @Nullable AvlNode<E> deleteMe() { 886 int oldElemCount = this.elemCount; 887 this.elemCount = 0; 888 successor(pred(), succ()); 889 if (left == null) { 890 return right; 891 } else if (right == null) { 892 return left; 893 } else if (left.height >= right.height) { 894 AvlNode<E> newTop = pred(); 895 // newTop is the maximum node in my left subtree 896 newTop.left = left.removeMax(newTop); 897 newTop.right = right; 898 newTop.distinctElements = distinctElements - 1; 899 newTop.totalCount = totalCount - oldElemCount; 900 return newTop.rebalance(); 901 } else { 902 AvlNode<E> newTop = succ(); 903 newTop.right = right.removeMin(newTop); 904 newTop.left = left; 905 newTop.distinctElements = distinctElements - 1; 906 newTop.totalCount = totalCount - oldElemCount; 907 return newTop.rebalance(); 908 } 909 } 910 911 // Removes the minimum node from this subtree to be reused elsewhere 912 private @Nullable AvlNode<E> removeMin(AvlNode<E> node) { 913 if (left == null) { 914 return right; 915 } else { 916 left = left.removeMin(node); 917 distinctElements--; 918 totalCount -= node.elemCount; 919 return rebalance(); 920 } 921 } 922 923 // Removes the maximum node from this subtree to be reused elsewhere 924 private @Nullable AvlNode<E> removeMax(AvlNode<E> node) { 925 if (right == null) { 926 return left; 927 } else { 928 right = right.removeMax(node); 929 distinctElements--; 930 totalCount -= node.elemCount; 931 return rebalance(); 932 } 933 } 934 935 private void recomputeMultiset() { 936 this.distinctElements = 937 1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right); 938 this.totalCount = elemCount + totalCount(left) + totalCount(right); 939 } 940 941 private void recomputeHeight() { 942 this.height = 1 + max(height(left), height(right)); 943 } 944 945 private void recompute() { 946 recomputeMultiset(); 947 recomputeHeight(); 948 } 949 950 private AvlNode<E> rebalance() { 951 switch (balanceFactor()) { 952 case -2: 953 // requireNonNull is safe because right must exist in order to get a negative factor. 954 requireNonNull(right); 955 if (right.balanceFactor() > 0) { 956 right = right.rotateRight(); 957 } 958 return rotateLeft(); 959 case 2: 960 // requireNonNull is safe because left must exist in order to get a positive factor. 961 requireNonNull(left); 962 if (left.balanceFactor() < 0) { 963 left = left.rotateLeft(); 964 } 965 return rotateRight(); 966 default: 967 recomputeHeight(); 968 return this; 969 } 970 } 971 972 private int balanceFactor() { 973 return height(left) - height(right); 974 } 975 976 private AvlNode<E> rotateLeft() { 977 checkState(right != null); 978 AvlNode<E> newTop = right; 979 this.right = newTop.left; 980 newTop.left = this; 981 newTop.totalCount = this.totalCount; 982 newTop.distinctElements = this.distinctElements; 983 this.recompute(); 984 newTop.recomputeHeight(); 985 return newTop; 986 } 987 988 private AvlNode<E> rotateRight() { 989 checkState(left != null); 990 AvlNode<E> newTop = left; 991 this.left = newTop.right; 992 newTop.right = this; 993 newTop.totalCount = this.totalCount; 994 newTop.distinctElements = this.distinctElements; 995 this.recompute(); 996 newTop.recomputeHeight(); 997 return newTop; 998 } 999 1000 private static long totalCount(@Nullable AvlNode<?> node) { 1001 return (node == null) ? 0 : node.totalCount; 1002 } 1003 1004 private static int height(@Nullable AvlNode<?> node) { 1005 return (node == null) ? 0 : node.height; 1006 } 1007 1008 private @Nullable AvlNode<E> ceiling( 1009 Comparator<? super E> comparator, @ParametricNullness E e) { 1010 int cmp = comparator.compare(e, getElement()); 1011 if (cmp < 0) { 1012 return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this); 1013 } else if (cmp == 0) { 1014 return this; 1015 } else { 1016 return (right == null) ? null : right.ceiling(comparator, e); 1017 } 1018 } 1019 1020 private @Nullable AvlNode<E> floor(Comparator<? super E> comparator, @ParametricNullness E e) { 1021 int cmp = comparator.compare(e, getElement()); 1022 if (cmp > 0) { 1023 return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this); 1024 } else if (cmp == 0) { 1025 return this; 1026 } else { 1027 return (left == null) ? null : left.floor(comparator, e); 1028 } 1029 } 1030 1031 @ParametricNullness 1032 E getElement() { 1033 // For discussion of this cast, see the comment on the elem field. 1034 return uncheckedCastNullableTToT(elem); 1035 } 1036 1037 int getCount() { 1038 return elemCount; 1039 } 1040 1041 @Override 1042 public String toString() { 1043 return Multisets.immutableEntry(getElement(), getCount()).toString(); 1044 } 1045 } 1046 1047 private static <T extends @Nullable Object> void successor(AvlNode<T> a, AvlNode<T> b) { 1048 a.succ = b; 1049 b.pred = a; 1050 } 1051 1052 private static <T extends @Nullable Object> void successor( 1053 AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) { 1054 successor(a, b); 1055 successor(b, c); 1056 } 1057 1058 /* 1059 * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that 1060 * calls the comparator to compare the two keys. If that change is made, 1061 * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets. 1062 */ 1063 1064 /** 1065 * @serialData the comparator, the number of distinct elements, the first element, its count, the 1066 * second element, its count, and so on 1067 */ 1068 @J2ktIncompatible 1069 @GwtIncompatible // java.io.ObjectOutputStream 1070 private void writeObject(ObjectOutputStream stream) throws IOException { 1071 stream.defaultWriteObject(); 1072 stream.writeObject(elementSet().comparator()); 1073 Serialization.writeMultiset(this, stream); 1074 } 1075 1076 @J2ktIncompatible 1077 @GwtIncompatible // java.io.ObjectInputStream 1078 private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException { 1079 stream.defaultReadObject(); 1080 @SuppressWarnings("unchecked") 1081 // reading data stored by writeObject 1082 Comparator<? super E> comparator = (Comparator<? super E>) requireNonNull(stream.readObject()); 1083 Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator); 1084 Serialization.getFieldSetter(TreeMultiset.class, "range") 1085 .set(this, GeneralRange.all(comparator)); 1086 Serialization.getFieldSetter(TreeMultiset.class, "rootReference") 1087 .set(this, new Reference<AvlNode<E>>()); 1088 AvlNode<E> header = new AvlNode<>(); 1089 Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header); 1090 successor(header, header); 1091 Serialization.populateMultiset(this, stream); 1092 } 1093 1094 @GwtIncompatible @J2ktIncompatible private static final long serialVersionUID = 1; 1095}