Skip to content

Commit

Permalink
Merge pull request #405 from bruin-data/bq-patch
Browse files Browse the repository at this point in the history
Bq patch
  • Loading branch information
terzioglub authored Jan 16, 2025
2 parents 546f6d2 + 6f0e927 commit 212e944
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 188 deletions.
16 changes: 10 additions & 6 deletions pkg/bigquery/checks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ func (m *mockQuerierWithResult) SelectWithSchema(ctx context.Context, q *query.Q
return result, args.Error(1)
}

func (m *mockQuerierWithResult) DeleteTableIfPartitioningOrClusteringMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) error {
args := m.Called(ctx, tableName, asset)
return args.Error(0)
func (m *mockQuerierWithResult) IsPartitioningOrClusteringMismatch(ctx context.Context, meta *bigquery.TableMetadata, asset *pipeline.Asset) bool {
args := m.Called(ctx, meta, asset)
return args.Bool(0)
}

func (m *mockQuerierWithResult) IsSamePartitioning(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool {
Expand All @@ -75,15 +75,19 @@ func (m *mockQuerierWithResult) IsSameClustering(meta *bigquery.TableMetadata, a
return args.Bool(0)
}

func (m *mockQuerierWithResult) DeleteTableIfMaterializationTypeMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) error {
args := m.Called(ctx, tableName, asset)
return args.Error(0)
func (m *mockQuerierWithResult) IsMaterializationTypeMismatch(ctx context.Context, meta *bigquery.TableMetadata, asset *pipeline.Asset) bool {
args := m.Called(ctx, meta, asset)
return args.Bool(0)
}

func (m *mockQuerierWithResult) CreateDataSetIfNotExist(asset *pipeline.Asset, ctx context.Context) error {
args := m.Called(asset, ctx)
return args.Error(0)
}
func (m *mockQuerierWithResult) DropTableOnMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) error {
args := m.Called(asset, tableName, ctx)
return args.Error(0)
}

type mockConnectionFetcher struct {
mock.Mock
Expand Down
138 changes: 58 additions & 80 deletions pkg/bigquery/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ type MetadataUpdater interface {
}

type TableManager interface {
DeleteTableIfPartitioningOrClusteringMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) error
IsPartitioningOrClusteringMismatch(ctx context.Context, meta *bigquery.TableMetadata, asset *pipeline.Asset) bool
CreateDataSetIfNotExist(asset *pipeline.Asset, ctx context.Context) error
DeleteTableIfMaterializationTypeMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) error
IsMaterializationTypeMismatch(ctx context.Context, meta *bigquery.TableMetadata, asset *pipeline.Asset) bool
DropTableOnMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) error
}

type DB interface {
Expand All @@ -47,10 +48,14 @@ type DB interface {
TableManager
}

var (
datasetNameCache sync.Map // Global cache for dataset existence
datasetLocks sync.Map // Global map for dataset-specific locks
)

type Client struct {
client *bigquery.Client
config *Config
datasetNameCache *sync.Map
client *bigquery.Client
config *Config
}

func NewDB(c *Config) (*Client, error) {
Expand Down Expand Up @@ -83,9 +88,8 @@ func NewDB(c *Config) (*Client, error) {
}

return &Client{
client: client,
config: c,
datasetNameCache: &sync.Map{},
client: client,
config: c,
}, nil
}

Expand Down Expand Up @@ -230,11 +234,15 @@ func (d *Client) UpdateTableMetadataIfNotExist(ctx context.Context, asset *pipel
if err != nil {
return err
}

meta, err := tableRef.Metadata(ctx)
if err != nil {
var apiErr *googleapi.Error
if errors.As(err, &apiErr) && apiErr.Code == 404 {
return nil
}
return err
}

schema := meta.Schema
colsChanged := false
for _, field := range schema {
Expand Down Expand Up @@ -296,42 +304,19 @@ func (d *Client) Ping(ctx context.Context) error {
return nil // Return nil if the query runs successfully
}

func (d *Client) DeleteTableIfPartitioningOrClusteringMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) error {
tableRef, err := d.getTableRef(tableName)
if err != nil {
return err
}
// Fetch table metadata
meta, err := tableRef.Metadata(ctx)
if err != nil {
var apiErr *googleapi.Error
if errors.As(err, &apiErr) && apiErr.Code == 404 {
return nil
}
return fmt.Errorf("failed to fetch metadata for table '%s': %w", tableName, err)
}
func (d *Client) IsPartitioningOrClusteringMismatch(ctx context.Context, meta *bigquery.TableMetadata, asset *pipeline.Asset) bool {
if meta.TimePartitioning != nil || meta.RangePartitioning != nil || asset.Materialization.PartitionBy != "" || len(asset.Materialization.ClusterBy) > 0 {
if !IsSamePartitioning(meta, asset) || !IsSameClustering(meta, asset) {
if err := tableRef.Delete(ctx); err != nil {
return fmt.Errorf("failed to delete table '%s': %w", tableName, err)
}
fmt.Printf("Your table will be dropped and recreated:\n")
fmt.Printf("Table '%s' dropped successfully.\n", tableName)
fmt.Printf("Recreating the table with the new clustering and partitioning strategies...\n")
return true
}
}

return nil
return false
}

func IsSamePartitioning(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool {
if asset.Materialization.PartitionBy != "" &&
meta.TimePartitioning == nil &&
meta.RangePartitioning == nil {
fmt.Printf(
"Mismatch detected: Your table has no partitioning, but you are attempting to partition by '%s'.\n",
asset.Materialization.PartitionBy,
)
return false
}

Expand All @@ -341,22 +326,11 @@ func IsSamePartitioning(meta *bigquery.TableMetadata, asset *pipeline.Asset) boo

if meta.TimePartitioning != nil {
if meta.TimePartitioning.Field != asset.Materialization.PartitionBy {
fmt.Printf(
"Mismatch detected: Your table has a time partitioning strategy with the field '%s', "+
"but you are attempting to use the field '%s'\n",
meta.TimePartitioning.Field,
asset.Materialization.PartitionBy,
)
return false
}
}
if meta.RangePartitioning != nil {
if meta.RangePartitioning.Field != asset.Materialization.PartitionBy {
fmt.Printf(
"Mismatch detected: Your table has a range partitioning strategy with the field '%s',"+
"but you are attempting to use the field '%s'.\n", meta.RangePartitioning.Field,
asset.Materialization.PartitionBy,
)
return false
}
}
Expand All @@ -366,10 +340,6 @@ func IsSamePartitioning(meta *bigquery.TableMetadata, asset *pipeline.Asset) boo
func IsSameClustering(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool {
if len(asset.Materialization.ClusterBy) > 0 &&
(meta.Clustering == nil || len(meta.Clustering.Fields) == 0) {
fmt.Printf(
"Mismatch detected: Your table has no clustering, but you are attempting to cluster by %v.\n",
asset.Materialization.ClusterBy,
)
return false
}
if meta.Clustering == nil {
Expand All @@ -380,20 +350,11 @@ func IsSameClustering(meta *bigquery.TableMetadata, asset *pipeline.Asset) bool
userFields := asset.Materialization.ClusterBy

if len(bigQueryFields) != len(userFields) {
fmt.Printf(
"Mismatch detected: Your table has the clustering fields (%v), but you are trying to use the fields (%v).\n",
bigQueryFields, userFields,
)
return false
}

for i := range bigQueryFields {
if bigQueryFields[i] != userFields[i] {
fmt.Printf(
"Mismatch detected: Your table is clustered by '%s' at position %d, "+
"but you are trying to cluster by '%s'.\n",
bigQueryFields[i], i+1, userFields[i],
)
return false
}
}
Expand All @@ -405,44 +366,62 @@ func (d *Client) CreateDataSetIfNotExist(asset *pipeline.Asset, ctx context.Cont
tableName := asset.Name
tableComponents := strings.Split(tableName, ".")
var datasetName string
var projectID string

switch len(tableComponents) {
case 2:
projectID = d.config.ProjectID
datasetName = tableComponents[0]
case 3:
datasetName = tableComponents[1]
projectID = tableComponents[0]
default:
return nil
}

if _, exists := d.datasetNameCache.Load(datasetName); exists {
cacheKey := fmt.Sprintf("%s.%s", projectID, datasetName)

if _, exists := datasetNameCache.Load(cacheKey); exists {
return nil
}
datasets := d.client.Datasets(ctx)
for {
dataset, err := datasets.Next()
if errors.Is(err, iterator.Done) {
break
}
if err != nil {
return err
}
if datasetName == dataset.DatasetID {
d.datasetNameCache.Store(datasetName, true)
return nil
}

lock, _ := datasetLocks.LoadOrStore(cacheKey, &sync.Mutex{})
mutex := lock.(*sync.Mutex)

mutex.Lock()
defer mutex.Unlock()

if _, exists := datasetNameCache.Load(cacheKey); exists {
return nil
}
if err := d.client.Dataset(datasetName).Create(ctx, &bigquery.DatasetMetadata{}); err != nil {
return err

dataset := d.client.DatasetInProject(projectID, datasetName)
_, err := dataset.Metadata(ctx)
if err != nil {
var apiErr *googleapi.Error
if errors.As(err, &apiErr) && apiErr.Code == 404 {
if err := dataset.Create(ctx, &bigquery.DatasetMetadata{}); err != nil {
return fmt.Errorf("failed to create dataset '%s': %w", cacheKey, err)
}
datasetNameCache.Store(cacheKey, true)
} else {
return fmt.Errorf("failed to fetch metadata for table '%s': %w", tableName, err)
}
}
d.datasetNameCache.Store(datasetName, true)

return nil
}

func (d *Client) DeleteTableIfMaterializationTypeMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) error {
func (d *Client) IsMaterializationTypeMismatch(ctx context.Context, meta *bigquery.TableMetadata, asset *pipeline.Asset) bool {
if asset.Materialization.Type == pipeline.MaterializationTypeNone {
return nil
return false
}

tableType := meta.Type
return !strings.EqualFold(string(tableType), string(asset.Materialization.Type))
}

func (d *Client) DropTableOnMismatch(ctx context.Context, tableName string, asset *pipeline.Asset) error {
tableRef, err := d.getTableRef(tableName)
if err != nil {
return err
Expand All @@ -455,8 +434,7 @@ func (d *Client) DeleteTableIfMaterializationTypeMismatch(ctx context.Context, t
}
return fmt.Errorf("failed to fetch metadata for table '%s': %w", tableName, err)
}
tableType := meta.Type
if !strings.EqualFold(string(tableType), string(asset.Materialization.Type)) {
if d.IsMaterializationTypeMismatch(ctx, meta, asset) || d.IsPartitioningOrClusteringMismatch(ctx, meta, asset) {
if err := tableRef.Delete(ctx); err != nil {
return fmt.Errorf("failed to delete table '%s': %w", tableName, err)
}
Expand Down
Loading

0 comments on commit 212e944

Please sign in to comment.