Skip to content

Commit 72e7bc2

Browse files
committed
feat(schema): optimize queries: drop index, alter table add constraint
1 parent f3eda22 commit 72e7bc2

File tree

2 files changed

+194
-169
lines changed

2 files changed

+194
-169
lines changed

engine/cmd/pgquery/main.go

Lines changed: 0 additions & 169 deletions
This file was deleted.

engine/cmd/schema-diff/main.go

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"log"
6+
7+
pg_query "github.com/pganalyze/pg_query_go/v2"
8+
)
9+
10+
const idxExample = `
11+
CREATE UNIQUE INDEX title_idx ON films (title);
12+
13+
DROP INDEX title_idx;
14+
15+
ALTER TABLE distributors
16+
ADD CONSTRAINT zipchk CHECK (char_length(zipcode) = 5);
17+
18+
ALTER TABLE pgbench_accounts
19+
ADD COLUMN test integer NOT NULL DEFAULT 0;
20+
`
21+
22+
/*
23+
Optimized queries:
24+
25+
CREATE UNIQUE INDEX CONCURRENTLY title_idx ON films USING btree (title);
26+
27+
DROP INDEX CONCURRENTLY title_idx;
28+
29+
ALTER TABLE distributors ADD CONSTRAINT zipchk CHECK (char_length(zipcode) = 5) NOT VALID;
30+
ALTER TABLE distributors VALIDATE CONSTRAINT zipchk;
31+
32+
ALTER TABLE pgbench_accounts ADD COLUMN test int;
33+
ALTER TABLE pgbench_accounts ALTER COLUMN test SET DEFAULT 0;
34+
*/
35+
36+
func main() {
37+
scanTree, err := pg_query.ParseToJSON(idxExample)
38+
if err != nil {
39+
log.Fatal(err)
40+
}
41+
42+
fmt.Printf("JSON: %s\n", scanTree)
43+
44+
idxTree, err := pg_query.Parse(idxExample)
45+
if err != nil {
46+
log.Fatal(err)
47+
}
48+
49+
fmt.Printf("Original query:\n%v\n\n", idxExample)
50+
fmt.Printf("Parse Tree:\n%#v\n\n", idxTree)
51+
52+
stmts := idxTree.GetStmts()
53+
nodes := processStmts(stmts)
54+
idxTree.Stmts = nodes
55+
56+
fmt.Printf("Parse Tree after processing:\n%#v\n\n", idxTree.GetStmts())
57+
58+
resIdxStr, err := pg_query.Deparse(idxTree)
59+
if err != nil {
60+
log.Fatal(err)
61+
}
62+
63+
fmt.Printf("Optimized queries:\n%v\n", resIdxStr)
64+
}
65+
66+
func processStmts(stmts []*pg_query.RawStmt) []*pg_query.RawStmt {
67+
rawStmts := []*pg_query.RawStmt{}
68+
69+
for _, stmt := range stmts {
70+
for _, node := range detectNodeType(stmt.Stmt) {
71+
rawStmt := &pg_query.RawStmt{
72+
Stmt: node,
73+
}
74+
75+
rawStmts = append(rawStmts, rawStmt)
76+
}
77+
}
78+
79+
return rawStmts
80+
}
81+
82+
func detectNodeType(node *pg_query.Node) []*pg_query.Node {
83+
switch stmt := node.Node.(type) {
84+
case *pg_query.Node_IndexStmt:
85+
IndexStmt(stmt)
86+
87+
case *pg_query.Node_DropStmt:
88+
DropStmt(stmt)
89+
90+
case *pg_query.Node_AlterTableStmt:
91+
fmt.Println("Alter Type")
92+
return AlterStmt(node)
93+
94+
case *pg_query.Node_SelectStmt:
95+
fmt.Println("Select Type")
96+
}
97+
98+
return []*pg_query.Node{node}
99+
}
100+
101+
// IndexStmt processes index statement.
102+
func IndexStmt(stmt *pg_query.Node_IndexStmt) {
103+
stmt.IndexStmt.Concurrent = true
104+
}
105+
106+
// DropStmt processes drop statement.
107+
func DropStmt(stmt *pg_query.Node_DropStmt) {
108+
switch stmt.DropStmt.RemoveType {
109+
case pg_query.ObjectType_OBJECT_INDEX:
110+
stmt.DropStmt.Concurrent = true
111+
default:
112+
}
113+
}
114+
115+
// AlterStmt processes alter statement.
116+
func AlterStmt(node *pg_query.Node) []*pg_query.Node {
117+
alterTableStmt := node.GetAlterTableStmt()
118+
if alterTableStmt == nil {
119+
return []*pg_query.Node{node}
120+
}
121+
122+
var alterStmts []*pg_query.Node
123+
124+
initialCommands := alterTableStmt.GetCmds()
125+
126+
for _, cmd := range initialCommands {
127+
switch v := cmd.Node.(type) {
128+
case *pg_query.Node_AlterTableCmd:
129+
fmt.Printf("%#v\n", v)
130+
fmt.Printf("%#v\n", v.AlterTableCmd.Def.Node)
131+
fmt.Println(v.AlterTableCmd.Subtype.Enum())
132+
133+
switch v.AlterTableCmd.Subtype {
134+
case pg_query.AlterTableType_AT_AddColumn:
135+
def := v.AlterTableCmd.Def.GetColumnDef()
136+
137+
constraints := def.GetConstraints()
138+
constraintsMap := make(map[pg_query.ConstrType]int)
139+
140+
for i, constr := range constraints {
141+
constraintsMap[constr.GetConstraint().Contype] = i
142+
}
143+
144+
if index, ok := constraintsMap[pg_query.ConstrType_CONSTR_DEFAULT]; ok {
145+
def.Constraints = make([]*pg_query.Node, 0)
146+
147+
alterStmts = append(alterStmts, node)
148+
149+
defaultDefinitionTemp := fmt.Sprintf(`ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %v;`,
150+
alterTableStmt.GetRelation().GetRelname(), def.Colname,
151+
constraints[index].GetConstraint().GetRawExpr().GetAConst().GetVal().GetInteger().GetIval())
152+
153+
alterStmts = append(alterStmts, generateNode(defaultDefinitionTemp))
154+
155+
// TODO: Update rows
156+
157+
// TODO: apply the rest constraints
158+
constraints = append(constraints[:index], constraints[index+1:]...)
159+
fmt.Println(constraints)
160+
}
161+
162+
case pg_query.AlterTableType_AT_AddConstraint:
163+
constraint := v.AlterTableCmd.Def.GetConstraint()
164+
constraint.SkipValidation = true
165+
166+
alterStmts = append(alterStmts, node)
167+
168+
validationTemp := fmt.Sprintf(`ALTER TABLE %s VALIDATE CONSTRAINT %s;`,
169+
alterTableStmt.GetRelation().GetRelname(), constraint.GetConname())
170+
171+
alterStmts = append(alterStmts, generateNode(validationTemp))
172+
173+
default:
174+
alterStmts = append(alterStmts, node)
175+
}
176+
177+
default:
178+
alterStmts = append(alterStmts, node)
179+
180+
fmt.Printf("%T\n", v)
181+
}
182+
}
183+
184+
return alterStmts
185+
}
186+
187+
func generateNode(nodeTemplate string) *pg_query.Node {
188+
defDefinition, err := pg_query.Parse(nodeTemplate)
189+
if err != nil {
190+
log.Fatal(err)
191+
}
192+
193+
return defDefinition.Stmts[0].Stmt
194+
}

0 commit comments

Comments
 (0)