Skip to content

Commit

Permalink
Merge pull request #2 from xeger/insertPositional
Browse files Browse the repository at this point in the history
Fix bugs
  • Loading branch information
xeger authored Sep 29, 2023
2 parents 8139192 + 417c2a6 commit ea05dc0
Show file tree
Hide file tree
Showing 10 changed files with 149 additions and 61 deletions.
2 changes: 1 addition & 1 deletion cmd/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func init() {

func extract(cmd *cobra.Command, args []string) {
if len(args) != 1 {
ui.Fatalf("Must pass exactly one directory for model storage")
ui.Fatalf("Must pass exactly one field name to extract")
ui.Exit('-')
}

Expand Down
23 changes: 10 additions & 13 deletions format/mysql/extract_visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ type extractVisitor struct {

// ExtractStatement pulls interesting field values from INSERT statements.
func (v *extractVisitor) ExtractStatement(stmt ast.StmtNode) []string {
switch stmt.(type) {
switch typed := stmt.(type) {
case *ast.InsertStmt:
v.insert = &insertState{}
v.insert = newInsertState(typed)
v.values = []string{}
stmt.Accept(v)
v.insert = nil
Expand All @@ -25,31 +25,28 @@ func (v *extractVisitor) ExtractStatement(stmt ast.StmtNode) []string {
}

func (v *extractVisitor) Enter(in ast.Node) (ast.Node, bool) {
switch st := in.(type) {
switch typed := in.(type) {
case *ast.TableName:
if v.insert != nil {
v.insert.tableName = st.Name.L
v.insert.tableName = typed.Name.L
}
case *ast.ColumnName:
// insert column names present in SQL source; accumulate them
if v.insert != nil {
v.insert.columnNames = append(v.insert.columnNames, st.Name.L)
v.insert.columnNames = append(v.insert.columnNames, typed.Name.L)
}
case *test_driver.ValueExpr:
if v.insert != nil {
// column names omitted from SQL source; infer from table schema
if v.insert.valueIndex == 0 && len(v.insert.columnNames) == 0 {
v.insert.columnNames = v.ctx.TableColumns[v.insert.tableName]
}
v.insert.ObserveContext(v.ctx)
defer func() {
v.insert.valueIndex++
v.insert.Advance()
}()
switch st.Kind() {
switch typed.Kind() {
case test_driver.KindString:
if v.MatchFieldName(v.insert.Names()) {
v.values = append(v.values, st.Datum.GetString())
v.values = append(v.values, typed.Datum.GetString())
}
return st, true
return typed, true
}
}
}
Expand Down
23 changes: 10 additions & 13 deletions format/mysql/learn_visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,42 @@ type learnVisitor struct {

// LearnStatement trains models based on values in a SQL insert AST.
func (v *learnVisitor) LearnStatement(stmt ast.StmtNode) {
switch stmt.(type) {
switch typed := stmt.(type) {
case *ast.InsertStmt:
v.insert = &insertState{}
v.insert = newInsertState(typed)
stmt.Accept(v)
v.insert = nil
}
}

func (v *learnVisitor) Enter(in ast.Node) (ast.Node, bool) {
switch st := in.(type) {
switch typed := in.(type) {
case *ast.TableName:
if v.insert != nil {
v.insert.tableName = st.Name.L
v.insert.tableName = typed.Name.L
}
case *ast.ColumnName:
// insert column names present in SQL source; accumulate them
if v.insert != nil {
v.insert.columnNames = append(v.insert.columnNames, st.Name.L)
v.insert.columnNames = append(v.insert.columnNames, typed.Name.L)
}
case *test_driver.ValueExpr:
if v.insert != nil {
// column names omitted from SQL source; infer from table schema
if v.insert.valueIndex == 0 && len(v.insert.columnNames) == 0 {
v.insert.columnNames = v.ctx.TableColumns[v.insert.tableName]
}
v.insert.ObserveContext(v.ctx)
defer func() {
v.insert.valueIndex++
v.insert.Advance()
}()
switch st.Kind() {
switch typed.Kind() {
case test_driver.KindString:
disposition, _ := v.policy.MatchFieldName(v.insert.Names())
switch disposition.Action() {
case "generate":
model := v.models[disposition.Parameter()]
if model != nil {
model.Train(st.Datum.GetString())
model.Train(typed.Datum.GetString())
}
}
return st, true
return typed, true
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions format/mysql/schema_info_visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ func (v *schemaInfoVisitor) ScanStatement(stmt ast.StmtNode) {
}

func (v *schemaInfoVisitor) Enter(in ast.Node) (ast.Node, bool) {
switch st := in.(type) {
switch typed := in.(type) {
case *ast.TableName:
v.tableName = st.Name.L
v.tableName = typed.Name.L
if v.info.TableColumns[v.tableName] == nil {
v.info.TableColumns[v.tableName] = make([]string, 0, 32)
}
case *ast.ColumnDef:
v.columnDef = true
case *ast.ColumnName:
if v.columnDef {
v.info.TableColumns[v.tableName] = append(v.info.TableColumns[v.tableName], st.Name.L)
v.info.TableColumns[v.tableName] = append(v.info.TableColumns[v.tableName], typed.Name.L)
}
}
return in, false
Expand Down
20 changes: 17 additions & 3 deletions format/mysql/scrub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,37 @@ package mysql_test
import (
"bufio"
"bytes"
"io/ioutil"
"os"
"strings"
"testing"

"github.com/xeger/pipeclean/format/mysql"
"github.com/xeger/pipeclean/scrubbing"
)

var nullPolicy = &scrubbing.Policy{}

func read(t *testing.T, name string) string {
data, err := ioutil.ReadFile("testdata/" + name)
data, err := os.ReadFile("testdata/" + name)
if err != nil {
t.Fatalf("Failed to read test file %s: %s", name, err)
}
return string(data)
}

func scrub(ctx *mysql.Context, input string) string {
return scrubPolicy(ctx, input, scrubbing.DefaultPolicy())
}

func scrubPolicy(ctx *mysql.Context, input string, policy *scrubbing.Policy) string {
reader := bufio.NewReader(bytes.NewBufferString(input))
in := make(chan string)

out := make(chan string)
output := bytes.NewBuffer(make([]byte, 0, len(input)))
writer := bufio.NewWriter(output)

scrubber := scrubbing.NewScrubber("", false, scrubbing.DefaultPolicy(), nil)
scrubber := scrubbing.NewScrubber("", false, policy, nil)
go mysql.ScrubChan(ctx, scrubber, in, out)

for {
Expand Down Expand Up @@ -96,3 +102,11 @@ func TestInsertPositional(t *testing.T) {
t.Errorf("UNLOCK TABLES statement is missing")
}
}

func TestInsertPositionalNoScan(t *testing.T) {
input := read(t, "insert-positional.sql")

ctx := mysql.NewContext()
// output may not be useful, but it shouldn't crash if there are no column names to work with!
scrub(ctx, input)
}
21 changes: 9 additions & 12 deletions format/mysql/scrub_visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ type scrubVisitor struct {
// May modify the AST in-place (and return it), or may return a derived AST.
// Returns nil if the entire statement should be omitted from output.
func (v *scrubVisitor) ScrubStatement(stmt ast.StmtNode) (ast.StmtNode, bool) {
switch stmt.(type) {
switch typed := stmt.(type) {
case *ast.InsertStmt:
if doInserts {
v.insert = &insertState{}
v.insert = newInsertState(typed)
stmt.Accept(v)
v.insert = nil
return stmt, true
Expand All @@ -35,29 +35,26 @@ func (v *scrubVisitor) ScrubStatement(stmt ast.StmtNode) (ast.StmtNode, bool) {
}

func (v *scrubVisitor) Enter(in ast.Node) (ast.Node, bool) {
switch st := in.(type) {
switch typed := in.(type) {
case *ast.TableName:
if v.insert != nil {
v.insert.tableName = st.Name.L
v.insert.tableName = typed.Name.L
}
case *ast.ColumnName:
// insert column names present in SQL source; accumulate them
if v.insert != nil {
v.insert.columnNames = append(v.insert.columnNames, st.Name.L)
v.insert.columnNames = append(v.insert.columnNames, typed.Name.L)
}
case *test_driver.ValueExpr:
if v.insert != nil {
// column names omitted from SQL source; infer from table schema
if v.insert.valueIndex == 0 && len(v.insert.columnNames) == 0 {
v.insert.columnNames = v.ctx.TableColumns[v.insert.tableName]
}
v.insert.ObserveContext(v.ctx)
defer func() {
v.insert.valueIndex++
v.insert.Advance()
}()
switch st.Kind() {
switch typed.Kind() {
case test_driver.KindString:
datum := test_driver.Datum{}
s := st.Datum.GetString()
s := typed.Datum.GetString()
names := v.insert.Names()
if v.scrubber.EraseString(s, names) {
datum.SetNull()
Expand Down
42 changes: 38 additions & 4 deletions format/mysql/state.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,50 @@
package mysql

import "fmt"
import (
"fmt"

"github.com/pingcap/tidb/parser/ast"
)

type insertState struct {
// Name of the table being inserted into.
tableName string
// List of column names (explicitly specified in current statement, or inferred from table schema).
columnNames []string
// Value tuple sizes of the current statement.
rowLength int
// Number of ValueExpr seen so far across all rows of current statement.
valueIndex int
}

// Names returns a list of column names to which the Next ValueExpr will apply.
func newInsertState(stmt *ast.InsertStmt) *insertState {
rowLength := 0
for _, list := range stmt.Lists {
if rowLength == 0 {
rowLength = len(list)
} else if len(list) != rowLength {
// TODO: handle this case by storing an array of row-tuple lengths & iterating through it
panic(fmt.Sprintf("inconsistent INSERT row lengths: %d prior vs %d next", rowLength, len(list)))
}
}
return &insertState{rowLength: rowLength}
}

// Advance increments the column-value index so that Names() remains accurate.
func (is *insertState) Advance() {
is.valueIndex += 1
}

// Names returns a list of column names to which the next ValueExpr will apply.
// The list contains 0-3 elements depending on the completeness of the schema
// information provided in context.
func (is insertState) Names() []string {
func (is *insertState) Names() []string {
colIdx := is.valueIndex
if is.rowLength > 0 {
colIdx = colIdx % is.rowLength
}
names := make([]string, 0, 3)
if len(is.tableName) > 0 {
colIdx := is.valueIndex % len(is.columnNames)
if len(is.columnNames) > 0 {
colName := is.columnNames[colIdx]
names = append(names, colName)
Expand All @@ -28,3 +55,10 @@ func (is insertState) Names() []string {

return names
}

// If column names were omitted from the SQL INSERT statement, infer them from the previously-scanned table schema.
func (is *insertState) ObserveContext(ctx *Context) {
if is.valueIndex == 0 && len(is.columnNames) == 0 {
is.columnNames = ctx.TableColumns[is.tableName]
}
}
39 changes: 27 additions & 12 deletions scrubbing/scrubber.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,21 @@ import (
"gopkg.in/yaml.v3"
)

// ReShortExtension identifies filename-like extensions at the end of strings.
var reShortExtension = regexp.MustCompile(`[.][a-z]{2,5}$`)

func isJsonData(s string) bool {
if len(s) >= 2 {
f, l := s[0], s[len(s)-1]
return (f == '{' && l == '}') || (f == '[' && l == ']')
}
return false
}

func isYamlData(s string) bool {
return strings.Index(s, "---\n") == 0
}

type Scrubber struct {
maskAll bool
models map[string]nlp.Model
Expand Down Expand Up @@ -146,23 +159,25 @@ func (sc *Scrubber) ScrubString(s string, names []string) string {
if !sc.shallow {
var data any

if err := json.Unmarshal([]byte(s), &data); err == nil {
scrubbed, err := json.Marshal(sc.ScrubData(data, nil))
if err != nil {
ui.Fatal(err)
}
return string(scrubbed)
}

if err := yaml.Unmarshal([]byte(s), &data); err == nil {
switch v := data.(type) {
case []any, map[string]any:
scrubbed, err := yaml.Marshal(sc.ScrubData(v, nil))
if isJsonData(s) {
if err := json.Unmarshal([]byte(s), &data); err == nil {
scrubbed, err := json.Marshal(sc.ScrubData(data, nil))
if err != nil {
ui.Fatal(err)
}
return string(scrubbed)
}
} else if isYamlData(s) {
if err := yaml.Unmarshal([]byte(s), &data); err == nil {
switch v := data.(type) {
case []any, map[string]any:
scrubbed, err := yaml.Marshal(sc.ScrubData(v, nil))
if err != nil {
ui.Fatal(err)
}
return string(scrubbed)
}
}
}

// Empty serialized Ruby YAML hashes.
Expand Down
Loading

0 comments on commit ea05dc0

Please sign in to comment.