Skip to content

feat: Add embedded structs based on result match #984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/codegen/golang/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

type Field struct {
Name string
Struct string
Type string
Tags map[string]string
Comment string
Expand Down
7 changes: 5 additions & 2 deletions internal/codegen/golang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,15 +243,18 @@ const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}}

{{if .Arg.EmitStruct}}
type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}}
{{.Name}} {{.Type}} {{if or ($.EmitJSONTags) ($.EmitDBTags)}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}}
{{.Name}} {{.Type}} {{if or ($.EmitJSONTags) ($.EmitDBTags)}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}}
{{- end}}
}
{{end}}

{{if .Ret.EmitStruct}}
type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}}
type {{.Ret.Type}} struct { {{- range .Ret.Struct.Embedded}}
{{.Name}}
{{- end}} {{- range .Ret.Struct.Fields}} {{if not .Struct}}
{{.Name}} {{.Type}} {{if or ($.EmitJSONTags) ($.EmitDBTags)}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}}
{{- end}}
{{- end}}
}
{{end}}

Expand Down
9 changes: 7 additions & 2 deletions internal/codegen/golang/query.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package golang

import (
"fmt"
"strings"

"github.com/kyleconroy/sqlc/internal/metadata"
Expand Down Expand Up @@ -79,10 +80,14 @@ func (v QueryValue) Scan() string {
}
} else {
for _, f := range v.Struct.Fields {
ref := fmt.Sprintf("%s.%s", v.Name, f.Name)
if f.Struct != "" {
ref = fmt.Sprintf("%s.%s.%s", v.Name, f.Struct, f.Name)
}
if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" {
out = append(out, "pq.Array(&"+v.Name+"."+f.Name+")")
out = append(out, "pq.Array(&"+ref+")")
} else {
out = append(out, "&"+v.Name+"."+f.Name)
out = append(out, "&"+ref)
}
}
}
Expand Down
121 changes: 94 additions & 27 deletions internal/codegen/golang/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ func buildStructs(r *compiler.Result, settings config.CombinedSettings) []Struct
}

type goColumn struct {
id int
id int
Embed *Struct
*compiler.Column
}

Expand Down Expand Up @@ -183,37 +184,85 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
Typ: goType(r, c, settings),
}
} else if len(query.Columns) > 1 {
var gs *Struct
var emit bool
var (
columns []goColumn
embedded = map[string]interface{}{}
)

for _, s := range structs {
if len(s.Fields) != len(query.Columns) {
continue
}
same := true
for i, f := range s.Fields {
c := query.Columns[i]
sameName := f.Name == StructName(columnName(c, i), settings)
sameType := f.Type == goType(r, c, settings)
sameTable := sameTableName(c.Table, s.Table, r.Catalog.DefaultSchema)
if !sameName || !sameType || !sameTable {
same = false
for ci := 0; ci < len(query.Columns); {
c := query.Columns[ci]
var embed *Struct

// Checks for matching structs.
for _, s := range structs {
// Ensuring tables match the column selector.
if c.Table == nil || c.Table.Name != s.Table.Rel {
continue
}
// If the query doesn't have enough fields, it cannot
// fufill the struct.
if len(query.Columns)-ci < len(s.Fields) {
continue
}
// We can only embed one struct of each type.
if _, ok := embedded[s.Name]; ok {
continue
}
same := true
for fi, f := range s.Fields {
fieldOffset := ci + fi
// If the location of this field doesn't fit into our columns,
// we know the struct can't fit either.
if fieldOffset > len(query.Columns)-1 {
break
}
c := query.Columns[fieldOffset]
sameName := f.Name == StructName(columnName(c, fieldOffset), settings)
sameType := f.Type == goType(r, c, settings)
sameTable := sameTableName(c.Table, s.Table, r.Catalog.DefaultSchema)
if !sameName || !sameType || !sameTable {
same = false
}
}
if same {
embed = &s
break
}
}
if same {
gs = &s

// Used to track the amount of columns matched.
// A struct could be embedded, and in that case
// for performance we want to skip over those
// matched columns.
colsMatched := 1
if embed != nil {
colsMatched = len(embed.Fields)
embedded[embed.Name] = nil
}
for colID := ci; colID < ci+colsMatched; colID++ {
columns = append(columns, goColumn{
id: colID,
Embed: embed,
Column: query.Columns[colID],
})
}
ci += colsMatched
}

var emit bool
gs := columns[0].Embed
// Check if all columns match a consistent embedded struct.
// If they do, we don't need to generate a new struct for the row.
for _, c := range columns {
// Cheaper to compare the pointer instead of the name.
if gs != c.Embed {
gs = nil
break
}
gs = c.Embed
}

if gs == nil {
var columns []goColumn
for i, c := range query.Columns {
columns = append(columns, goColumn{
id: i,
Column: c,
})
}
gs = columnsToStruct(r, gq.MethodName+"Row", columns, settings)
emit = true
}
Expand Down Expand Up @@ -241,9 +290,18 @@ func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settin
gs := Struct{
Name: name,
}
embedded := map[string]interface{}{}
seen := map[string]int{}
suffixes := map[int]int{}
for i, c := range columns {
if c.Embed != nil {
if _, ok := embedded[c.Embed.Name]; !ok {
// We only want to include each embedded struct once.
gs.Embedded = append(gs.Embedded, *c.Embed)
embedded[c.Embed.Name] = nil
}
}

colName := columnName(c.Column, i)
tagName := colName
fieldName := StructName(colName, settings)
Expand All @@ -256,6 +314,9 @@ func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settin
suffix = v + 1
}
suffixes[c.id] = suffix
if c.Embed != nil {
suffix = 0
}
if suffix > 0 {
tagName = fmt.Sprintf("%s_%d", tagName, suffix)
fieldName = fmt.Sprintf("%s_%d", fieldName, suffix)
Expand All @@ -267,12 +328,18 @@ func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settin
if settings.Go.EmitJSONTags {
tags["json:"] = JSONTagName(tagName, settings)
}
gs.Fields = append(gs.Fields, Field{
f := Field{
Name: fieldName,
Type: goType(r, c.Column, settings),
Tags: tags,
})
seen[colName]++
}
if c.Embed != nil {
f.Struct = c.Embed.Name
}
gs.Fields = append(gs.Fields, f)
if c.Embed == nil {
seen[colName]++
}
}
return &gs
}
9 changes: 5 additions & 4 deletions internal/codegen/golang/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ import (
)

type Struct struct {
Table core.FQN
Name string
Fields []Field
Comment string
Table core.FQN
Name string
Embedded []Struct
Fields []Field
Comment string
}

func StructName(name string, settings config.CombinedSettings) string {
Expand Down