Skip to content

Commit cab12e7

Browse files
committed
Add support for client registration + group enforcement
If configured, dynamically register and renew a client. Add support for filtering groups to allow access.
1 parent 3f98666 commit cab12e7

File tree

7 files changed

+544
-9
lines changed

7 files changed

+544
-9
lines changed

config.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ type upstream struct {
5454
OIDCClientID string `json:"oidcClientID"`
5555
// OIDCClientSecret sets the OIDC client secret
5656
OIDCClientSecret string `json:"oidcClientSecret"`
57+
58+
// OIDCRegisterClient enables auto-registration of the OIDC client the
59+
// issuer. If used, the client id and client secret are ignored.
60+
OIDCRegisterClient bool `json:"oidcRegisterClient"`
61+
// OIDCRequireGroups requires the user to be in one of the groups listed in
62+
// the OIDC groups claim. This will automatically add the `groups` scope.
63+
OIDCRequireGroups []string `json:"oidcRequireGroups"`
5764
}
5865

5966
type kubernetesConfig struct {
@@ -103,10 +110,10 @@ func parseAndValidateConfig(cfg []byte) (config, error) {
103110
}
104111

105112
if u.OIDCIssuer != "" {
106-
if u.OIDCClientID == "" {
113+
if u.OIDCClientID == "" && !u.OIDCRegisterClient {
107114
verr = errors.Join(verr, fmt.Errorf("upstream %s oidcClientID required", u.Name))
108115
}
109-
if u.OIDCClientSecret == "" {
116+
if u.OIDCClientSecret == "" && !u.OIDCRegisterClient {
110117
verr = errors.Join(verr, fmt.Errorf("upstream %s oidcClientSecret required", u.Name))
111118
}
112119
}

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ require (
1010
k8s.io/api v0.34.1
1111
k8s.io/apimachinery v0.34.1
1212
k8s.io/client-go v0.34.1
13-
lds.li/oauth2ext v0.0.0-20250914000806-c3b2c2b5b83a
13+
lds.li/oauth2ext v0.0.0-20250914133403-f15f6850f142
1414
tailscale.com v1.88.1
1515
)
1616

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,8 @@ k8s.io/utils v0.0.0-20250820121507-0af2bda4dd1d h1:wAhiDyZ4Tdtt7e46e9M5ZSAJ/MnPG
362362
k8s.io/utils v0.0.0-20250820121507-0af2bda4dd1d/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
363363
lds.li/oauth2ext v0.0.0-20250914000806-c3b2c2b5b83a h1:zfJ6WGu4U0UJmuajE/TCgMgR8PaCSMluazfUI8lsopc=
364364
lds.li/oauth2ext v0.0.0-20250914000806-c3b2c2b5b83a/go.mod h1:hM8whxxUy2hC0nsgxAYIbCMF+W2mzhRokz15zCkYFwA=
365+
lds.li/oauth2ext v0.0.0-20250914133403-f15f6850f142 h1:ZcWyDcAjS4TAjjEZQIGaKim1rKDBhGNYCuisYme7ia8=
366+
lds.li/oauth2ext v0.0.0-20250914133403-f15f6850f142/go.mod h1:hM8whxxUy2hC0nsgxAYIbCMF+W2mzhRokz15zCkYFwA=
365367
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 h1:IpInykpT6ceI+QxKBbEflcR5EXP7sU1kvOlxwZh5txg=
366368
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg=
367369
sigs.k8s.io/randfill v1.0.0 h1:JfjMILfT8A6RbawdsK2JXGBR5AQVfd+9TbzrlneTyrU=

go.work

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
go 1.25.1
2+
3+
use (
4+
.
5+
../oauth2ext
6+
)

go.work.sum

Lines changed: 341 additions & 0 deletions
Large diffs are not rendered by default.

main.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -329,20 +329,19 @@ func tsproxy(ctx context.Context) error {
329329
// TODO pass public paths direct to the proxy
330330
mux := http.NewServeMux()
331331

332+
// process these first, so they take precedence over the OIDC
333+
// auth.
332334
for _, p := range upstream.FunnelPublicPatterns {
333335
mux.Handle(p, rp)
334336
}
335337

336338
if upstream.OIDCIssuer != "" {
337-
baseURL := "https://" + strings.TrimSuffix(st.Self.DNSName, ".")
338-
339-
oidcm, err := oidcmiddleware.NewFromDiscovery(ctx, nil, upstream.OIDCIssuer, upstream.OIDCClientID, upstream.OIDCClientSecret, baseURL+"/.tsproxy/oidc-callback")
339+
mw, err := buildMiddlewareForUpstream(ctx, st, upstream)
340340
if err != nil {
341-
return fmt.Errorf("oidc: new middleware: %w", err)
341+
return fmt.Errorf("oidc: build middleware: %w", err)
342342
}
343-
oidcm.OAuth2Config.Scopes = append(oidcm.OAuth2Config.Scopes, "profile", "email")
344343

345-
mux.Handle("/", oidcm.Wrap(rp)) // fallback to authed path.
344+
mux.Handle("/", mw(rp)) // fallback to authed path.
346345
} else if !slices.Contains(upstream.FunnelPublicPatterns, "/") {
347346
// no OIDC auth, no root pattern, default behaviour is to block.
348347
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {

oidc.go

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log/slog"
7+
"net/http"
8+
"slices"
9+
"strings"
10+
"time"
11+
12+
"golang.org/x/oauth2"
13+
"lds.li/oauth2ext/oidc"
14+
"lds.li/oauth2ext/oidcclientreg"
15+
"lds.li/oauth2ext/oidcmiddleware"
16+
"tailscale.com/ipn/ipnstate"
17+
)
18+
19+
type groupClaims struct {
20+
Groups []string `json:"groups"`
21+
}
22+
23+
func buildMiddlewareForUpstream(ctx context.Context, st *ipnstate.Status, upstream upstream) (func(http.Handler) http.Handler, error) {
24+
baseURL := "https://" + strings.TrimSuffix(st.Self.DNSName, ".")
25+
26+
p, err := oidc.DiscoverProvider(ctx, upstream.OIDCIssuer)
27+
if err != nil {
28+
return nil, fmt.Errorf("oidc: discover: %w", err)
29+
}
30+
31+
oidcOAuth2ConfigFn, err := oidcOAuth2ConfigFn(ctx, upstream, baseURL+"/.tsproxy/oidc-callback")
32+
if err != nil {
33+
return nil, fmt.Errorf("oidc: oauth2 config: %w", err)
34+
}
35+
36+
gmw := func(h http.Handler) http.Handler {
37+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
38+
if len(upstream.OIDCRequireGroups) == 0 {
39+
h.ServeHTTP(w, r)
40+
return
41+
}
42+
43+
cl, ok := oidcmiddleware.IDClaimsFromContext(r.Context())
44+
if !ok {
45+
slog.Error("oidc: missing id claims")
46+
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
47+
return
48+
}
49+
50+
var gc groupClaims
51+
if err := cl.UnmarshalClaims(&gc); err != nil {
52+
slog.Error("oidc: unmarshalling group claims", "error", err)
53+
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
54+
return
55+
}
56+
57+
allowed := false
58+
for _, required := range upstream.OIDCRequireGroups {
59+
if slices.Contains(gc.Groups, required) {
60+
allowed = true
61+
break
62+
}
63+
}
64+
if allowed {
65+
h.ServeHTTP(w, r)
66+
return
67+
}
68+
69+
slog.WarnContext(r.Context(), "oidc: user not in required groups", "upstream", upstream.Name, "groups", gc.Groups)
70+
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
71+
})
72+
}
73+
74+
omw := &oidcmiddleware.Handler{
75+
Provider: p,
76+
OAuth2ConfigSource: oidcOAuth2ConfigFn,
77+
SessionStore: &oidcmiddleware.Cookiestore{},
78+
}
79+
80+
return func(h http.Handler) http.Handler {
81+
return omw.Wrap(gmw(h))
82+
}, nil
83+
}
84+
85+
func oidcOAuth2ConfigFn(ctx context.Context, upstream upstream, redirURL string) (func(context.Context) (oauth2.Config, error), error) {
86+
p, err := oidc.DiscoverProvider(ctx, upstream.OIDCIssuer)
87+
if err != nil {
88+
return nil, fmt.Errorf("oidc: discover: %w", err)
89+
}
90+
91+
scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail}
92+
if len(upstream.OIDCRequireGroups) > 0 {
93+
scopes = append(scopes, "groups")
94+
}
95+
96+
o2cfg := &oauth2.Config{
97+
Endpoint: p.Endpoint(),
98+
Scopes: scopes,
99+
RedirectURL: redirURL,
100+
}
101+
102+
if !upstream.OIDCRegisterClient {
103+
// just return something that uses the current config.
104+
return func(ctx context.Context) (oauth2.Config, error) {
105+
c := *o2cfg
106+
c.ClientID = upstream.OIDCClientID
107+
c.ClientSecret = upstream.OIDCClientSecret
108+
return c, nil
109+
}, nil
110+
}
111+
112+
// otherwise register a client, and run a routine to update the config.
113+
regResp, err := registerOIDCClient(ctx, upstream, p)
114+
if err != nil {
115+
return nil, fmt.Errorf("oidc: register client: %w", err)
116+
}
117+
118+
o2cfg.ClientID = regResp.ClientID
119+
o2cfg.ClientSecret = regResp.ClientSecret
120+
121+
slog.Info("oidc: registered client", "upstream", upstream.Name, "client_id", o2cfg.ClientID)
122+
123+
reRegisterDelay := time.Hour // for now, to exercise it.
124+
// if regResp.ClientSecretExpiresAt != nil {
125+
// ttl := time.Duration(*regResp.ClientSecretExpiresAt-time.Now().Unix()) * time.Second
126+
// reRegisterDelay = time.Duration(float64(ttl) * 0.75)
127+
// } else {
128+
// reRegisterDelay = time.Hour // TODO set a better default
129+
// }
130+
131+
go func() {
132+
for {
133+
select {
134+
case <-ctx.Done():
135+
return
136+
case <-time.After(reRegisterDelay):
137+
newRegResp, err := registerOIDCClient(ctx, upstream, p)
138+
if err != nil {
139+
slog.Error("oidc: failed to register client", "upstream", upstream.Name, "error", err)
140+
// If registration fails, try again sooner
141+
reRegisterDelay = time.Minute
142+
continue
143+
}
144+
slog.Info("oidc: registered client", "upstream", upstream.Name, "client_id", o2cfg.ClientID)
145+
o2cfg.ClientID = newRegResp.ClientID
146+
o2cfg.ClientSecret = newRegResp.ClientSecret
147+
// On success, wait longer before next registration
148+
// TODO - same calculation as above
149+
reRegisterDelay = time.Hour
150+
}
151+
}
152+
}()
153+
154+
return func(ctx context.Context) (oauth2.Config, error) {
155+
return *o2cfg, nil
156+
}, nil
157+
}
158+
159+
// registerOIDCClient performs dynamic client registration with the OIDC provider
160+
func registerOIDCClient(ctx context.Context, upstream upstream, provider *oidc.Provider) (*oidcclientreg.ClientRegistrationResponse, error) {
161+
// Create registration request
162+
request := &oidcclientreg.ClientRegistrationRequest{
163+
ClientName: fmt.Sprintf("tsproxy-%s", upstream.Name),
164+
RedirectURIs: []string{"http://127.0.0.1/callback"},
165+
ApplicationType: "web",
166+
ResponseTypes: []string{"code"},
167+
GrantTypes: []string{"authorization_code"},
168+
}
169+
170+
if slices.Contains(provider.Metadata.IDTokenSigningAlgValuesSupported, "ES256") {
171+
request.IDTokenSignedResponseAlg = "ES256"
172+
}
173+
174+
response, err := oidcclientreg.RegisterWithProvider(ctx, provider, request)
175+
if err != nil {
176+
return nil, fmt.Errorf("failed to register client: %w", err)
177+
}
178+
179+
return response, nil
180+
}

0 commit comments

Comments
 (0)