package sqltabler import ( "errors" "io" "strings" sqlparser "github.com/rqlite/sql" ) // Modify parses each statement in stmt, adds a prefix and suffix to every // table, view, trigger, and index name, and then returns the result as a string. func Modify(stmt, prefix, suffix string) (string, error) { parser := sqlparser.NewParser(strings.NewReader(stmt)) sb := strings.Builder{} for { s, err := parser.ParseStatement() if errors.Is(err, io.EOF) { break } else if err != nil { return "", err } modify(s, prefix, suffix) sb.WriteString(s.String()) sb.WriteByte(';') } return sb.String(), nil } // ModifyStmt clones stmt, adds a prefix and suffix to every table, view, trigger, and // index name, and returns the modified statement. func ModifyStmt(stmt sqlparser.Statement, prefix, suffix string) sqlparser.Statement { stmt = sqlparser.CloneStatement(stmt) modify(stmt, prefix, suffix) return stmt } // modify recursively changes all the table, view, trigger, and index names in a single statement. func modify(stmt any, prefix, suffix string) { switch stmt := stmt.(type) { case *sqlparser.SelectStatement: modifySource(stmt.Source, prefix, suffix) modify(stmt.WhereExpr, prefix, suffix) case *sqlparser.InsertStatement: stmt.Table.Name = prefix + stmt.Table.Name + suffix if stmt.Select != nil { modify(stmt.Select, prefix, suffix) } case *sqlparser.UpdateStatement: stmt.Table.Name.Name = prefix + stmt.Table.Name.Name + suffix for _, assignment := range stmt.Assignments { modify(assignment, prefix, suffix) } case *sqlparser.CreateTableStatement: stmt.Name.Name = prefix + stmt.Name.Name + suffix if stmt.Select != nil { modify(stmt.Select, prefix, suffix) } for _, col := range stmt.Columns { modify(col, prefix, suffix) } for _, constraint := range stmt.Constraints { modify(constraint, prefix, suffix) } case *sqlparser.CreateViewStatement: stmt.Name.Name = prefix + stmt.Name.Name + suffix if stmt.Select != nil { modify(stmt.Select, prefix, suffix) } case *sqlparser.AlterTableStatement: stmt.Name.Name = prefix + stmt.Name.Name + suffix if stmt.NewName != nil { stmt.NewName.Name = prefix + stmt.NewName.Name + suffix } if stmt.ColumnDef != nil { modify(stmt.ColumnDef, prefix, suffix) } case *sqlparser.Call: for _, arg := range stmt.Args { modify(arg, prefix, suffix) } case *sqlparser.FilterClause: modify(stmt.X, prefix, suffix) case *sqlparser.DeleteStatement: stmt.Table.Name.Name = prefix + stmt.Table.Name.Name + suffix if stmt.WhereExpr != nil { modify(stmt.WhereExpr, prefix, suffix) } case *sqlparser.AnalyzeStatement: stmt.Name.Name = prefix + stmt.Name.Name + suffix case *sqlparser.ExplainStatement: modify(stmt.Stmt, prefix, suffix) case *sqlparser.CreateIndexStatement: stmt.Name.Name = prefix + stmt.Name.Name + suffix stmt.Table.Name = prefix + stmt.Table.Name + suffix case *sqlparser.CreateTriggerStatement: stmt.Name.Name = prefix + stmt.Name.Name + suffix stmt.Table.Name = prefix + stmt.Table.Name + suffix if stmt.WhenExpr != nil { modify(stmt.WhenExpr, prefix, suffix) } for _, istmt := range stmt.Body { modify(istmt, prefix, suffix) } case *sqlparser.CTE: stmt.TableName.Name = prefix + stmt.TableName.Name + suffix if stmt.Select != nil { modify(stmt.Select, prefix, suffix) } case *sqlparser.DropTableStatement: stmt.Name.Name = prefix + stmt.Name.Name + suffix case *sqlparser.DropViewStatement: stmt.Name.Name = prefix + stmt.Name.Name + suffix case *sqlparser.DropIndexStatement: stmt.Name.Name = prefix + stmt.Name.Name + suffix case *sqlparser.DropTriggerStatement: stmt.Name.Name = prefix + stmt.Name.Name + suffix case *sqlparser.ForeignKeyConstraint: stmt.ForeignTable.Name = prefix + stmt.ForeignTable.Name + suffix case *sqlparser.OnConstraint: modify(stmt.X, prefix, suffix) case *sqlparser.ExprList: for _, expr := range stmt.Exprs { modify(expr, prefix, suffix) } case *sqlparser.UnaryExpr: modify(stmt.X, prefix, suffix) case *sqlparser.BinaryExpr: modify(stmt.X, prefix, suffix) modify(stmt.Y, prefix, suffix) case *sqlparser.ParenExpr: modify(stmt.X, prefix, suffix) case *sqlparser.CastExpr: modify(stmt.X, prefix, suffix) case *sqlparser.OrderingTerm: modify(stmt.X, prefix, suffix) case *sqlparser.Assignment: modify(stmt.Expr, prefix, suffix) case *sqlparser.ColumnDefinition: for _, constraint := range stmt.Constraints { modify(constraint, prefix, suffix) } case *sqlparser.QualifiedRef: if stmt.Table != nil { stmt.Table.Name = prefix + stmt.Table.Name + suffix } case *sqlparser.CaseExpr: modify(stmt.ElseExpr, prefix, suffix) for _, block := range stmt.Blocks { modify(block.Condition, prefix, suffix) modify(block.Body, prefix, suffix) } } } func modifySource(source sqlparser.Source, prefix, suffix string) { switch source := source.(type) { case *sqlparser.QualifiedTableName: source.Name.Name = prefix + source.Name.Name + suffix case *sqlparser.JoinClause: modifySource(source.X, prefix, suffix) modifySource(source.Y, prefix, suffix) modify(source.Constraint, prefix, suffix) } }