Skip to content

Commit

Permalink
Merge pull request #38 from cloudsufi/snowflake-fixes
Browse files Browse the repository at this point in the history
[cherrypick][PLUGIN-1816] Added fix for decimal issue not having rounding mode and made escape character configurable.
  • Loading branch information
vikasrathee-cs authored Nov 15, 2024
2 parents 3b340b9 + f7d8d0a commit 2b7f72e
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

<groupId>io.cdap.plugin</groupId>
<artifactId>snowflake-plugins</artifactId>
<version>1.1.3</version>
<version>1.1.4-SNAPSHOT</version>
<packaging>jar</packaging>
<name>Snowflake plugins</name>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.cdap.plugin.snowflake.common.client.SnowflakeFieldDescriptor;
import io.cdap.plugin.snowflake.common.exception.SchemaParseException;
import io.cdap.plugin.snowflake.source.batch.SnowflakeBatchSourceConfig;
import io.cdap.plugin.snowflake.source.batch.SnowflakeInputFormatProvider;
import io.cdap.plugin.snowflake.source.batch.SnowflakeSourceAccessor;
import java.io.IOException;
import java.sql.Types;
Expand Down Expand Up @@ -62,7 +63,8 @@ public static Schema getSchema(SnowflakeBatchSourceConfig config, FailureCollect
return getParsedSchema(config.getSchema());
}

SnowflakeSourceAccessor snowflakeSourceAccessor = new SnowflakeSourceAccessor(config);
SnowflakeSourceAccessor snowflakeSourceAccessor =
new SnowflakeSourceAccessor(config, SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR);
return getSchema(snowflakeSourceAccessor, config.getSchema(), collector, config.getImportQuery());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.cdap.plugin.snowflake.source.batch;

import com.google.common.base.Strings;
import io.cdap.cdap.api.annotation.Description;
import io.cdap.cdap.api.annotation.Name;
import io.cdap.cdap.api.annotation.Plugin;
Expand All @@ -33,6 +34,7 @@
import io.cdap.plugin.snowflake.common.util.SchemaHelper;
import org.apache.hadoop.io.NullWritable;

import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -68,7 +70,11 @@ public void configurePipeline(PipelineConfigurer pipelineConfigurer) {
public void prepareRun(BatchSourceContext context) {
FailureCollector failureCollector = context.getFailureCollector();
config.validate(failureCollector);

Map<String, String> arguments = new HashMap<>(context.getArguments().asMap());
String escapeChar = arguments.containsKey(SnowflakeInputFormatProvider.PROPERTY_ESCAPE_CHAR) &&
!Strings.isNullOrEmpty(arguments.get(SnowflakeInputFormatProvider.PROPERTY_ESCAPE_CHAR))
? arguments.get(SnowflakeInputFormatProvider.PROPERTY_ESCAPE_CHAR)
: SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR;
Schema schema = SchemaHelper.getSchema(config, failureCollector);
failureCollector.getOrThrowException();

Expand All @@ -81,7 +87,7 @@ public void prepareRun(BatchSourceContext context) {
.collect(Collectors.toList()));
}

context.setInput(Input.of(config.getReferenceName(), new SnowflakeInputFormatProvider(config)));
context.setInput(Input.of(config.getReferenceName(), new SnowflakeInputFormatProvider(config, escapeChar)));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ private SnowflakeSourceAccessor getSnowflakeAccessor(Configuration configuration
SnowflakeInputFormatProvider.PROPERTY_CONFIG_JSON);
SnowflakeBatchSourceConfig config = GSON.fromJson(
configJson, SnowflakeBatchSourceConfig.class);
return new SnowflakeSourceAccessor(config);
String escapeChar = configuration.get(SnowflakeInputFormatProvider.PROPERTY_ESCAPE_CHAR,
SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR);
return new SnowflakeSourceAccessor(config, escapeChar);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,17 @@
public class SnowflakeInputFormatProvider implements InputFormatProvider {

public static final String PROPERTY_CONFIG_JSON = "cdap.snowflake.source.config";
public static final String PROPERTY_ESCAPE_CHAR = "cdap.snowflake.source.escape";

public static final String PROPERTY_DEFAULT_ESCAPE_CHAR = "\\";

private static final Gson GSON = new Gson();
private final Map<String, String> conf;

public SnowflakeInputFormatProvider(SnowflakeBatchSourceConfig config) {
public SnowflakeInputFormatProvider(SnowflakeBatchSourceConfig config, String escapeChar) {
this.conf = new ImmutableMap.Builder<String, String>()
.put(PROPERTY_CONFIG_JSON, GSON.toJson(config))
.put(PROPERTY_ESCAPE_CHAR, escapeChar)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.slf4j.LoggerFactory;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalTime;
Expand Down Expand Up @@ -83,7 +84,8 @@ private Object convertValue(String fieldName, String value, Schema fieldSchema)
case TIME_MICROS:
return TimeUnit.NANOSECONDS.toMicros(LocalTime.parse(value).toNanoOfDay());
case DECIMAL:
return new BigDecimal(value).setScale(fieldSchema.getScale()).unscaledValue().toByteArray();
return new BigDecimal(value).setScale(fieldSchema.getScale(),
RoundingMode.HALF_EVEN).unscaledValue().toByteArray();
default:
throw new IllegalArgumentException(
String.format("Field '%s' is of unsupported type '%s'", fieldSchema.getDisplayName(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ public class SnowflakeSourceAccessor extends SnowflakeAccessor {
"OVERWRITE=TRUE HEADER=TRUE SINGLE=FALSE";
private static final String COMMAND_MAX_FILE_SIZE = " MAX_FILE_SIZE=%s";
private final SnowflakeBatchSourceConfig config;
private final char escapeChar;

public SnowflakeSourceAccessor(SnowflakeBatchSourceConfig config) {
public SnowflakeSourceAccessor(SnowflakeBatchSourceConfig config, String escapeChar) {
super(config);
this.config = config;
this.escapeChar = escapeChar.charAt(0);
}

/**
Expand Down Expand Up @@ -116,7 +118,7 @@ public CSVReader buildCsvReader(String stageSplit) throws IOException {
InputStream downloadStream = connection.unwrap(SnowflakeConnection.class)
.downloadStream("@~", stageSplit, true);
InputStreamReader inputStreamReader = new InputStreamReader(downloadStream);
return new CSVReader(inputStreamReader);
return new CSVReader(inputStreamReader, ',', '"', escapeChar);
} catch (SQLException e) {
throw new IOException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import io.cdap.plugin.snowflake.Constants;
import io.cdap.plugin.snowflake.common.BaseSnowflakeTest;
import io.cdap.plugin.snowflake.source.batch.SnowflakeInputFormatProvider;
import io.cdap.plugin.snowflake.source.batch.SnowflakeSourceAccessor;
import org.junit.Assert;
import org.junit.Test;
Expand All @@ -44,7 +45,8 @@
*/
public class SnowflakeAccessorTest extends BaseSnowflakeTest {

private SnowflakeSourceAccessor snowflakeAccessor = new SnowflakeSourceAccessor(CONFIG);
private SnowflakeSourceAccessor snowflakeAccessor =
new SnowflakeSourceAccessor(CONFIG, SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR);

@Test
public void testDescribeQuery() throws Exception {
Expand Down

0 comments on commit 2b7f72e

Please sign in to comment.