Skip to content

Commit

Permalink
Fix getting columns and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jduo committed Feb 17, 2024
1 parent c20e67f commit c90881d
Show file tree
Hide file tree
Showing 2 changed files with 530 additions and 482 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.regex.Pattern;
import org.apache.arrow.adbc.core.AdbcConnection;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcStatusCode;
import org.apache.arrow.adbc.core.StandardSchemas;
import org.apache.arrow.driver.jdbc.utils.SqlTypes;
import org.apache.arrow.flight.FlightEndpoint;
Expand Down Expand Up @@ -126,8 +127,11 @@ static ArrowReader CreateGetObjectsReader(

private abstract static class GetObjectMetadataReader extends BaseFlightReader {
private final VectorSchemaRoot aggregateRoot;
private boolean hasLoaded = false;
protected final Text buffer = new Text();

@SuppressWarnings(
"method.invocation") // Checker Framework does not like the ensureInitialized call
protected GetObjectMetadataReader(
BufferAllocator allocator,
FlightSqlClientWithCallOptions client,
Expand All @@ -137,6 +141,17 @@ protected GetObjectMetadataReader(
super(allocator, client, clientCache, rpcCall);
aggregateRoot = VectorSchemaRoot.create(readSchema(), allocator);
populateEndpointData();

try {
this.ensureInitialized();
} catch (IOException e) {
throw new AdbcException(
FlightSqlDriverUtil.prefixExceptionMessage(e.getMessage()),
e,
AdbcStatusCode.IO,
null,
0);
}
}

@Override
Expand All @@ -155,16 +170,23 @@ public void close() throws IOException {

@Override
public boolean loadNextBatch() throws IOException {
while (super.loadNextBatch()) {
// Do nothing. Just iterate through all partitions, processing the data.
}
try {
finish();
} catch (AdbcException e) {
throw new RuntimeException(e);
if (!hasLoaded) {
while (super.loadNextBatch()) {
// Do nothing. Just iterate through all partitions, processing the data.
}
try {
finish();
} catch (AdbcException e) {
throw new RuntimeException(e);
}

hasLoaded = true;
if (aggregateRoot.getRowCount() > 0) {
loadRoot(aggregateRoot);
return true;
}
}
loadRoot(aggregateRoot);
return true;
return false;
}

@Override
Expand Down Expand Up @@ -196,12 +218,14 @@ protected GetCatalogsMetadataReader(
protected void processRootFromStream(VectorSchemaRoot root) {
VarCharVector catalogVector = (VarCharVector) root.getVector(0);
VarCharVector adbcCatalogNames = (VarCharVector) getAggregateRoot().getVector(0);
for (int srcIndex = 0, dstIndex = 0; srcIndex < root.getRowCount(); ++srcIndex) {
int srcIndex = 0, dstIndex = getAggregateRoot().getRowCount();
for (; srcIndex < root.getRowCount(); ++srcIndex) {
catalogVector.read(srcIndex, buffer);
if (catalogPattern.matcher(buffer.toString()).matches()) {
if (catalogPattern == null || catalogPattern.matcher(buffer.toString()).matches()) {
catalogVector.makeTransferPair(adbcCatalogNames).copyValueSafe(srcIndex, dstIndex++);
}
}
getAggregateRoot().setRowCount(dstIndex);
}

private static List<FlightEndpoint> doRequest(FlightSqlClientWithCallOptions client) {
Expand Down Expand Up @@ -242,7 +266,7 @@ protected void processRootFromStream(VectorSchemaRoot root) {
catalog,
(k, v) -> {
if (v == null) {
return Arrays.asList(schema);
v = new ArrayList<>();
}
v.add(schema);
return v;
Expand All @@ -257,6 +281,9 @@ protected void finish() throws AdbcException, IOException {
VarCharVector outputCatalogColumn = (VarCharVector) getAggregateRoot().getVector(0);
try (GetCatalogsMetadataReader catalogReader =
new GetCatalogsMetadataReader(allocator, client, clientCache, catalog)) {
if (!catalogReader.loadNextBatch()) {
return;
}
getAggregateRoot().setRowCount(catalogReader.getAggregateRoot().getRowCount());
VarCharVector catalogColumn = (VarCharVector) catalogReader.getAggregateRoot().getVector(0);
catalogColumn.makeTransferPair(outputCatalogColumn).transfer();
Expand All @@ -267,22 +294,24 @@ protected void finish() throws AdbcException, IOException {
((ListVector) getAggregateRoot().getVector(1)).getWriter();
BaseWriter.StructWriter adbcCatalogDbSchemasStructWriter =
adbcCatalogDbSchemasWriter.struct();
VarCharWriter adbcCatalogDbSchemaNameWriter =
adbcCatalogDbSchemasStructWriter.varChar("db_schema_name");
for (int i = 0; i < getAggregateRoot().getRowCount(); ++i) {
outputCatalogColumn.read(i, buffer);
String catalog = buffer.toString();
List<String> schemas = catalogToSchemaMap.get(catalog);
adbcCatalogDbSchemasWriter.setPosition(i);
adbcCatalogDbSchemasWriter.startList();
if (schemas != null) {
adbcCatalogDbSchemasWriter.startList();
for (String schema : schemas) {
adbcCatalogDbSchemasStructWriter.start();
VarCharWriter adbcCatalogDbSchemaNameWriter =
adbcCatalogDbSchemasStructWriter.varChar("db_schema_name");
adbcCatalogDbSchemaNameWriter.writeVarChar(schema);
adbcCatalogDbSchemasStructWriter.end();
}
adbcCatalogDbSchemasWriter.endList();
}
adbcCatalogDbSchemasWriter.endList();
}
adbcCatalogDbSchemasWriter.setValueCount(getAggregateRoot().getRowCount());
}

private static List<FlightEndpoint> doRequest(
Expand Down Expand Up @@ -356,7 +385,8 @@ protected GetTablesMetadataReader(
() -> doRequest(client, catalogPattern, schemaPattern, tablePattern, tableTypes, true));
this.catalogPattern = catalogPattern;
this.dbSchemaPattern = schemaPattern;
compiledColumnNamePattern = columnPattern != null ? Pattern.compile(sqlToRegexLike(columnPattern)) : null;
compiledColumnNamePattern =
columnPattern != null ? Pattern.compile(sqlToRegexLike(columnPattern)) : null;
shouldGetColumns = true;
}

Expand Down Expand Up @@ -391,7 +421,7 @@ protected void processRootFromStream(VectorSchemaRoot root) {
if (tableTypeVector.isNull(i)) {
tableType = null;
} else {
tableVector.read(i, buffer);
tableTypeVector.read(i, buffer);
tableType = buffer.toString();
}

Expand Down Expand Up @@ -428,70 +458,82 @@ protected void finish() throws AdbcException, IOException {
// Create a schema-only reader to get the catalog->schema hierarchy, including empty catalogs
// and schemas.
// Then transfer the contents of this to the current reader's root.
VarCharVector outputCatalogColumn = (VarCharVector) getAggregateRoot().getVector(0);
ListVector outputSchemaStructList = (ListVector) getAggregateRoot().getVector(1);
try (GetDbSchemasMetadataReader schemaReader =
new GetDbSchemasMetadataReader(
allocator, client, clientCache, catalogPattern, dbSchemaPattern)) {
if (!schemaReader.loadNextBatch()) {
return;
}
VarCharVector outputCatalogColumn = (VarCharVector) getAggregateRoot().getVector(0);
ListVector outputSchemaStructList = (ListVector) getAggregateRoot().getVector(1);
ListVector sourceSchemaStructList =
(ListVector) schemaReader.getAggregateRoot().getVector(1);
getAggregateRoot().setRowCount(schemaReader.getAggregateRoot().getRowCount());

VarCharVector catalogColumn = (VarCharVector) schemaReader.getAggregateRoot().getVector(0);
catalogColumn.makeTransferPair(outputCatalogColumn).transfer();
}

// Iterate over catalogs and schemas reported by the GetDbSchemasMetadataReader.
for (int i = 0; i < getAggregateRoot().getRowCount(); ++i) {
outputCatalogColumn.read(i, buffer);
final String catalog = buffer.toString();
// Iterate over catalogs and schemas reported by the GetDbSchemasMetadataReader.
final UnionListWriter schemaListWriter = outputSchemaStructList.getWriter();
for (Object schemaStructObj : outputSchemaStructList.getObject(i)) {
final Map<String, Object> schemaStructAsMap = (Map<String, Object>) schemaStructObj;
String schemaName = (String) schemaStructAsMap.get("db_schema_name");

// If either the catalog or the schema was not reported by the GetTables RPC call during
// processRootFromStream(),
// it means that this was an empty (table-less) catalog or schema pair and should be
// skipped.
final Map<String, Map<String, TableDefinition>> schemaToTableMap =
tablePathToColumnsMap.get(catalog);
if (schemaToTableMap == null) {
continue;
}

final Map<String, TableDefinition> tables = schemaToTableMap.get(schemaName);
if (tables == null) {
continue;
}

// Set up the schema list writer to write at the current position.
schemaListWriter.setPosition(i);
BaseWriter.StructWriter schemaStructWriter = schemaListWriter.struct();
schemaStructWriter.start();
schemaStructWriter.varChar("db_schema_name").writeVarChar(schemaName);
BaseWriter.ListWriter tableWriter = schemaStructWriter.list("db_schema_tables");
// Process each table.
for (Map.Entry<String, TableDefinition> table : tables.entrySet()) {
schemaListWriter.allocate();
for (int i = 0; i < getAggregateRoot().getRowCount(); ++i) {
outputCatalogColumn.read(i, buffer);
final String catalog = buffer.toString();

schemaListWriter.startList();
for (Object schemaStructObj : sourceSchemaStructList.getObject(i)) {
final Map<String, Object> schemaStructAsMap = (Map<String, Object>) schemaStructObj;
String schemaName = schemaStructAsMap.get("db_schema_name").toString();

// Set up the schema list writer to write at the current position.
schemaListWriter.setPosition(i);
BaseWriter.StructWriter schemaStructWriter = schemaListWriter.struct();
schemaStructWriter.start();
schemaStructWriter.varChar("db_schema_name").writeVarChar(schemaName);
BaseWriter.ListWriter tableWriter = schemaStructWriter.list("db_schema_tables");
// Process each table.
tableWriter.startList();
BaseWriter.StructWriter tableStructWriter = tableWriter.struct();
tableStructWriter.start();
tableStructWriter.varChar("table_name").writeVarChar(table.getKey());
tableStructWriter.varChar("table_type").writeVarChar(table.getValue().tableType);

// Process each column if columns are requested.
if (shouldGetColumns) {
BaseWriter.ListWriter columnListWriter = tableStructWriter.list("table_columns");
columnListWriter.startList();
for (ColumnDefinition columnDefinition : table.getValue().columnDefinitions) {
BaseWriter.StructWriter columnDefinitionWriter = columnListWriter.struct();
writeColumnDefinition(columnDefinition, columnDefinitionWriter);

// If either the catalog or the schema was not reported by the GetTables RPC call during
// processRootFromStream(),
// it means that this was an empty (table-less) catalog or schema pair and should be
// skipped.
final Map<String, Map<String, TableDefinition>> schemaToTableMap =
tablePathToColumnsMap.get(catalog);
if (schemaToTableMap != null) {
final Map<String, TableDefinition> tables = schemaToTableMap.get(schemaName);
if (tables != null) {
for (Map.Entry<String, TableDefinition> table : tables.entrySet()) {
BaseWriter.StructWriter tableStructWriter = tableWriter.struct();
tableStructWriter.start();
tableStructWriter.varChar("table_name").writeVarChar(table.getKey());
if (table.getValue().tableType != null) {
tableStructWriter
.varChar("table_type")
.writeVarChar(table.getValue().tableType);
}

// Process each column if columns are requested.
if (shouldGetColumns) {
BaseWriter.ListWriter columnListWriter =
tableStructWriter.list("table_columns");
columnListWriter.startList();
for (ColumnDefinition columnDefinition : table.getValue().columnDefinitions) {
BaseWriter.StructWriter columnDefinitionWriter = columnListWriter.struct();
writeColumnDefinition(columnDefinition, columnDefinitionWriter);
}
columnListWriter.endList();
}
tableStructWriter.end();
}
}
columnListWriter.endList();
}
tableStructWriter.end();
tableWriter.endList();
schemaStructWriter.end();
}
tableWriter.endList();
schemaStructWriter.end();
schemaListWriter.endList();
}
schemaListWriter.setValueCount(getAggregateRoot().getRowCount());
}
}

Expand Down Expand Up @@ -531,6 +573,7 @@ private List<ColumnDefinition> getColumnDefinitions(

private void writeColumnDefinition(
ColumnDefinition columnDefinition, BaseWriter.StructWriter columnDefinitionWriter) {
columnDefinitionWriter.start();
// This code is based on the implementation of getColumns() in the Flight JDBC driver.
columnDefinitionWriter.varChar("column_name").writeVarChar(columnDefinition.field.getName());
columnDefinitionWriter.integer("ordinal_position").writeInt(columnDefinition.ordinal);
Expand Down Expand Up @@ -565,7 +608,7 @@ private void writeColumnDefinition(
if (decimalDigits != null) {
columnDefinitionWriter
.smallInt("xdbc_decimal_digits")
.writeSmallInt(Shorts.saturatedCast(columnDefinition.metadata.getScale()));
.writeSmallInt(Shorts.saturatedCast(decimalDigits));
}

// This is taken from the JDBC driver, but seems wrong that all three branches write the same
Expand Down Expand Up @@ -594,10 +637,13 @@ private void writeColumnDefinition(
// columnDefinitionWriter.varChar("xdbc_scope_catalog").writeVarChar();
// columnDefinitionWriter.varChar("xdbc_scope_schema").writeVarChar();
// columnDefinitionWriter.varChar("xdbc_scope_table").writeVarChar();
columnDefinitionWriter
.bit("xdbc_auto_increment")
.writeBit(columnDefinition.metadata.isAutoIncrement() ? 1 : 0);
if (columnDefinition.metadata.isAutoIncrement() != null) {
columnDefinitionWriter
.bit("xdbc_auto_increment")
.writeBit(columnDefinition.metadata.isAutoIncrement() ? 1 : 0);
}
// columnDefinitionWriter.bit("xdbc_is_generatedcolumn").writeBit();
columnDefinitionWriter.end();
}

private static List<FlightEndpoint> doRequest(
Expand All @@ -608,7 +654,12 @@ private static List<FlightEndpoint> doRequest(
String[] tableTypes,
boolean shouldGetColumns) {
return client
.getTables(catalog, schemaPattern, table, null != tableTypes ? Arrays.asList(tableTypes) : null, shouldGetColumns)
.getTables(
catalog,
schemaPattern,
table,
null != tableTypes ? Arrays.asList(tableTypes) : null,
shouldGetColumns)
.getEndpoints();
}
}
Expand Down
Loading

0 comments on commit c90881d

Please sign in to comment.