diff --git a/ballerina/tests/h2-native-tests.bal b/ballerina/tests/h2-native-tests.bal index 5d68b4b..7174efb 100644 --- a/ballerina/tests/h2-native-tests.bal +++ b/ballerina/tests/h2-native-tests.bal @@ -191,7 +191,7 @@ function h2NativeQueryComplexTest() returns error? { @test:Config { groups: ["transactions", "h2", "native"], dependsOn: [h2NativeExecuteTestNegative1, h2NativeQueryTest, h2NativeQueryTestNegative, h2NativeQueryComplexTest], - enable: false + enable: true } function h2NativeTransactionTest() returns error? { H2RainierClient rainierClient = check new (); @@ -254,17 +254,14 @@ function h2NativeTransactionTest2() returns error? { check buildingStream.close(); test:assertEquals(buildings, [building33]); - transaction { - _ = check rainierClient->executeNativeSQL(` - UPDATE "Building" - SET - "city" = ${building33Updated.city}, - "state" = ${building33Updated.state}, - "country" = ${building33Updated.country} - WHERE "buildingCode" = ${building33.buildingCode} - `); - check commit; - } + _ = check rainierClient->executeNativeSQL(` + UPDATE "Building" + SET + "city" = ${building33Updated.city}, + "state" = ${building33Updated.state}, + "country" = ${building33Updated.country} + WHERE "buildingCode" = ${building33.buildingCode} + `); stream buildingStream3 = rainierClient->queryNativeSQL(`SELECT * FROM "Building" WHERE "buildingCode" = ${building33.buildingCode}`); Building[] buildings3 = check from Building building in buildingStream3 diff --git a/ballerina/tests/mssql-native-tests.bal b/ballerina/tests/mssql-native-tests.bal index d4fd753..3d06f9e 100644 --- a/ballerina/tests/mssql-native-tests.bal +++ b/ballerina/tests/mssql-native-tests.bal @@ -191,7 +191,7 @@ function mssqlNativeQueryComplexTest() returns error? { @test:Config { groups: ["transactions", "mssql", "native"], dependsOn: [mssqlNativeExecuteTestNegative1, mssqlNativeQueryTest, mssqlNativeQueryTestNegative, mssqlNativeQueryComplexTest], - enable: false + enable: true } function mssqlNativeTransactionTest() returns error? { MSSQLRainierClient rainierClient = check new (); @@ -254,17 +254,14 @@ function mssqlNativeTransactionTest2() returns error? { check buildingStream.close(); test:assertEquals(buildings, [building33]); - transaction { - _ = check rainierClient->executeNativeSQL(` - UPDATE Building - SET - city = ${building33Updated.city}, - state = ${building33Updated.state}, - country = ${building33Updated.country} - WHERE buildingCode = ${building33.buildingCode} - `); - check commit; - } + _ = check rainierClient->executeNativeSQL(` + UPDATE Building + SET + city = ${building33Updated.city}, + state = ${building33Updated.state}, + country = ${building33Updated.country} + WHERE buildingCode = ${building33.buildingCode} + `); stream buildingStream3 = rainierClient->queryNativeSQL(`SELECT * FROM Building WHERE buildingCode = ${building33.buildingCode}`); Building[] buildings3 = check from Building building in buildingStream3 diff --git a/ballerina/tests/mysql-native-tests.bal b/ballerina/tests/mysql-native-tests.bal index c69d2f8..4a7e145 100644 --- a/ballerina/tests/mysql-native-tests.bal +++ b/ballerina/tests/mysql-native-tests.bal @@ -191,7 +191,7 @@ function mysqlNativeQueryComplexTest() returns error? { @test:Config { groups: ["transactions", "mysql", "native"], dependsOn: [mysqlNativeExecuteTestNegative1, mysqlNativeQueryTest, mysqlNativeQueryTestNegative, mysqlNativeQueryComplexTest], - enable: false + enable: true } function mysqlNativeTransactionTest() returns error? { MySQLRainierClient rainierClient = check new (); @@ -254,17 +254,14 @@ function mysqlNativeTransactionTest2() returns error? { check buildingStream.close(); test:assertEquals(buildings, [building33]); - transaction { - _ = check rainierClient->executeNativeSQL(` - UPDATE Building - SET - city = ${building33Updated.city}, - state = ${building33Updated.state}, - country = ${building33Updated.country} - WHERE buildingCode = ${building33.buildingCode} - `); - check commit; - } + _ = check rainierClient->executeNativeSQL(` + UPDATE Building + SET + city = ${building33Updated.city}, + state = ${building33Updated.state}, + country = ${building33Updated.country} + WHERE buildingCode = ${building33.buildingCode} + `); stream buildingStream3 = rainierClient->queryNativeSQL(`SELECT * FROM Building WHERE buildingCode = ${building33.buildingCode}`); Building[] buildings3 = check from Building building in buildingStream3 diff --git a/ballerina/tests/postgresql-native-tests.bal b/ballerina/tests/postgresql-native-tests.bal index 1a3f4e7..f309071 100644 --- a/ballerina/tests/postgresql-native-tests.bal +++ b/ballerina/tests/postgresql-native-tests.bal @@ -191,7 +191,7 @@ function postgresqlNativeQueryComplexTest() returns error? { @test:Config { groups: ["transactions", "postgresql", "native"], dependsOn: [postgresqlNativeExecuteTestNegative1, postgresqlNativeQueryTest, postgresqlNativeQueryTestNegative, postgresqlNativeQueryComplexTest], - enable: false + enable: true } function postgresqlNativeTransactionTest() returns error? { PostgreSQLRainierClient rainierClient = check new (); @@ -254,17 +254,14 @@ function postgresqlNativeTransactionTest2() returns error? { check buildingStream.close(); test:assertEquals(buildings, [building33]); - transaction { - _ = check rainierClient->executeNativeSQL(` - UPDATE "Building" - SET - "city" = ${building33Updated.city}, - "state" = ${building33Updated.state}, - "country" = ${building33Updated.country} - WHERE "buildingCode" = ${building33.buildingCode} - `); - check commit; - } + _ = check rainierClient->executeNativeSQL(` + UPDATE "Building" + SET + "city" = ${building33Updated.city}, + "state" = ${building33Updated.state}, + "country" = ${building33Updated.country} + WHERE "buildingCode" = ${building33.buildingCode} + `); stream buildingStream3 = rainierClient->queryNativeSQL(`SELECT * FROM "Building" WHERE "buildingCode" = ${building33.buildingCode}`); Building[] buildings3 = check from Building building in buildingStream3 diff --git a/native/src/main/java/io/ballerina/stdlib/persist/sql/datastore/SQLProcessor.java b/native/src/main/java/io/ballerina/stdlib/persist/sql/datastore/SQLProcessor.java index 11baf06..942441e 100644 --- a/native/src/main/java/io/ballerina/stdlib/persist/sql/datastore/SQLProcessor.java +++ b/native/src/main/java/io/ballerina/stdlib/persist/sql/datastore/SQLProcessor.java @@ -22,6 +22,7 @@ import io.ballerina.runtime.api.Future; import io.ballerina.runtime.api.PredefinedTypes; import io.ballerina.runtime.api.async.Callback; +import io.ballerina.runtime.api.constants.RuntimeConstants; import io.ballerina.runtime.api.creators.TypeCreator; import io.ballerina.runtime.api.creators.ValueCreator; import io.ballerina.runtime.api.types.ErrorType; @@ -36,12 +37,11 @@ import io.ballerina.runtime.api.values.BStream; import io.ballerina.runtime.api.values.BString; import io.ballerina.runtime.api.values.BTypedesc; +import io.ballerina.runtime.transactions.TransactionLocalContext; import io.ballerina.runtime.transactions.TransactionResourceManager; import io.ballerina.stdlib.persist.Constants; import io.ballerina.stdlib.persist.ModuleUtils; import io.ballerina.stdlib.persist.sql.Utils; -import io.ballerina.stdlib.sql.parameterprocessor.DefaultResultParameterProcessor; -import io.ballerina.stdlib.sql.parameterprocessor.DefaultStatementParameterProcessor; import java.util.Map; @@ -180,50 +180,12 @@ public void notifyFailure(BError bError) { static BStream queryNativeSQL(Environment env, BObject client, BObject paramSQLString, BTypedesc targetType) { // This method will return `stream` - - TransactionResourceManager trxResourceManager = TransactionResourceManager.getInstance(); - if (!io.ballerina.stdlib.sql.utils.Utils.isWithinTrxBlock(trxResourceManager)) { - return queryNativeSQLBal(env, client, paramSQLString, targetType); - } - - BObject dbClient = (BObject) client.get(DB_CLIENT); - BStream sqlStream = io.ballerina.stdlib.sql.nativeimpl.QueryProcessor.nativeQuery(env, dbClient, - paramSQLString, targetType, DefaultStatementParameterProcessor.getInstance(), - DefaultResultParameterProcessor.getInstance()); - - if (sqlStream != null) { - BObject persistNativeStream = createPersistNativeSQLStream(sqlStream, null); - RecordType streamConstraint = - (RecordType) TypeUtils.getReferredType(targetType.getDescribingType()); - return (ValueCreator.createStreamValue(TypeCreator.createStreamType(streamConstraint, - PredefinedTypes.TYPE_NULL), persistNativeStream) - ); - } - - return null; + return queryNativeSQLBal(env, client, paramSQLString, targetType); } static Object executeNativeSQL(Environment env, BObject client, BObject paramSQLString) { // This method will return `persist:ExecutionResult|persist:Error` - - TransactionResourceManager trxResourceManager = TransactionResourceManager.getInstance(); - if (!io.ballerina.stdlib.sql.utils.Utils.isWithinTrxBlock(trxResourceManager)) { - return executeNativeSQLBal(env, client, paramSQLString); - } - - BObject dbClient = (BObject) client.get(DB_CLIENT); - Object sqlExecutionResult = io.ballerina.stdlib.sql.nativeimpl.ExecuteProcessor.nativeExecute(env, dbClient, - paramSQLString, DefaultStatementParameterProcessor.getInstance()); - - if (sqlExecutionResult instanceof BMap) { // returned type is `sql:ExecutionResult` - return ValueCreator.createRecordValue(getModule(), - io.ballerina.stdlib.persist.sql.Constants.PERSIST_EXECUTION_RESULT, - (BMap) sqlExecutionResult); - } else if (sqlExecutionResult instanceof BError) { // returned type is `sql:Error` - return wrapSQLError((BError) sqlExecutionResult); - } - - return null; + return executeNativeSQLBal(env, client, paramSQLString); } private static BStream queryNativeSQLBal(Environment env, BObject client, BObject paramSQLString, @@ -233,6 +195,12 @@ private static BStream queryNativeSQLBal(Environment env, BObject client, BObjec BObject dbClient = (BObject) client.get(DB_CLIENT); RecordType recordType = (RecordType) targetType.getDescribingType(); StreamType streamType = TypeCreator.createStreamType(recordType, PredefinedTypes.TYPE_NULL); + TransactionResourceManager trxResourceManager = TransactionResourceManager.getInstance(); + TransactionLocalContext currentTrxContext = trxResourceManager.getCurrentTransactionContext(); + Map properties = null; + if (currentTrxContext != null) { + properties = Map.of(RuntimeConstants.CURRENT_TRANSACTION_CONTEXT_PROPERTY, currentTrxContext); + } Future balFuture = env.markAsync(); env.getRuntime().invokeMethodAsyncSequentially( @@ -257,7 +225,7 @@ public void notifyFailure(BError bError) { // can only be hit on a panic BObject errorStream = Utils.createPersistNativeSQLStream(null, bError); balFuture.complete(errorStream); } - }, null, streamType, paramSQLString, true, targetType, true + }, properties, streamType, paramSQLString, true, targetType, true ); return null; @@ -267,6 +235,12 @@ private static Object executeNativeSQLBal(Environment env, BObject client, BObje BObject dbClient = (BObject) client.get(DB_CLIENT); RecordType persistExecutionResultType = TypeCreator.createRecordType( io.ballerina.stdlib.persist.sql.Constants.PERSIST_EXECUTION_RESULT, getModule(), 0, true, 0); + TransactionResourceManager trxResourceManager = TransactionResourceManager.getInstance(); + TransactionLocalContext currentTrxContext = trxResourceManager.getCurrentTransactionContext(); + Map properties = null; + if (currentTrxContext != null) { + properties = Map.of(RuntimeConstants.CURRENT_TRANSACTION_CONTEXT_PROPERTY, currentTrxContext); + } Future balFuture = env.markAsync(); env.getRuntime().invokeMethodAsyncSequentially( @@ -292,8 +266,7 @@ public void notifyFailure(BError bError) { // can only be hit on a panic BError persistError = wrapError(bError); balFuture.complete(persistError); } - }, null, persistExecutionResultType, paramSQLString, true - ); + }, properties, persistExecutionResultType, paramSQLString, true); return null; }