diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtilsTest.java index 5fd30daa2..29fcdda15 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtilsTest.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtilsTest.java @@ -61,7 +61,7 @@ static SharedStorageBody.SharedStorageBodyResult sharedStorageBody(List afterBOp = - data.transform("b", TypeInformation.of(Long.class), bOp); + afterAOp.transform("b", TypeInformation.of(Long.class), bOp); Map, SharedStorageStreamOperator> ownerMap = new HashMap<>(); ownerMap.put(SUM, aOp); @@ -89,7 +89,9 @@ public void testSharedStorage() throws Exception { /** Operator A: add input elements to the shared {@link #SUM}. */ static class AOperator extends AbstractStreamOperator - implements OneInputStreamOperator, SharedStorageStreamOperator { + implements OneInputStreamOperator, + SharedStorageStreamOperator, + BoundedOneInput { private final String sharedStorageAccessorID; private SharedStorageContext sharedStorageContext; @@ -115,15 +117,18 @@ public void processElement(StreamRecord element) throws Exception { Long currentSum = getter.get(SUM); setter.set(SUM, currentSum + element.getValue()); }); - output.collect(element); + } + + @Override + public void endInput() throws Exception { + // Informs BOperator to get the value from shared {@link #SUM}. + output.collect(new StreamRecord<>(0L)); } } /** Operator B: when input ends, get the value from shared {@link #SUM}. */ static class BOperator extends AbstractStreamOperator - implements OneInputStreamOperator, - SharedStorageStreamOperator, - BoundedOneInput { + implements OneInputStreamOperator, SharedStorageStreamOperator { private final String sharedStorageAccessorID; private SharedStorageContext sharedStorageContext; @@ -143,10 +148,7 @@ public String getSharedStorageAccessorID() { } @Override - public void processElement(StreamRecord element) throws Exception {} - - @Override - public void endInput() throws Exception { + public void processElement(StreamRecord element) throws Exception { sharedStorageContext.invoke( (getter, setter) -> { output.collect(new StreamRecord<>(getter.get(SUM)));