diff --git a/flow/cmd/peer_data.go b/flow/cmd/peer_data.go index eb45003b41..5fc347639b 100644 --- a/flow/cmd/peer_data.go +++ b/flow/cmd/peer_data.go @@ -11,6 +11,7 @@ import ( "google.golang.org/protobuf/proto" connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" + connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake" "github.com/PeerDB-io/peer-flow/generated/protos" ) @@ -31,6 +32,23 @@ func (h *FlowRequestHandler) getPGPeerConfig(ctx context.Context, peerName strin return &pgPeerConfig, nil } +func (h *FlowRequestHandler) getSFPeerConfig(ctx context.Context, peerName string) (*protos.SnowflakeConfig, error) { + var sfPeerOptions sql.RawBytes + var sfPeerConfig protos.SnowflakeConfig + err := h.pool.QueryRow(ctx, + "SELECT options FROM peers WHERE name = $1 AND type=1", peerName).Scan(&sfPeerOptions) + if err != nil { + return nil, err + } + + unmarshalErr := proto.Unmarshal(sfPeerOptions, &sfPeerConfig) + if err != nil { + return nil, unmarshalErr + } + + return &sfPeerConfig, nil +} + func (h *FlowRequestHandler) getConnForPGPeer(ctx context.Context, peerName string) (*connpostgres.SSHTunnel, *pgx.Conn, error) { pgPeerConfig, err := h.getPGPeerConfig(ctx, peerName) if err != nil { @@ -52,6 +70,21 @@ func (h *FlowRequestHandler) getConnForPGPeer(ctx context.Context, peerName stri return tunnel, conn, nil } +func (h *FlowRequestHandler) getConnForSFPeer(ctx context.Context, peerName string) (*connsnowflake.SnowflakeConnector, error) { + sfPeerConfig, err := h.getSFPeerConfig(ctx, peerName) + if err != nil { + return nil, err + } + + sfConn, err := connsnowflake.NewSnowflakeConnector(ctx, sfPeerConfig) + if err != nil { + slog.Error("Failed to create snowflake client", slog.Any("error", err)) + return nil, err + } + + return sfConn, nil +} + func (h *FlowRequestHandler) GetSchemas( ctx context.Context, req *protos.PostgresPeerActivityInfoRequest, @@ -138,33 +171,50 @@ func (h *FlowRequestHandler) GetAllTables( ctx context.Context, req *protos.PostgresPeerActivityInfoRequest, ) (*protos.AllTablesResponse, error) { - tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName) - if err != nil { - return &protos.AllTablesResponse{Tables: nil}, err - } - defer tunnel.Close() - defer peerConn.Close(ctx) + switch req.PeerType { + case protos.DBType_SNOWFLAKE: + sfConn, err := h.getConnForSFPeer(ctx, req.PeerName) + if err != nil { + slog.Error("Failed to get snowflake client", slog.Any("error", err)) + return &protos.AllTablesResponse{Tables: nil}, err + } + defer sfConn.Close() + sfTables, err := sfConn.GetAllTables(ctx) + if err != nil { + slog.Error("Failed to get all Snowflake tables", slog.Any("error", err)) + return &protos.AllTablesResponse{Tables: nil}, err + } + return &protos.AllTablesResponse{Tables: sfTables}, nil - rows, err := peerConn.Query(ctx, "SELECT n.nspname || '.' || c.relname AS schema_table "+ - "FROM pg_class c "+ - "JOIN pg_namespace n ON c.relnamespace = n.oid "+ - "WHERE n.nspname !~ '^pg_' AND n.nspname <> 'information_schema' AND c.relkind = 'r';") - if err != nil { - return &protos.AllTablesResponse{Tables: nil}, err - } + default: + tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName) + if err != nil { + return &protos.AllTablesResponse{Tables: nil}, err + } + defer tunnel.Close() + defer peerConn.Close(ctx) - defer rows.Close() - var tables []string - for rows.Next() { - var table pgtype.Text - err := rows.Scan(&table) + rows, err := peerConn.Query(ctx, "SELECT n.nspname || '.' || c.relname AS schema_table "+ + "FROM pg_class c "+ + "JOIN pg_namespace n ON c.relnamespace = n.oid "+ + "WHERE n.nspname !~ '^pg_' AND n.nspname <> 'information_schema' AND c.relkind = 'r';") if err != nil { return &protos.AllTablesResponse{Tables: nil}, err } - tables = append(tables, table.String) + defer rows.Close() + var tables []string + for rows.Next() { + var table pgtype.Text + err := rows.Scan(&table) + if err != nil { + return &protos.AllTablesResponse{Tables: nil}, err + } + + tables = append(tables, table.String) + } + return &protos.AllTablesResponse{Tables: tables}, nil } - return &protos.AllTablesResponse{Tables: tables}, nil } func (h *FlowRequestHandler) GetColumns( diff --git a/flow/connectors/snowflake/client.go b/flow/connectors/snowflake/client.go index 3ee20362c7..59814b8512 100644 --- a/flow/connectors/snowflake/client.go +++ b/flow/connectors/snowflake/client.go @@ -84,6 +84,33 @@ func (c *SnowflakeConnector) getTableCounts(ctx context.Context, tables []string return totalRecords, nil } +func (c *SnowflakeConnector) GetAllTables(ctx context.Context) ([]string, error) { + // return all tables in database in schema.table form + // get it from information schema columns + rows, err := c.database.QueryContext(ctx, ` + SELECT table_schema, table_name + FROM information_schema.tables + WHERE table_type = 'BASE TABLE';`) + if err != nil { + return nil, fmt.Errorf("failed to get tables from Snowflake: %w", err) + } + defer rows.Close() + if rows.Err() != nil { + return nil, fmt.Errorf("failed to get tables from Snowflake: %w", rows.Err()) + } + var tables []string + for rows.Next() { + var schema, table string + err := rows.Scan(&schema, &table) + if err != nil { + return nil, fmt.Errorf("failed to scan table from Snowflake: %w", err) + } + tables = append(tables, fmt.Sprintf(`%s.%s`, schema, table)) + } + + return tables, nil +} + func SnowflakeIdentifierNormalize(identifier string) string { // https://www.alberton.info/dbms_identifiers_and_case_sensitivity.html // Snowflake follows the SQL standard, but Postgres does the opposite. diff --git a/protos/route.proto b/protos/route.proto index 316459d78f..85dc0666d9 100644 --- a/protos/route.proto +++ b/protos/route.proto @@ -139,6 +139,7 @@ message TableColumnsResponse { message PostgresPeerActivityInfoRequest { string peer_name = 1; + peerdb_peers.DBType peer_type = 2; } message SlotInfo { diff --git a/ui/app/api/peers/tables/all/route.ts b/ui/app/api/peers/tables/all/route.ts index 0281cc7067..14f2594c1f 100644 --- a/ui/app/api/peers/tables/all/route.ts +++ b/ui/app/api/peers/tables/all/route.ts @@ -4,11 +4,11 @@ import { GetFlowHttpAddressFromEnv } from '@/rpc/http'; export async function POST(request: Request) { const body = await request.json(); - const { peerName } = body; + const { peerName, peerType } = body; const flowServiceAddr = GetFlowHttpAddressFromEnv(); try { const tableList: AllTablesResponse = await fetch( - `${flowServiceAddr}/v1/peers/tables/all?peer_name=${peerName}` + `${flowServiceAddr}/v1/peers/tables/all?peer_name=${peerName}&peer_type=${peerType}` ).then((res) => { return res.json(); }); diff --git a/ui/app/mirrors/create/handlers.ts b/ui/app/mirrors/create/handlers.ts index b0f186f0e9..6e1ee21d76 100644 --- a/ui/app/mirrors/create/handlers.ts +++ b/ui/app/mirrors/create/handlers.ts @@ -337,12 +337,13 @@ export const fetchColumns = async ( return columnsRes.columns; }; -export const fetchAllTables = async (peerName: string) => { +export const fetchAllTables = async (peerName: string, peerType?: DBType) => { if (peerName?.length === 0) return []; const tablesRes: UTablesAllResponse = await fetch('/api/peers/tables/all', { method: 'POST', body: JSON.stringify({ peerName, + peerType, }), cache: 'no-store', }).then((res) => res.json()); diff --git a/ui/app/mirrors/create/helpers/qrep.ts b/ui/app/mirrors/create/helpers/qrep.ts index 4fc193cd72..ad9e1a0f44 100644 --- a/ui/app/mirrors/create/helpers/qrep.ts +++ b/ui/app/mirrors/create/helpers/qrep.ts @@ -147,7 +147,7 @@ export const snowflakeQRepSettings: MirrorSetting[] = [ ...curr, watermarkTable: (value as string) || '', })), - type: 'text', + type: 'select', tips: 'The source table of the replication and the table to which the watermark column belongs.', required: true, }, diff --git a/ui/app/mirrors/create/qrep/qrep.tsx b/ui/app/mirrors/create/qrep/qrep.tsx index 0ee4107a14..3aeafc6e3a 100644 --- a/ui/app/mirrors/create/qrep/qrep.tsx +++ b/ui/app/mirrors/create/qrep/qrep.tsx @@ -123,7 +123,10 @@ export default function QRepConfigForm({ }; useEffect(() => { - fetchAllTables(mirrorConfig.sourcePeer?.name ?? '').then((tables) => + fetchAllTables( + mirrorConfig.sourcePeer?.name ?? '', + mirrorConfig.sourcePeer?.type + ).then((tables) => setSourceTables(tables?.map((table) => ({ value: table, label: table }))) ); }, [mirrorConfig.sourcePeer]); diff --git a/ui/app/mirrors/create/qrep/snowflakeQrep.tsx b/ui/app/mirrors/create/qrep/snowflakeQrep.tsx index 4137f86019..daf812c2a6 100644 --- a/ui/app/mirrors/create/qrep/snowflakeQrep.tsx +++ b/ui/app/mirrors/create/qrep/snowflakeQrep.tsx @@ -6,10 +6,11 @@ import { RowWithSelect, RowWithSwitch, RowWithTextField } from '@/lib/Layout'; import { Switch } from '@/lib/Switch'; import { TextField } from '@/lib/TextField'; import { Tooltip } from '@/lib/Tooltip'; -import { useEffect } from 'react'; +import { useEffect, useState } from 'react'; import ReactSelect from 'react-select'; import { InfoPopover } from '../../../../components/InfoPopover'; import { MirrorSetter } from '../../types'; +import { fetchAllTables } from '../handlers'; import { MirrorSetting, blankSnowflakeQRepSetting } from '../helpers/common'; import { snowflakeQRepSettings } from '../helpers/qrep'; import QRepQuery from './query'; @@ -23,11 +24,10 @@ export default function SnowflakeQRepForm({ mirrorConfig, setter, }: SnowflakeQRepProps) { - const WriteModes = ['Overwrite'].map((value) => ({ - label: value, - value, - })); - + const [sourceTables, setSourceTables] = useState< + { value: string; label: string }[] + >([]); + const [loading, setLoading] = useState(false); const handleChange = (val: string | boolean, setting: MirrorSetting) => { let stateVal: string | boolean | QRepWriteType | string[] = val; if (setting.label.includes('Write Type')) { @@ -46,6 +46,14 @@ export default function SnowflakeQRepForm({ setting.stateHandler(stateVal, setter); }; + const handleSourceChange = (val: string, setting: MirrorSetting) => { + setter((curr) => ({ + ...curr, + destinationTableIdentifier: val.toLowerCase(), + })); + handleChange(val, setting); + }; + const paramDisplayCondition = (setting: MirrorSetting) => { const label = setting.label.toLowerCase(); if ( @@ -57,6 +65,17 @@ export default function SnowflakeQRepForm({ return true; }; + useEffect(() => { + setLoading(true); + fetchAllTables( + mirrorConfig.sourcePeer?.name ?? '', + mirrorConfig.sourcePeer?.type + ).then((tables) => { + setSourceTables(tables?.map((table) => ({ value: table, label: table }))); + setLoading(false); + }); + }, [mirrorConfig.sourcePeer]); + useEffect(() => { // set defaults setter((curr) => ({ ...curr, ...blankSnowflakeQRepSetting })); @@ -126,11 +145,22 @@ export default function SnowflakeQRepForm({ }} >
- + {setting.label.includes('Write') ? ( + + ) : ( + + val && handleSourceChange(val.value, setting) + } + isLoading={loading} + options={sourceTables} + /> + )}
{setting.tips && ( @@ -171,7 +201,11 @@ export default function SnowflakeQRepForm({ ) => handleChange(e.target.value, setting) }