Skip to content

Commit

Permalink
feat: add db2 Queries task (#410)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgabelle authored Oct 25, 2024
1 parent eaa8a5d commit 522f78e
Show file tree
Hide file tree
Showing 5 changed files with 329 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package io.kestra.plugin.jdbc.db2;

import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.tasks.RunnableTask;
import io.kestra.core.runners.RunContext;
import io.kestra.plugin.jdbc.AbstractCellConverter;
import io.kestra.plugin.jdbc.AbstractJdbcQueries;
import io.kestra.plugin.jdbc.AbstractJdbcQuery;
import io.kestra.plugin.jdbc.AutoCommitInterface;
import io.micronaut.http.uri.UriBuilder;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;
import lombok.experimental.SuperBuilder;

import java.net.URI;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.time.ZoneId;
import java.util.Properties;

@SuperBuilder
@ToString
@EqualsAndHashCode
@Getter
@NoArgsConstructor
@Schema(
title = "Perform multiple queries on a DB2 database."
)
@Plugin(
examples = {
@Example(
title = "Send a SQL query to a DB2 Database and fetch a row as output.",
full = true,
code = """
id: db2_query
namespace: company.team
tasks:
- id: queries
type: io.kestra.plugin.jdbc.db2.Queries
url: jdbc:db2://127.0.0.1:50000/
username: db2inst
password: db2_password
sql: select * from employee; select * from laptop;
fetchType: FETCH
"""
)
}
)
public class Queries extends AbstractJdbcQueries implements RunnableTask<AbstractJdbcQueries.MultiQueryOutput> {

@Override
protected AbstractCellConverter getCellConverter(ZoneId zoneId) {
return new Db2CellConverter(zoneId);
}

@Override
public void registerDriver() throws SQLException {
DriverManager.registerDriver(new com.ibm.db2.jcc.DB2Driver());
}

@Override
public Properties connectionProperties(RunContext runContext) throws Exception {
return super.connectionProperties(runContext, "jdbc:db2");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,6 @@ public void registerDriver() throws SQLException {

@Override
public Properties connectionProperties(RunContext runContext) throws Exception {
Properties props = super.connectionProperties(runContext);

URI url = URI.create((String) props.get("jdbc.url"));
url = URI.create(url.getSchemeSpecificPart());

UriBuilder builder = UriBuilder.of(url).scheme("jdbc:db2");

props.put("jdbc.url", builder.build().toString());

return props;
return super.connectionProperties(runContext, "jdbc:db2");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
package io.kestra.plugin.jdbc.db2;

import io.kestra.core.junit.annotations.KestraTest;
import io.kestra.core.models.property.Property;
import io.kestra.core.runners.RunContext;
import io.kestra.plugin.jdbc.AbstractJdbcQueries;
import io.kestra.plugin.jdbc.AbstractRdbmsTest;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.net.URISyntaxException;
import java.sql.SQLException;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import static io.kestra.core.models.tasks.common.FetchType.FETCH;
import static io.kestra.core.models.tasks.common.FetchType.FETCH_ONE;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.jupiter.api.Assertions.assertThrows;

@KestraTest
@Disabled("Disabled for CI")
public class DB2QueriesTest extends AbstractRdbmsTest {

@Test
void testMultiSelectWithParameters() throws Exception {
RunContext runContext = runContextFactory.of(Collections.emptyMap());

Map<String, Object> parameters = Map.of(
"age", 40,
"brand", "Apple",
"cpu_frequency", 1.5
);

Queries taskGet = Queries.builder()
.url(getUrl())
.username(getUsername())
.password(getPassword())
.fetchType(FETCH)
.timeZoneId("Europe/Paris")
.sql("""
SELECT firstName, lastName, age FROM employee where age > :age and age < :age + 10;
SELECT brand, model FROM laptop where brand = :brand and cpu_frequency > :cpu_frequency;
SELECT * FROM employee;
""")
.parameters(Property.of(parameters))
.build();

AbstractJdbcQueries.MultiQueryOutput runOutput = taskGet.run(runContext);
assertThat(runOutput.getOutputs().size(), is(3));

List<Map<String, Object>> employees = runOutput.getOutputs().getFirst().getRows();
assertThat("employees", employees, notNullValue());
assertThat("employees", employees.size(), is(1));
assertThat("employee selected", employees.getFirst().get("AGE"), is(45));
assertThat("employee selected", employees.getFirst().get("FIRSTNAME"), is("John"));
assertThat("employee selected", employees.getFirst().get("LASTNAME"), is("Doe"));

List<Map<String, Object>> laptops = runOutput.getOutputs().get(1).getRows();
assertThat("laptops", laptops, notNullValue());
assertThat("laptops", laptops.size(), is(1));
assertThat("selected laptop", laptops.getFirst().get("BRAND"), is("Apple"));

List<Map<String, Object>> allEmployees = runOutput.getOutputs().getLast().getRows();
assertThat("All employees", allEmployees, notNullValue());
assertThat("All employees size", allEmployees.size(), is(4));
}

@Test
void testRollback() throws Exception {
RunContext runContext = runContextFactory.of(Collections.emptyMap());

//Queries should pass in a transaction
Queries queriesPass = Queries.builder()
.url(getUrl())
.username(getUsername())
.password(getPassword())
.fetchType(FETCH_ONE)
.timeZoneId("Europe/Paris")
.sql("""
DROP TABLE IF EXISTS DB2INST1.test_transaction;
CREATE TABLE DB2INST1.test_transaction(id INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, name VARCHAR(230));
INSERT INTO DB2INST1.test_transaction (name) VALUES ('test_insert_1');
SELECT COUNT(id) as TRANSACTION_COUNT FROM DB2INST1.test_transaction;
""")
.build();

AbstractJdbcQueries.MultiQueryOutput runOutput = queriesPass.run(runContext);
assertThat(runOutput.getOutputs().size(), is(1));
assertThat(runOutput.getOutputs().getFirst().getRow().get("TRANSACTION_COUNT"), is(1));

//Queries should fail due to bad sql
Queries insertsFail = Queries.builder()
.url(getUrl())
.username(getUsername())
.password(getPassword())
.fetchType(FETCH_ONE)
.timeZoneId("Europe/Paris")
.sql("""
INSERT INTO DB2INST1.test_transaction (name) VALUES ('test_insert_2');
INSERT INTO DB2INST1.test_transaction (name) VALUES (3f);
""") //Try inserting before failing
.build();

assertThrows(Exception.class, () -> insertsFail.run(runContext));

//Final query to verify the amount of updated rows
Queries verifyQuery = Queries.builder()
.url(getUrl())
.username(getUsername())
.password(getPassword())
.fetchType(FETCH_ONE)
.timeZoneId("Europe/Paris")
.sql("""
SELECT COUNT(id) as TRANSACTION_COUNT FROM DB2INST1.test_transaction;
""") //Try inserting before failing
.build();

AbstractJdbcQueries.MultiQueryOutput verifyOutput = verifyQuery.run(runContext);
assertThat(verifyOutput.getOutputs().size(), is(1));
assertThat(verifyOutput.getOutputs().getFirst().getRow().get("TRANSACTION_COUNT"), is(1));
}

@Test
void testNonTransactionalShouldNotRollback() throws Exception {
RunContext runContext = runContextFactory.of(Collections.emptyMap());

//Queries should pass in a transaction
Queries insertOneAndFail = Queries.builder()
.url(getUrl())
.username(getUsername())
.password(getPassword())
.fetchType(FETCH_ONE)
.transaction(Property.of(false))
.timeZoneId("Europe/Paris")
.sql("""
DROP TABLE IF EXISTS DB2INST1.test_transaction;
CREATE TABLE DB2INST1.test_transaction(id INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, name VARCHAR(230));
INSERT INTO DB2INST1.test_transaction (name) VALUES ('test_insert_1');
INSERT INTO DB2INST1.test_transaction (id) VALUES (1f);
INSERT INTO DB2INST1.test_transaction (id) VALUES ('test_insert_2);
""")
.build();

assertThrows(Exception.class, () -> insertOneAndFail.run(runContext));

//Final query to verify the amount of updated rows
Queries verifyQuery = Queries.builder()
.url(getUrl())
.username(getUsername())
.password(getPassword())
.fetchType(FETCH_ONE)
.timeZoneId("Europe/Paris")
.sql("""
SELECT COUNT(id) as TRANSACTION_COUNT FROM DB2INST1.test_transaction;
""") //Try inserting before failing
.build();

AbstractJdbcQueries.MultiQueryOutput verifyOutput = verifyQuery.run(runContext);
assertThat(verifyOutput.getOutputs().size(), is(1));
assertThat(verifyOutput.getOutputs().getFirst().getRow().get("TRANSACTION_COUNT"), is(1));
}

@Override
protected String getUrl() {
return "jdbc:db2://localhost:5023/testdb";
}

@Override
protected String getUsername() {
return "db2inst1";
}

@Override
protected String getPassword() {
return "password";
}

@Override
protected void initDatabase() throws SQLException, FileNotFoundException, URISyntaxException {
executeSqlScript("scripts/db2_queries.sql");
}

@Override
@BeforeEach
public void init() throws IOException, URISyntaxException, SQLException {
initDatabase();
}
}
35 changes: 35 additions & 0 deletions plugin-jdbc-db2/src/test/resources/scripts/db2_queries.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
-- Create table employee
DROP TABLE IF EXISTS DB2INST1.employee;

CREATE TABLE DB2INST1.employee (
id INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
firstName VARCHAR(200),
lastName VARCHAR(200),
age INT
);

INSERT INTO DB2INST1.employee (firstName, lastName, age)
VALUES
('John', 'Doe', 45),
('Bryan', 'Grant', 33),
('Jude', 'Philips', 25),
('Michael', 'Page', 62);


-- Create table laptop
DROP TABLE IF EXISTS DB2INST1.laptop;

CREATE TABLE DB2INST1.laptop
(
id INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
brand VARCHAR(200),
model VARCHAR(200),
cpu_frequency DOUBLE
);

INSERT INTO DB2INST1.laptop (brand, model, cpu_frequency)
VALUES
('Apple', 'MacBookPro M1 13', 2.2),
('Apple', 'MacBookPro M3 16', 1.5),
('LG', 'Gram', 1.95),
('Lenovo', 'ThinkPad', 1.05);
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package io.kestra.plugin.jdbc;

import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.annotations.PluginProperty;
import io.kestra.core.runners.RunContext;
import io.micronaut.http.uri.UriBuilder;
import io.swagger.v3.oas.annotations.media.Schema;

import java.net.URI;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
Expand Down Expand Up @@ -39,16 +42,18 @@ public interface JdbcConnectionInterface {
void registerDriver() throws SQLException;

default Properties connectionProperties(RunContext runContext) throws Exception {
Properties props = new Properties();
props.put("jdbc.url", runContext.render(this.getUrl()));
return createConnectionProperties(runContext);
}

if (this.getUsername() != null) {
props.put("user", runContext.render(this.getUsername()));
}
default Properties connectionProperties(RunContext runContext, String urlScheme) throws Exception {
Properties props = createConnectionProperties(runContext);

if (this.getPassword() != null) {
props.put("password", runContext.render(this.getPassword()));
}
URI url = URI.create((String) props.get("jdbc.url"));
url = URI.create(url.getSchemeSpecificPart());

UriBuilder builder = UriBuilder.of(url).scheme(urlScheme);

props.put("jdbc.url", builder.build().toString());

return props;
}
Expand All @@ -62,4 +67,19 @@ default Connection connection(RunContext runContext) throws Exception {

return DriverManager.getConnection(jdbcUrl, props);
}

private Properties createConnectionProperties(RunContext runContext) throws IllegalVariableEvaluationException {
Properties props = new Properties();
props.put("jdbc.url", runContext.render(this.getUrl()));

if (this.getUsername() != null) {
props.put("user", runContext.render(this.getUsername()));
}

if (this.getPassword() != null) {
props.put("password", runContext.render(this.getPassword()));
}

return props;
}
}

0 comments on commit 522f78e

Please sign in to comment.