diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala index 6a102aacd94..4d06bdf0553 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala @@ -901,19 +901,26 @@ object GpuShuffledAsymmetricHashJoinExec { concatTime = metrics(CONCAT_TIME), opTime = metrics(OP_TIME), opName = "stream as build") - val streamBatch = streamBatchIter.next() - val singleStreamIter = new SingleGpuColumnarBatchIterator(streamBatch) - assert(!streamBatchIter.hasNext, "stream side not exhausted") - val streamStats = JoinBuildSideStats.fromBatch(streamBatch, exprs.boundStreamKeys) - if (buildStats.streamMagnificationFactor < - streamStats.streamMagnificationFactor) { - metrics(BUILD_DATA_SIZE).set(buildSize) - JoinInfo(joinType, buildSide, buildIter, buildSize, Some(buildStats), - singleStreamIter, exprs) + if (streamBatchIter.hasNext) { + val streamBatch = streamBatchIter.next() + val singleStreamIter = new SingleGpuColumnarBatchIterator(streamBatch) + assert(!streamBatchIter.hasNext, "stream side not exhausted") + val streamStats = JoinBuildSideStats.fromBatch(streamBatch, exprs.boundStreamKeys) + if (buildStats.streamMagnificationFactor < + streamStats.streamMagnificationFactor) { + metrics(BUILD_DATA_SIZE).set(buildSize) + JoinInfo(joinType, buildSide, buildIter, buildSize, Some(buildStats), + singleStreamIter, exprs) + } else { + metrics(BUILD_DATA_SIZE).set(streamSize) + val flippedSide = flipped(buildSide) + JoinInfo(joinType, flippedSide, singleStreamIter, streamSize, Some(streamStats), + buildIter, exprs.flipped(joinType, flippedSide, condition)) + } } else { metrics(BUILD_DATA_SIZE).set(streamSize) val flippedSide = flipped(buildSide) - JoinInfo(joinType, flippedSide, singleStreamIter, streamSize, Some(streamStats), + JoinInfo(joinType, flippedSide, streamBatchIter, streamSize, None, buildIter, exprs.flipped(joinType, flippedSide, condition)) } } else {