遇到的问题
Java 8 开始引入了 Stream, 其中的 api 一直在不断的优化更新完善,Java 9 中更是引入了 ofNullable
还有 takeWhile
和 dropWhile
这两个关键 api。有时候,我们想对 Stream 中的对象进行排重,默认的可以用 distinct 这个 api,例如:
List<String> collect = Arrays.stream("test1,test2,test2,test3,test3".split(",")).distinct().collect(Collectors.toList());
底层实现是LinkedHashMap
,其实这个和下面的实现几乎是等价的:
Set<String> collect = Arrays.stream("test1,test2,test2,test3,test3".split(",")).collect(Collectors.toCollection(LinkedHashSet::new));
结果是一样的,靠hashcode()
方法定位槽,equals()
方法判断是否是同一个对象,如果是则排重被去掉,不是的话保留,通过LinkedHashMap来保留原始顺序。
但是,对于同一个对象,有时候我们排重的方式并不统一,所以最好像sorted
接口一样,能让我们传入比较器,来控制如何判断两个对象相等需要排重。
例如下面的这个对象,我们有时候想按照id
排重,有时候想按照name
进行排重。
@Data @NoArgsConstructor public class User { private int id; private String name; }
解决思考
首先来实现这个distinct
方法。首先,我们定义一个Key
类用来代理 hashcode 还有 equals 方法:
private static final class Key<E> { //要比较的对象 private final E e; //获取对象的hashcode的方法 private final ToIntFunction<E> hashCode; //判断两个对象是否相等的方法 private final BiPredicate<E, E> equals; public Key(E e, ToIntFunction<E> hashCode, BiPredicate<E, E> equals) { this.e = e; this.hashCode = hashCode; this.equals = equals; } @Override public int hashCode() { return hashCode.applyAsInt(e); } @Override public boolean equals(Object obj) { if (!(obj instanceof Key)) { return false; } @SuppressWarnings("unchecked") Key<E> that = (Key<E>) obj; return equals.test(this.e, that.e); } }
然后,增加新的distinct
方法:
public Stream<T> distinct ( ToIntFunction<T> hashCode, BiPredicate<T, T> equals, //排重的时候,保留哪一个? BinaryOperator<T> merger ) { return this.collect(Collectors.toMap( t -> new Key<>(t, hashCode, equals), Function.identity(), merger, //通过LinkedHashMap来保持原有的顺序 LinkedHashMap::new)) .values() .stream(); }
然后,这个方法如何放入 Stream 呢? 我们首先想到的就是代理 Stream
接口,最简单的实现:
public class EnhancedStream<T> implements Stream<T> { private Stream<T> delegate; public EnhancedStream(Stream<T> delegate) { this.delegate = delegate; } private static final class Key<E> { //要比较的对象 private final E e; //获取对象的hashcode的方法 private final ToIntFunction<E> hashCode; //判断两个对象是否相等的方法 private final BiPredicate<E, E> equals; public Key(E e, ToIntFunction<E> hashCode, BiPredicate<E, E> equals) { this.e = e; this.hashCode = hashCode; this.equals = equals; } @Override public int hashCode() { return hashCode.applyAsInt(e); } @Override public boolean equals(Object obj) { if (!(obj instanceof Key)) { return false; } @SuppressWarnings("unchecked") Key<E> that = (Key<E>) obj; return equals.test(this.e, that.e); } } public EnhancedStream<T> distinct( ToIntFunction<T> hashCode, BiPredicate<T, T> equals, //排重的时候,保留哪一个? BinaryOperator<T> merger ) { return new EnhancedStream<>( delegate.collect(Collectors.toMap( t -> new Key<>(t, hashCode, equals), Function.identity(), merger, //通过LinkedHashMap来保持原有的顺序 LinkedHashMap::new)) .values() .stream() ); } @Override public EnhancedStream<T> filter(Predicate<? super T> predicate) { return new EnhancedStream<>(delegate.filter(predicate)); } @Override public <R> EnhancedStream<R> map(Function<? super T, ? extends R> mapper) { return new EnhancedStream<>(delegate.map(mapper)); } @Override public IntStream mapToInt(ToIntFunction<? super T> mapper) { return delegate.mapToInt(mapper); } @Override public LongStream mapToLong(ToLongFunction<? super T> mapper) { return delegate.mapToLong(mapper); } @Override public DoubleStream mapToDouble(ToDoubleFunction<? super T> mapper) { return delegate.mapToDouble(mapper); } @Override public <R> EnhancedStream<R> flatMap(Function<? super T, ? extends Stream<? extends R>> mapper) { return new EnhancedStream<>(delegate.flatMap(mapper)); } @Override public IntStream flatMapToInt(Function<? super T, ? extends IntStream> mapper) { return delegate.flatMapToInt(mapper); } @Override public LongStream flatMapToLong(Function<? super T, ? extends LongStream> mapper) { return delegate.flatMapToLong(mapper); } @Override public DoubleStream flatMapToDouble(Function<? super T, ? extends DoubleStream> mapper) { return delegate.flatMapToDouble(mapper); } @Override public EnhancedStream<T> distinct() { return new EnhancedStream<>(delegate.distinct()); } @Override public EnhancedStream<T> sorted() { return new EnhancedStream<>(delegate.sorted()); } @Override public EnhancedStream<T> sorted(Comparator<? super T> comparator) { return new EnhancedStream<>(delegate.sorted(comparator)); } @Override public EnhancedStream<T> peek(Consumer<? super T> action) { return new EnhancedStream<>(delegate.peek(action)); } @Override public EnhancedStream<T> limit(long maxSize) { return new EnhancedStream<>(delegate.limit(maxSize)); } @Override public EnhancedStream<T> skip(long n) { return new EnhancedStream<>(delegate.skip(n)); } @Override public void forEach(Consumer<? super T> action) { delegate.forEach(action); } @Override public void forEachOrdered(Consumer<? super T> action) { delegate.forEachOrdered(action); } @Override public Object[] toArray() { return delegate.toArray(); } @Override public <A> A[] toArray(IntFunction<A[]> generator) { return delegate.toArray(generator); } @Override public T reduce(T identity, BinaryOperator<T> accumulator) { return delegate.reduce(identity, accumulator); } @Override public Optional<T> reduce(BinaryOperator<T> accumulator) { return delegate.reduce(accumulator); } @Override public <U> U reduce(U identity, BiFunction<U, ? super T, U> accumulator, BinaryOperator<U> combiner) { return delegate.reduce(identity, accumulator, combiner); } @Override public <R> R collect(Supplier<R> supplier, BiConsumer<R, ? super T> accumulator, BiConsumer<R, R> combiner) { return delegate.collect(supplier, accumulator, combiner); } @Override public <R, A> R collect(Collector<? super T, A, R> collector) { return delegate.collect(collector); } @Override public Optional<T> min(Comparator<? super T> comparator) { return delegate.min(comparator); } @Override public Optional<T> max(Comparator<? super T> comparator) { return delegate.max(comparator); } @Override public long count() { return delegate.count(); } @Override public boolean anyMatch(Predicate<? super T> predicate) { return delegate.anyMatch(predicate); } @Override public boolean allMatch(Predicate<? super T> predicate) { return delegate.allMatch(predicate); } @Override public boolean noneMatch(Predicate<? super T> predicate) { return delegate.noneMatch(predicate); } @Override public Optional<T> findFirst() { return delegate.findFirst(); } @Override public Optional<T> findAny() { return delegate.findAny(); } @Override public Iterator<T> iterator() { return delegate.iterator(); } @Override public Spliterator<T> spliterator() { return delegate.spliterator(); } @Override public boolean isParallel() { return delegate.isParallel(); } @Override public EnhancedStream<T> sequential() { return new EnhancedStream<>(delegate.sequential()); } @Override public EnhancedStream<T> parallel() { return new EnhancedStream<>(delegate.parallel()); } @Override public EnhancedStream<T> unordered() { return new EnhancedStream<>(delegate.unordered()); } @Override public EnhancedStream<T> onClose(Runnable closeHandler) { return new EnhancedStream<>(delegate.onClose(closeHandler)); } @Override public void close() { delegate.close(); } }