Skip to content

Commit 5f28cf2

Browse files
committed
Respect pg_catalog for PSQL when matches overrides
1 parent 88b3cb3 commit 5f28cf2

File tree

3 files changed

+120
-5
lines changed

3 files changed

+120
-5
lines changed

internal/codegen/golang/go_type.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ func addExtraGoStructTags(tags map[string]string, req *plugin.GenerateRequest, o
1414
if oride.GoType.StructTags == nil {
1515
continue
1616
}
17-
if override.MatchesColumn(col) {
17+
if override.MatchesColumn(col, req.Settings.Engine) {
1818
for k, v := range oride.GoType.StructTags {
1919
tags[k] = v
2020
}
@@ -76,7 +76,8 @@ func goInnerType(req *plugin.GenerateRequest, options *opts.Options, col *plugin
7676
if oride.GoType.TypeName == "" {
7777
continue
7878
}
79-
if override.MatchesColumn(col) {
79+
80+
if override.MatchesColumn(col, req.Settings.Engine) {
8081
return oride.GoType.TypeName
8182
}
8283
}

internal/codegen/golang/opts/override.go

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,38 @@ func (o *Override) Matches(n *plugin.Identifier, defaultSchema string) bool {
7777
return true
7878
}
7979

80-
func (o *Override) MatchesColumn(col *plugin.Column) bool {
81-
columnType := sdk.DataType(col.Type)
80+
func typesMatches(dbType string, colType *plugin.Identifier, isPostgresql bool) bool {
81+
if dbType == "" {
82+
return false
83+
}
84+
columnType := sdk.DataType(colType)
85+
if dbType == columnType {
86+
return true
87+
}
88+
// For example, in PostgreSQL, built-in types are in the 'pg_catalog' schema.
89+
// colType Identifier might show them as:
90+
// - Schema: "pg_catalog", Name: "json"
91+
// - Or Name: "pg_catalog.json"
92+
// - Or just Name: "json
93+
// This checks both to match types.
94+
if isPostgresql {
95+
if strings.TrimPrefix(dbType, "pg_catalog.") == columnType {
96+
return true
97+
}
98+
if colType.Schema == "pg_catalog" && colType.Name == dbType {
99+
return true
100+
}
101+
if strings.HasPrefix(colType.Name, "pg_catalog.") {
102+
return colType.Name[len("pg_catalog."):] == dbType
103+
}
104+
}
105+
106+
return false
107+
}
108+
109+
func (o *Override) MatchesColumn(col *plugin.Column, engine string) bool {
82110
notNull := col.NotNull || col.IsArray
83-
return o.DBType != "" && o.DBType == columnType && o.Nullable != notNull && o.Unsigned == col.Unsigned
111+
return typesMatches(o.DBType, col.Type, engine == "postgresql") && o.Nullable != notNull && o.Unsigned == col.Unsigned
84112
}
85113

86114
func (o *Override) parse(req *plugin.GenerateRequest) (err error) {

internal/codegen/golang/opts/override_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"testing"
55

66
"github.com/google/go-cmp/cmp"
7+
"github.com/sqlc-dev/sqlc/internal/plugin"
78
)
89

910
func TestTypeOverrides(t *testing.T) {
@@ -115,3 +116,88 @@ func FuzzOverride(f *testing.F) {
115116
o.parse(nil)
116117
})
117118
}
119+
120+
func TestOverride_MatchesColumn(t *testing.T) {
121+
t.Parallel()
122+
type testCase struct {
123+
specName string
124+
override Override
125+
Column *plugin.Column
126+
engine string
127+
expected bool
128+
}
129+
130+
testCases := []*testCase{
131+
{
132+
specName: "matches with pg_catalog in schema and name",
133+
override: Override{
134+
DBType: "json",
135+
Nullable: false,
136+
},
137+
Column: &plugin.Column{
138+
Name: "data",
139+
Type: &plugin.Identifier{
140+
Schema: "pg_catalog",
141+
Name: "json",
142+
},
143+
NotNull: true,
144+
IsArray: false,
145+
},
146+
engine: "postgresql",
147+
expected: true,
148+
},
149+
{
150+
specName: "matches only with name",
151+
override: Override{
152+
DBType: "json",
153+
Nullable: false,
154+
},
155+
Column: &plugin.Column{
156+
Name: "data",
157+
Type: &plugin.Identifier{
158+
Name: "json",
159+
},
160+
NotNull: true,
161+
IsArray: false,
162+
},
163+
engine: "postgresql",
164+
expected: true,
165+
},
166+
{
167+
specName: "matches with pg_catalog in name",
168+
override: Override{
169+
DBType: "json",
170+
Nullable: false,
171+
},
172+
Column: &plugin.Column{
173+
Name: "data",
174+
Type: &plugin.Identifier{
175+
Name: "pg_catalog.json",
176+
},
177+
NotNull: true,
178+
IsArray: false,
179+
},
180+
engine: "postgresql",
181+
expected: true,
182+
},
183+
}
184+
185+
for _, test := range testCases {
186+
tt := *test
187+
t.Run(tt.specName, func(t *testing.T) {
188+
result := tt.override.MatchesColumn(tt.Column, tt.engine)
189+
if result != tt.expected {
190+
t.Errorf("mismatch; got %v; want %v", result, tt.expected)
191+
}
192+
if tt.engine == "postgresql" && tt.expected == true {
193+
tt.override.DBType = "pg_catalog." + tt.override.DBType
194+
result = tt.override.MatchesColumn(test.Column, tt.engine)
195+
if !result {
196+
t.Errorf("mismatch; got %v; want %v", result, tt.expected)
197+
}
198+
}
199+
200+
})
201+
202+
}
203+
}

0 commit comments

Comments
 (0)