@@ -50,7 +50,6 @@ func detectNodeType(node *pg_query.Node) []*pg_query.Node {
50
50
DropStmt (stmt )
51
51
52
52
case * pg_query.Node_AlterTableStmt :
53
- fmt .Println ("Alter Type" )
54
53
return AlterStmt (node )
55
54
56
55
case * pg_query.Node_SelectStmt :
@@ -88,10 +87,6 @@ func AlterStmt(node *pg_query.Node) []*pg_query.Node {
88
87
for _ , cmd := range initialCommands {
89
88
switch v := cmd .Node .(type ) {
90
89
case * pg_query.Node_AlterTableCmd :
91
- fmt .Printf ("%#v\n " , v )
92
- fmt .Printf ("%#v\n " , v .AlterTableCmd .Def .Node )
93
- fmt .Println (v .AlterTableCmd .Subtype .Enum ())
94
-
95
90
switch v .AlterTableCmd .Subtype {
96
91
case pg_query .AlterTableType_AT_AddColumn :
97
92
def := v .AlterTableCmd .Def .GetColumnDef ()
@@ -106,13 +101,13 @@ func AlterStmt(node *pg_query.Node) []*pg_query.Node {
106
101
if index , ok := constraintsMap [pg_query .ConstrType_CONSTR_DEFAULT ]; ok {
107
102
def .Constraints = make ([]* pg_query.Node , 0 )
108
103
109
- alterStmts = append (alterStmts , node )
104
+ alterStmts = append (alterStmts , wrapTransaction ([] * pg_query. Node { node }) ... )
110
105
111
106
defaultDefinitionTemp := fmt .Sprintf (`alter table %s alter column %s set default %v;` ,
112
107
alterTableStmt .GetRelation ().GetRelname (), def .Colname ,
113
108
constraints [index ].GetConstraint ().GetRawExpr ().GetAConst ().GetVal ().GetInteger ().GetIval ())
114
109
115
- alterStmts = append (alterStmts , generateNodes (defaultDefinitionTemp )... )
110
+ alterStmts = append (alterStmts , wrapTransaction ( generateNodes (defaultDefinitionTemp ) )... )
116
111
117
112
// TODO: Update rows
118
113
@@ -125,12 +120,12 @@ func AlterStmt(node *pg_query.Node) []*pg_query.Node {
125
120
constraint := v .AlterTableCmd .Def .GetConstraint ()
126
121
constraint .SkipValidation = true
127
122
128
- alterStmts = append (alterStmts , node )
123
+ alterStmts = append (alterStmts , wrapTransaction ([] * pg_query. Node { node }) ... )
129
124
130
- validationTemp := fmt .Sprintf (`begin; alter table %s validate constraint %s; commit ;` ,
125
+ validationTemp := fmt .Sprintf (`alter table %s validate constraint %s;` ,
131
126
alterTableStmt .GetRelation ().GetRelname (), constraint .GetConname ())
132
127
133
- alterStmts = append (alterStmts , generateNodes (validationTemp )... )
128
+ alterStmts = append (alterStmts , wrapTransaction ( generateNodes (validationTemp ) )... )
134
129
135
130
default :
136
131
alterStmts = append (alterStmts , node )
@@ -160,3 +155,14 @@ func generateNodes(nodeTemplate string) []*pg_query.Node {
160
155
161
156
return nodes
162
157
}
158
+
159
+ // wrapTransaction wraps nodes into transaction statements.
160
+ func wrapTransaction (nodes []* pg_query.Node ) []* pg_query.Node {
161
+ begin := makeBeginTransactionStmt ()
162
+ commit := makeCommitTransactionStmt ()
163
+
164
+ nodes = append ([]* pg_query.Node {begin }, nodes ... )
165
+ nodes = append (nodes , commit )
166
+
167
+ return nodes
168
+ }
0 commit comments