Skip to content

Commit e261000

Browse files
committed
feat(engine): process foreign key constraint
1 parent 72e7bc2 commit e261000

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

engine/cmd/schema-diff/main.go

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ DROP INDEX title_idx;
1515
ALTER TABLE distributors
1616
ADD CONSTRAINT zipchk CHECK (char_length(zipcode) = 5);
1717
18+
ALTER TABLE distributors
19+
ADD CONSTRAINT distfk FOREIGN KEY (address) REFERENCES addresses (address);
20+
1821
ALTER TABLE pgbench_accounts
1922
ADD COLUMN test integer NOT NULL DEFAULT 0;
2023
`
@@ -27,7 +30,10 @@ CREATE UNIQUE INDEX CONCURRENTLY title_idx ON films USING btree (title);
2730
DROP INDEX CONCURRENTLY title_idx;
2831
2932
ALTER TABLE distributors ADD CONSTRAINT zipchk CHECK (char_length(zipcode) = 5) NOT VALID;
30-
ALTER TABLE distributors VALIDATE CONSTRAINT zipchk;
33+
BEGIN; ALTER TABLE distributors VALIDATE CONSTRAINT zipchk; COMMIT;
34+
35+
ALTER TABLE distributors ADD CONSTRAINT distfk FOREIGN KEY (address) REFERENCES addresses (address) NOT VALID;
36+
BEGIN; ALTER TABLE distributors VALIDATE CONSTRAINT distfk; COMMIT;
3137
3238
ALTER TABLE pgbench_accounts ADD COLUMN test int;
3339
ALTER TABLE pgbench_accounts ALTER COLUMN test SET DEFAULT 0;
@@ -146,11 +152,11 @@ func AlterStmt(node *pg_query.Node) []*pg_query.Node {
146152

147153
alterStmts = append(alterStmts, node)
148154

149-
defaultDefinitionTemp := fmt.Sprintf(`ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %v;`,
155+
defaultDefinitionTemp := fmt.Sprintf(`alter table %s alter column %s set default %v;`,
150156
alterTableStmt.GetRelation().GetRelname(), def.Colname,
151157
constraints[index].GetConstraint().GetRawExpr().GetAConst().GetVal().GetInteger().GetIval())
152158

153-
alterStmts = append(alterStmts, generateNode(defaultDefinitionTemp))
159+
alterStmts = append(alterStmts, generateNodes(defaultDefinitionTemp)...)
154160

155161
// TODO: Update rows
156162

@@ -165,10 +171,10 @@ func AlterStmt(node *pg_query.Node) []*pg_query.Node {
165171

166172
alterStmts = append(alterStmts, node)
167173

168-
validationTemp := fmt.Sprintf(`ALTER TABLE %s VALIDATE CONSTRAINT %s;`,
174+
validationTemp := fmt.Sprintf(`begin; alter table %s validate constraint %s; commit;`,
169175
alterTableStmt.GetRelation().GetRelname(), constraint.GetConname())
170176

171-
alterStmts = append(alterStmts, generateNode(validationTemp))
177+
alterStmts = append(alterStmts, generateNodes(validationTemp)...)
172178

173179
default:
174180
alterStmts = append(alterStmts, node)
@@ -184,11 +190,16 @@ func AlterStmt(node *pg_query.Node) []*pg_query.Node {
184190
return alterStmts
185191
}
186192

187-
func generateNode(nodeTemplate string) *pg_query.Node {
193+
func generateNodes(nodeTemplate string) []*pg_query.Node {
188194
defDefinition, err := pg_query.Parse(nodeTemplate)
189195
if err != nil {
190196
log.Fatal(err)
191197
}
192198

193-
return defDefinition.Stmts[0].Stmt
199+
nodes := []*pg_query.Node{}
200+
for _, rawStmt := range defDefinition.Stmts {
201+
nodes = append(nodes, rawStmt.Stmt)
202+
}
203+
204+
return nodes
194205
}

0 commit comments

Comments
 (0)