Skip to content

Commit

Permalink
fix: perform retries in a way that is compatible with driver retry lo…
Browse files Browse the repository at this point in the history
…gic (#662)
  • Loading branch information
ali-ince authored Sep 11, 2024
1 parent 354dfeb commit 2d1247e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 44 deletions.
20 changes: 15 additions & 5 deletions common/src/main/scala/org/neo4j/spark/util/Neo4jOptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.SparkSession
import org.jetbrains.annotations.TestOnly
import org.neo4j.driver.Config.TrustStrategy
import org.neo4j.driver._
import org.neo4j.driver.exceptions.Neo4jException
import org.neo4j.driver.net.ServerAddress
import org.neo4j.driver.net.ServerAddressResolver
import org.neo4j.spark.util.Neo4jImplicits.StringMapImplicits
Expand Down Expand Up @@ -219,14 +220,14 @@ class Neo4jOptions(private val options: java.util.Map[String, String]) extends S
Neo4jNodeMetadata(labels, nodeKeys, nodeProps)
}

val transactionMetadata: Neo4jTransactionMetadata = initNeo4jTransactionMetadata()
val transactionSettings: Neo4jTransactionSettings = initNeo4jTransactionSettings()

val script: Array[String] = getParameter(SCRIPT)
.split(";")
.map(_.trim)
.filterNot(_.isEmpty)

private def initNeo4jTransactionMetadata(): Neo4jTransactionMetadata = {
private def initNeo4jTransactionSettings(): Neo4jTransactionSettings = {
val retries = getParameter(TRANSACTION_RETRIES, DEFAULT_TRANSACTION_RETRIES.toString).toInt
val failOnTransactionCodes = getParameter(TRANSACTION_CODES_FAIL, DEFAULT_EMPTY)
.split(",")
Expand All @@ -235,7 +236,7 @@ class Neo4jOptions(private val options: java.util.Map[String, String]) extends S
.toSet
val batchSize = getParameter(BATCH_SIZE, DEFAULT_BATCH_SIZE.toString).toInt
val retryTimeout = getParameter(TRANSACTION_RETRY_TIMEOUT, DEFAULT_TRANSACTION_RETRY_TIMEOUT.toString).toInt
Neo4jTransactionMetadata(retries, failOnTransactionCodes, batchSize, retryTimeout)
Neo4jTransactionSettings(retries, failOnTransactionCodes, batchSize, retryTimeout)
}

val relationshipMetadata: Neo4jRelationshipMetadata = initNeo4jRelationshipMetadata()
Expand Down Expand Up @@ -356,12 +357,21 @@ case class Neo4jSchemaMetadata(
mapGroupDuplicateKeys: Boolean
)

case class Neo4jTransactionMetadata(
case class Neo4jTransactionSettings(
retries: Int,
failOnTransactionCodes: Set[String],
batchSize: Int,
retryTimeout: Long
)
) {

def shouldFailOn(exception: Throwable): Boolean = {
exception match {
case e: Neo4jException => failOnTransactionCodes.contains(e.code())
case _ => false
}
}

}

case class Neo4jNodeMetadata(labels: Seq[String], nodeKeys: Map[String, String], properties: Map[String, String]) {
def includesProperty(name: String): Boolean = nodeKeys.contains(name) || properties.contains(name)
Expand Down
27 changes: 13 additions & 14 deletions common/src/main/scala/org/neo4j/spark/util/Neo4jUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,10 @@ import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.SerializerProvider
import com.fasterxml.jackson.databind.module.SimpleModule
import org.apache.spark.sql.sources._
import org.neo4j.cypherdsl.core.Condition
import org.neo4j.cypherdsl.core.Cypher
import org.neo4j.cypherdsl.core.Expression
import org.neo4j.cypherdsl.core.Functions
import org.neo4j.cypherdsl.core.Property
import org.neo4j.cypherdsl.core.PropertyContainer
import org.neo4j.cypherdsl.core._
import org.neo4j.driver.Session
import org.neo4j.driver.Transaction
import org.neo4j.driver.exceptions.Neo4jException
import org.neo4j.driver.exceptions.ServiceUnavailableException
import org.neo4j.driver.exceptions.SessionExpiredException
import org.neo4j.driver.exceptions.TransientException
import org.neo4j.driver.internal.retry.ExponentialBackoffRetryLogic
import org.neo4j.driver.types.Entity
import org.neo4j.driver.types.Path
import org.neo4j.spark.service.SchemaService
Expand All @@ -44,6 +36,8 @@ import org.slf4j.Logger
import java.time.temporal.Temporal
import java.util.Properties

import scala.annotation.tailrec

object Neo4jUtil {

val NODE_ALIAS = "n"
Expand Down Expand Up @@ -237,9 +231,14 @@ object Neo4jUtil {
}
}

def isRetryableException(neo4jTransientException: Neo4jException) =
(neo4jTransientException.isInstanceOf[SessionExpiredException]
|| neo4jTransientException.isInstanceOf[TransientException]
|| neo4jTransientException.isInstanceOf[ServiceUnavailableException])
@tailrec
def isRetryableException(exception: Throwable): Boolean = {
if (exception == null) {
false
} else
ExponentialBackoffRetryLogic.isRetryable(exception) || isRetryableException(
exception.getCause
)
}

}
44 changes: 19 additions & 25 deletions common/src/main/scala/org/neo4j/spark/writer/BaseDataWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ import org.apache.spark.sql.types.StructType
import org.neo4j.driver.Session
import org.neo4j.driver.Transaction
import org.neo4j.driver.Values
import org.neo4j.driver.exceptions.ClientException
import org.neo4j.driver.exceptions.Neo4jException
import org.neo4j.driver.exceptions.ServiceUnavailableException
import org.neo4j.spark.service._
import org.neo4j.spark.util.DriverCache
Expand All @@ -40,6 +38,7 @@ import java.util
import java.util.concurrent.CountDownLatch
import java.util.concurrent.locks.LockSupport

import scala.annotation.tailrec
import scala.collection.JavaConverters._

abstract class BaseDataWriter(
Expand All @@ -63,19 +62,20 @@ abstract class BaseDataWriter(

private val batch: util.List[java.util.Map[String, Object]] = new util.ArrayList[util.Map[String, Object]]()

private val retries = new CountDownLatch(options.transactionMetadata.retries)
private val retries = new CountDownLatch(options.transactionSettings.retries)

private val query: String = new Neo4jQueryService(options, new Neo4jQueryWriteStrategy(saveMode)).createQuery()

private val metrics = DataWriterMetrics()

def write(record: InternalRow): Unit = {
batch.add(mappingService.convert(record, structType))
if (batch.size() == options.transactionMetadata.batchSize) {
if (batch.size() == options.transactionSettings.batchSize) {
writeBatch()
}
}

@tailrec
private def writeBatch(): Unit = {
try {
if (session == null || !session.isOpen) {
Expand Down Expand Up @@ -120,45 +120,40 @@ abstract class BaseDataWriter(
closeSafely(transaction)
batch.clear()
} catch {
case neo4jTransientException: Neo4jException =>
val code = neo4jTransientException.code()
if (
isRetryableException(neo4jTransientException)
&& !options.transactionMetadata.failOnTransactionCodes.contains(code)
&& retries.getCount > 0
) {
case e: Throwable =>
if (options.transactionSettings.shouldFailOn(e)) {
log.error("unable to write batch due to explicitly configured failure condition", e)
throw e
}

if (isRetryableException(e) && retries.getCount > 0) {
retries.countDown()
log.info(
s"Matched Neo4j transient exception next retry is ${options.transactionMetadata.retries - retries.getCount}"
s"encountered a transient exception while writing batch, retrying ${options.transactionSettings.retries - retries.getCount} time",
e
)
close()
LockSupport.parkNanos(Duration.ofMillis(options.transactionMetadata.retryTimeout).toNanos)
LockSupport.parkNanos(Duration.ofMillis(options.transactionSettings.retryTimeout).toNanos)
writeBatch()
} else {
logAndThrowException(neo4jTransientException)
logAndThrowException(e)
}
case e: Exception => logAndThrowException(e)
}
()
}

/**
* df: we check if the thrown exception is STOPPED_THREAD_EXCEPTION. This is the
* exception that is thrown when the streaming query is interrupted, we don't want to cause
* any error in this case. The transaction are rolled back automatically.
*/
private def logAndThrowException(e: Exception): Unit = {
private def logAndThrowException(e: Throwable): Unit = {
if (e.isInstanceOf[ServiceUnavailableException] && e.getMessage == STOPPED_THREAD_EXCEPTION_MESSAGE) {
logWarning(e.getMessage)
} else {
if (e.isInstanceOf[ClientException]) {
log.error(s"Cannot commit the transaction because: ${e.getMessage}")
} else {
log.error("Cannot commit the transaction because the following exception", e)
}

throw e
logError("unable to write batch", e)
}

throw e
}

def commit(): Null = {
Expand All @@ -176,7 +171,6 @@ abstract class BaseDataWriter(
}
}
close()
()
}

def close(): Unit = {
Expand Down

0 comments on commit 2d1247e

Please sign in to comment.