001    /*
002     * Copyright (C) 2007 The Guava Authors
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package com.google.common.collect;
018    
019    import static com.google.common.base.Preconditions.checkArgument;
020    import static com.google.common.base.Preconditions.checkNotNull;
021    import static com.google.common.base.Preconditions.checkState;
022    import static com.google.common.collect.BstSide.LEFT;
023    import static com.google.common.collect.BstSide.RIGHT;
024    
025    import java.io.IOException;
026    import java.io.ObjectInputStream;
027    import java.io.ObjectOutputStream;
028    import java.io.Serializable;
029    import java.util.Comparator;
030    import java.util.ConcurrentModificationException;
031    import java.util.Iterator;
032    
033    import javax.annotation.Nullable;
034    
035    import com.google.common.annotations.GwtCompatible;
036    import com.google.common.annotations.GwtIncompatible;
037    import com.google.common.primitives.Ints;
038    
039    /**
040     * A multiset which maintains the ordering of its elements, according to either
041     * their natural order or an explicit {@link Comparator}. In all cases, this
042     * implementation uses {@link Comparable#compareTo} or {@link
043     * Comparator#compare} instead of {@link Object#equals} to determine
044     * equivalence of instances.
045     *
046     * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as
047     * explained by the {@link Comparable} class specification. Otherwise, the
048     * resulting multiset will violate the {@link java.util.Collection} contract,
049     * which is specified in terms of {@link Object#equals}.
050     *
051     * @author Louis Wasserman
052     * @author Jared Levy
053     * @since 2.0 (imported from Google Collections Library)
054     */
055    @GwtCompatible(emulated = true)
056    public final class TreeMultiset<E> extends AbstractSortedMultiset<E>
057        implements Serializable {
058    
059      /**
060       * Creates a new, empty multiset, sorted according to the elements' natural
061       * order. All elements inserted into the multiset must implement the
062       * {@code Comparable} interface. Furthermore, all such elements must be
063       * <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
064       * {@code ClassCastException} for any elements {@code e1} and {@code e2} in
065       * the multiset. If the user attempts to add an element to the multiset that
066       * violates this constraint (for example, the user attempts to add a string
067       * element to a set whose elements are integers), the {@code add(Object)}
068       * call will throw a {@code ClassCastException}.
069       *
070       * <p>The type specification is {@code <E extends Comparable>}, instead of the
071       * more specific {@code <E extends Comparable<? super E>>}, to support
072       * classes defined without generics.
073       */
074      public static <E extends Comparable> TreeMultiset<E> create() {
075        return new TreeMultiset<E>(Ordering.natural());
076      }
077    
078      /**
079       * Creates a new, empty multiset, sorted according to the specified
080       * comparator. All elements inserted into the multiset must be <i>mutually
081       * comparable</i> by the specified comparator: {@code comparator.compare(e1,
082       * e2)} must not throw a {@code ClassCastException} for any elements {@code
083       * e1} and {@code e2} in the multiset. If the user attempts to add an element
084       * to the multiset that violates this constraint, the {@code add(Object)} call
085       * will throw a {@code ClassCastException}.
086       *
087       * @param comparator the comparator that will be used to sort this multiset. A
088       *     null value indicates that the elements' <i>natural ordering</i> should
089       *     be used.
090       */
091      @SuppressWarnings("unchecked")
092      public static <E> TreeMultiset<E> create(
093          @Nullable Comparator<? super E> comparator) {
094        return (comparator == null)
095               ? new TreeMultiset<E>((Comparator) Ordering.natural())
096               : new TreeMultiset<E>(comparator);
097      }
098    
099      /**
100       * Creates an empty multiset containing the given initial elements, sorted
101       * according to the elements' natural order.
102       *
103       * <p>This implementation is highly efficient when {@code elements} is itself
104       * a {@link Multiset}.
105       *
106       * <p>The type specification is {@code <E extends Comparable>}, instead of the
107       * more specific {@code <E extends Comparable<? super E>>}, to support
108       * classes defined without generics.
109       */
110      public static <E extends Comparable> TreeMultiset<E> create(
111          Iterable<? extends E> elements) {
112        TreeMultiset<E> multiset = create();
113        Iterables.addAll(multiset, elements);
114        return multiset;
115      }
116    
117      /**
118       * Returns an iterator over the elements contained in this collection.
119       */
120      @Override
121      public Iterator<E> iterator() {
122        // Needed to avoid Javadoc bug.
123        return super.iterator();
124      }
125    
126      private TreeMultiset(Comparator<? super E> comparator) {
127        super(comparator);
128        this.range = GeneralRange.all(comparator);
129        this.rootReference = new Reference<Node<E>>();
130      }
131    
132      private TreeMultiset(GeneralRange<E> range, Reference<Node<E>> root) {
133        super(range.comparator());
134        this.range = range;
135        this.rootReference = root;
136      }
137    
138      @SuppressWarnings("unchecked")
139      E checkElement(Object o) {
140        return (E) o;
141      }
142    
143      private transient final GeneralRange<E> range;
144    
145      private transient final Reference<Node<E>> rootReference;
146    
147      static final class Reference<T> {
148        T value;
149    
150        public Reference() {}
151    
152        public T get() {
153          return value;
154        }
155    
156        public boolean compareAndSet(T expected, T newValue) {
157          if (value == expected) {
158            value = newValue;
159            return true;
160          }
161          return false;
162        }
163      }
164    
165      @Override
166      int distinctElements() {
167        Node<E> root = rootReference.get();
168        return Ints.checkedCast(BstRangeOps.totalInRange(distinctAggregate(), range, root));
169      }
170    
171      @Override
172      public int size() {
173        Node<E> root = rootReference.get();
174        return Ints.saturatedCast(BstRangeOps.totalInRange(sizeAggregate(), range, root));
175      }
176    
177      @Override
178      public int count(@Nullable Object element) {
179        try {
180          E e = checkElement(element);
181          if (range.contains(e)) {
182            Node<E> node = BstOperations.seek(comparator(), rootReference.get(), e);
183            return countOrZero(node);
184          }
185          return 0;
186        } catch (ClassCastException e) {
187          return 0;
188        } catch (NullPointerException e) {
189          return 0;
190        }
191      }
192    
193      private int mutate(@Nullable E e, MultisetModifier modifier) {
194        BstMutationRule<E, Node<E>> mutationRule = BstMutationRule.createRule(
195            modifier,
196            BstCountBasedBalancePolicies.
197              <E, Node<E>>singleRebalancePolicy(distinctAggregate()),
198            nodeFactory());
199        BstMutationResult<E, Node<E>> mutationResult =
200            BstOperations.mutate(comparator(), mutationRule, rootReference.get(), e);
201        if (!rootReference.compareAndSet(
202            mutationResult.getOriginalRoot(), mutationResult.getChangedRoot())) {
203          throw new ConcurrentModificationException();
204        }
205        Node<E> original = mutationResult.getOriginalTarget();
206        return countOrZero(original);
207      }
208    
209      @Override
210      public int add(E element, int occurrences) {
211        checkElement(element);
212        if (occurrences == 0) {
213          return count(element);
214        }
215        checkArgument(range.contains(element));
216        return mutate(element, new AddModifier(occurrences));
217      }
218    
219      @Override
220      public int remove(@Nullable Object element, int occurrences) {
221        if (element == null) {
222          return 0;
223        } else if (occurrences == 0) {
224          return count(element);
225        }
226        try {
227          E e = checkElement(element);
228          return range.contains(e) ? mutate(e, new RemoveModifier(occurrences)) : 0;
229        } catch (ClassCastException e) {
230          return 0;
231        }
232      }
233    
234      @Override
235      public boolean setCount(E element, int oldCount, int newCount) {
236        checkElement(element);
237        checkArgument(range.contains(element));
238        return mutate(element, new ConditionalSetCountModifier(oldCount, newCount))
239            == oldCount;
240      }
241    
242      @Override
243      public int setCount(E element, int count) {
244        checkElement(element);
245        checkArgument(range.contains(element));
246        return mutate(element, new SetCountModifier(count));
247      }
248    
249      private BstPathFactory<Node<E>, BstInOrderPath<Node<E>>> pathFactory() {
250        return BstInOrderPath.inOrderFactory();
251      }
252    
253      @Override
254      Iterator<Entry<E>> entryIterator() {
255        Node<E> root = rootReference.get();
256        final BstInOrderPath<Node<E>> startingPath =
257            BstRangeOps.furthestPath(range, LEFT, pathFactory(), root);
258        return iteratorInDirection(startingPath, RIGHT);
259      }
260    
261      @Override
262      Iterator<Entry<E>> descendingEntryIterator() {
263        Node<E> root = rootReference.get();
264        final BstInOrderPath<Node<E>> startingPath =
265            BstRangeOps.furthestPath(range, RIGHT, pathFactory(), root);
266        return iteratorInDirection(startingPath, LEFT);
267      }
268    
269      private Iterator<Entry<E>> iteratorInDirection(
270          @Nullable BstInOrderPath<Node<E>> start, final BstSide direction) {
271        final Iterator<BstInOrderPath<Node<E>>> pathIterator =
272            new AbstractLinkedIterator<BstInOrderPath<Node<E>>>(start) {
273              @Override
274              protected BstInOrderPath<Node<E>> computeNext(BstInOrderPath<Node<E>> previous) {
275                if (!previous.hasNext(direction)) {
276                  return null;
277                }
278                BstInOrderPath<Node<E>> next = previous.next(direction);
279                // TODO(user): only check against one side
280                return range.contains(next.getTip().getKey()) ? next : null;
281              }
282            };
283        return new Iterator<Entry<E>>() {
284          E toRemove = null;
285    
286          @Override
287          public boolean hasNext() {
288            return pathIterator.hasNext();
289          }
290    
291          @Override
292          public Entry<E> next() {
293            BstInOrderPath<Node<E>> path = pathIterator.next();
294            return new LiveEntry(
295                toRemove = path.getTip().getKey(), path.getTip().elemCount());
296          }
297    
298          @Override
299          public void remove() {
300            checkState(toRemove != null);
301            setCount(toRemove, 0);
302            toRemove = null;
303          }
304        };
305      }
306    
307      class LiveEntry extends Multisets.AbstractEntry<E> {
308        private Node<E> expectedRoot;
309        private final E element;
310        private int count;
311    
312        private LiveEntry(E element, int count) {
313          this.expectedRoot = rootReference.get();
314          this.element = element;
315          this.count = count;
316        }
317    
318        @Override
319        public E getElement() {
320          return element;
321        }
322    
323        @Override
324        public int getCount() {
325          if (rootReference.get() == expectedRoot) {
326            return count;
327          } else {
328            // check for updates
329            expectedRoot = rootReference.get();
330            return count = TreeMultiset.this.count(element);
331          }
332        }
333      }
334    
335      @Override
336      public void clear() {
337        Node<E> root = rootReference.get();
338        Node<E> cleared = BstRangeOps.minusRange(range,
339            BstCountBasedBalancePolicies.<E, Node<E>>fullRebalancePolicy(distinctAggregate()),
340            nodeFactory(), root);
341        if (!rootReference.compareAndSet(root, cleared)) {
342          throw new ConcurrentModificationException();
343        }
344      }
345    
346      @Override
347      public SortedMultiset<E> headMultiset(E upperBound, BoundType boundType) {
348        checkNotNull(upperBound);
349        return new TreeMultiset<E>(
350            range.intersect(GeneralRange.upTo(comparator, upperBound, boundType)), rootReference);
351      }
352    
353      @Override
354      public SortedMultiset<E> tailMultiset(E lowerBound, BoundType boundType) {
355        checkNotNull(lowerBound);
356        return new TreeMultiset<E>(
357            range.intersect(GeneralRange.downTo(comparator, lowerBound, boundType)), rootReference);
358      }
359    
360      /**
361       * {@inheritDoc}
362       *
363       * @since 11.0
364       */
365      @Override
366      public Comparator<? super E> comparator() {
367        return super.comparator();
368      }
369    
370      private static final class Node<E> extends BstNode<E, Node<E>> implements Serializable {
371        private final long size;
372        private final int distinct;
373    
374        private Node(E key, int elemCount, @Nullable Node<E> left,
375            @Nullable Node<E> right) {
376          super(key, left, right);
377          checkArgument(elemCount > 0);
378          this.size = (long) elemCount + sizeOrZero(left) + sizeOrZero(right);
379          this.distinct = 1 + distinctOrZero(left) + distinctOrZero(right);
380        }
381    
382        int elemCount() {
383          long result = size - sizeOrZero(childOrNull(LEFT))
384              - sizeOrZero(childOrNull(RIGHT));
385          return Ints.checkedCast(result);
386        }
387    
388        private Node(E key, int elemCount) {
389          this(key, elemCount, null, null);
390        }
391    
392        private static final long serialVersionUID = 0;
393      }
394    
395      private static long sizeOrZero(@Nullable Node<?> node) {
396        return (node == null) ? 0 : node.size;
397      }
398    
399      private static int distinctOrZero(@Nullable Node<?> node) {
400        return (node == null) ? 0 : node.distinct;
401      }
402    
403      private static int countOrZero(@Nullable Node<?> entry) {
404        return (entry == null) ? 0 : entry.elemCount();
405      }
406    
407      @SuppressWarnings("unchecked")
408      private BstAggregate<Node<E>> distinctAggregate() {
409        return (BstAggregate) DISTINCT_AGGREGATE;
410      }
411    
412      private static final BstAggregate<Node<Object>> DISTINCT_AGGREGATE =
413          new BstAggregate<Node<Object>>() {
414        @Override
415        public int entryValue(Node<Object> entry) {
416          return 1;
417        }
418    
419        @Override
420        public long treeValue(@Nullable Node<Object> tree) {
421          return distinctOrZero(tree);
422        }
423      };
424    
425      @SuppressWarnings("unchecked")
426      private BstAggregate<Node<E>> sizeAggregate() {
427        return (BstAggregate) SIZE_AGGREGATE;
428      }
429    
430      private static final BstAggregate<Node<Object>> SIZE_AGGREGATE =
431          new BstAggregate<Node<Object>>() {
432            @Override
433            public int entryValue(Node<Object> entry) {
434              return entry.elemCount();
435            }
436    
437            @Override
438            public long treeValue(@Nullable Node<Object> tree) {
439              return sizeOrZero(tree);
440            }
441          };
442    
443      @SuppressWarnings("unchecked")
444      private BstNodeFactory<Node<E>> nodeFactory() {
445        return (BstNodeFactory) NODE_FACTORY;
446      }
447    
448      private static final BstNodeFactory<Node<Object>> NODE_FACTORY =
449          new BstNodeFactory<Node<Object>>() {
450            @Override
451            public Node<Object> createNode(Node<Object> source, @Nullable Node<Object> left,
452                @Nullable Node<Object> right) {
453              return new Node<Object>(source.getKey(), source.elemCount(), left, right);
454            }
455          };
456    
457      private abstract class MultisetModifier implements BstModifier<E, Node<E>> {
458        abstract int newCount(int oldCount);
459    
460        @Nullable
461        @Override
462        public BstModificationResult<Node<E>> modify(E key, @Nullable Node<E> originalEntry) {
463          int oldCount = countOrZero(originalEntry);
464          int newCount = newCount(oldCount);
465          if (oldCount == newCount) {
466            return BstModificationResult.identity(originalEntry);
467          } else if (newCount == 0) {
468            return BstModificationResult.rebalancingChange(originalEntry, null);
469          } else if (oldCount == 0) {
470            return BstModificationResult.rebalancingChange(null, new Node<E>(key, newCount));
471          } else {
472            return BstModificationResult.rebuildingChange(originalEntry,
473                new Node<E>(originalEntry.getKey(), newCount));
474          }
475        }
476      }
477    
478      private final class AddModifier extends MultisetModifier {
479        private final int countToAdd;
480    
481        private AddModifier(int countToAdd) {
482          checkArgument(countToAdd > 0);
483          this.countToAdd = countToAdd;
484        }
485    
486        @Override
487        int newCount(int oldCount) {
488          checkArgument(countToAdd <= Integer.MAX_VALUE - oldCount, "Cannot add this many elements");
489          return oldCount + countToAdd;
490        }
491      }
492    
493      private final class RemoveModifier extends MultisetModifier {
494        private final int countToRemove;
495    
496        private RemoveModifier(int countToRemove) {
497          checkArgument(countToRemove > 0);
498          this.countToRemove = countToRemove;
499        }
500    
501        @Override
502        int newCount(int oldCount) {
503          return Math.max(0, oldCount - countToRemove);
504        }
505      }
506    
507      private final class SetCountModifier extends MultisetModifier {
508        private final int countToSet;
509    
510        private SetCountModifier(int countToSet) {
511          checkArgument(countToSet >= 0);
512          this.countToSet = countToSet;
513        }
514    
515        @Override
516        int newCount(int oldCount) {
517          return countToSet;
518        }
519      }
520    
521      private final class ConditionalSetCountModifier extends MultisetModifier {
522        private final int expectedCount;
523        private final int setCount;
524    
525        private ConditionalSetCountModifier(int expectedCount, int setCount) {
526          checkArgument(setCount >= 0 & expectedCount >= 0);
527          this.expectedCount = expectedCount;
528          this.setCount = setCount;
529        }
530    
531        @Override
532        int newCount(int oldCount) {
533          return (oldCount == expectedCount) ? setCount : oldCount;
534        }
535      }
536    
537      /*
538       * TODO(jlevy): Decide whether entrySet() should return entries with an
539       * equals() method that calls the comparator to compare the two keys. If that
540       * change is made, AbstractMultiset.equals() can simply check whether two
541       * multisets have equal entry sets.
542       */
543    
544      /**
545       * @serialData the comparator, the number of distinct elements, the first
546       *     element, its count, the second element, its count, and so on
547       */
548      @GwtIncompatible("java.io.ObjectOutputStream")
549      private void writeObject(ObjectOutputStream stream) throws IOException {
550        stream.defaultWriteObject();
551        stream.writeObject(elementSet().comparator());
552        Serialization.writeMultiset(this, stream);
553      }
554    
555      @GwtIncompatible("java.io.ObjectInputStream")
556      private void readObject(ObjectInputStream stream)
557          throws IOException, ClassNotFoundException {
558        stream.defaultReadObject();
559        @SuppressWarnings("unchecked") // reading data stored by writeObject
560        Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
561        Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
562        Serialization.getFieldSetter(TreeMultiset.class, "range").set(this,
563            GeneralRange.all(comparator));
564        Serialization.getFieldSetter(TreeMultiset.class, "rootReference").set(this,
565            new Reference<Node<E>>());
566        Serialization.populateMultiset(this, stream);
567      }
568    
569      @GwtIncompatible("not needed in emulated source")
570      private static final long serialVersionUID = 1;
571    }