|
| 1 | +package main |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + "log/slog" |
| 7 | + "net/http" |
| 8 | + "slices" |
| 9 | + "strings" |
| 10 | + "time" |
| 11 | + |
| 12 | + "golang.org/x/oauth2" |
| 13 | + "lds.li/oauth2ext/oidc" |
| 14 | + "lds.li/oauth2ext/oidcclientreg" |
| 15 | + "lds.li/oauth2ext/oidcmiddleware" |
| 16 | + "tailscale.com/ipn/ipnstate" |
| 17 | +) |
| 18 | + |
| 19 | +type groupClaims struct { |
| 20 | + Groups []string `json:"groups"` |
| 21 | +} |
| 22 | + |
| 23 | +func buildMiddlewareForUpstream(ctx context.Context, st *ipnstate.Status, upstream upstream) (func(http.Handler) http.Handler, error) { |
| 24 | + baseURL := "https://" + strings.TrimSuffix(st.Self.DNSName, ".") |
| 25 | + |
| 26 | + p, err := oidc.DiscoverProvider(ctx, upstream.OIDCIssuer) |
| 27 | + if err != nil { |
| 28 | + return nil, fmt.Errorf("oidc: discover: %w", err) |
| 29 | + } |
| 30 | + |
| 31 | + oidcOAuth2ConfigFn, err := oidcOAuth2ConfigFn(ctx, upstream, baseURL+"/.tsproxy/oidc-callback") |
| 32 | + if err != nil { |
| 33 | + return nil, fmt.Errorf("oidc: oauth2 config: %w", err) |
| 34 | + } |
| 35 | + |
| 36 | + gmw := func(h http.Handler) http.Handler { |
| 37 | + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 38 | + if len(upstream.OIDCRequireGroups) == 0 { |
| 39 | + h.ServeHTTP(w, r) |
| 40 | + return |
| 41 | + } |
| 42 | + |
| 43 | + cl, ok := oidcmiddleware.IDClaimsFromContext(r.Context()) |
| 44 | + if !ok { |
| 45 | + slog.Error("oidc: missing id claims") |
| 46 | + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) |
| 47 | + return |
| 48 | + } |
| 49 | + |
| 50 | + var gc groupClaims |
| 51 | + if err := cl.UnmarshalClaims(&gc); err != nil { |
| 52 | + slog.Error("oidc: unmarshalling group claims", "error", err) |
| 53 | + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) |
| 54 | + return |
| 55 | + } |
| 56 | + |
| 57 | + allowed := false |
| 58 | + for _, required := range upstream.OIDCRequireGroups { |
| 59 | + if slices.Contains(gc.Groups, required) { |
| 60 | + allowed = true |
| 61 | + break |
| 62 | + } |
| 63 | + } |
| 64 | + if allowed { |
| 65 | + h.ServeHTTP(w, r) |
| 66 | + return |
| 67 | + } |
| 68 | + |
| 69 | + slog.WarnContext(r.Context(), "oidc: user not in required groups", "upstream", upstream.Name, "groups", gc.Groups) |
| 70 | + http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) |
| 71 | + }) |
| 72 | + } |
| 73 | + |
| 74 | + omw := &oidcmiddleware.Handler{ |
| 75 | + Provider: p, |
| 76 | + OAuth2ConfigSource: oidcOAuth2ConfigFn, |
| 77 | + SessionStore: &oidcmiddleware.Cookiestore{}, |
| 78 | + } |
| 79 | + |
| 80 | + return func(h http.Handler) http.Handler { |
| 81 | + return omw.Wrap(gmw(h)) |
| 82 | + }, nil |
| 83 | +} |
| 84 | + |
| 85 | +func oidcOAuth2ConfigFn(ctx context.Context, upstream upstream, redirURL string) (func(context.Context) (oauth2.Config, error), error) { |
| 86 | + p, err := oidc.DiscoverProvider(ctx, upstream.OIDCIssuer) |
| 87 | + if err != nil { |
| 88 | + return nil, fmt.Errorf("oidc: discover: %w", err) |
| 89 | + } |
| 90 | + |
| 91 | + scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail} |
| 92 | + if len(upstream.OIDCRequireGroups) > 0 { |
| 93 | + scopes = append(scopes, "groups") |
| 94 | + } |
| 95 | + |
| 96 | + o2cfg := &oauth2.Config{ |
| 97 | + Endpoint: p.Endpoint(), |
| 98 | + Scopes: scopes, |
| 99 | + RedirectURL: redirURL, |
| 100 | + } |
| 101 | + |
| 102 | + if !upstream.OIDCRegisterClient { |
| 103 | + // just return something that uses the current config. |
| 104 | + return func(ctx context.Context) (oauth2.Config, error) { |
| 105 | + c := *o2cfg |
| 106 | + c.ClientID = upstream.OIDCClientID |
| 107 | + c.ClientSecret = upstream.OIDCClientSecret |
| 108 | + return c, nil |
| 109 | + }, nil |
| 110 | + } |
| 111 | + |
| 112 | + // otherwise register a client, and run a routine to update the config. |
| 113 | + regResp, err := registerOIDCClient(ctx, upstream, p) |
| 114 | + if err != nil { |
| 115 | + return nil, fmt.Errorf("oidc: register client: %w", err) |
| 116 | + } |
| 117 | + |
| 118 | + o2cfg.ClientID = regResp.ClientID |
| 119 | + o2cfg.ClientSecret = regResp.ClientSecret |
| 120 | + |
| 121 | + slog.Info("oidc: registered client", "upstream", upstream.Name, "client_id", o2cfg.ClientID) |
| 122 | + |
| 123 | + reRegisterDelay := time.Hour // for now, to exercise it. |
| 124 | + // if regResp.ClientSecretExpiresAt != nil { |
| 125 | + // ttl := time.Duration(*regResp.ClientSecretExpiresAt-time.Now().Unix()) * time.Second |
| 126 | + // reRegisterDelay = time.Duration(float64(ttl) * 0.75) |
| 127 | + // } else { |
| 128 | + // reRegisterDelay = time.Hour // TODO set a better default |
| 129 | + // } |
| 130 | + |
| 131 | + go func() { |
| 132 | + for { |
| 133 | + select { |
| 134 | + case <-ctx.Done(): |
| 135 | + return |
| 136 | + case <-time.After(reRegisterDelay): |
| 137 | + newRegResp, err := registerOIDCClient(ctx, upstream, p) |
| 138 | + if err != nil { |
| 139 | + slog.Error("oidc: failed to register client", "upstream", upstream.Name, "error", err) |
| 140 | + // If registration fails, try again sooner |
| 141 | + reRegisterDelay = time.Minute |
| 142 | + continue |
| 143 | + } |
| 144 | + slog.Info("oidc: registered client", "upstream", upstream.Name, "client_id", o2cfg.ClientID) |
| 145 | + o2cfg.ClientID = newRegResp.ClientID |
| 146 | + o2cfg.ClientSecret = newRegResp.ClientSecret |
| 147 | + // On success, wait longer before next registration |
| 148 | + // TODO - same calculation as above |
| 149 | + reRegisterDelay = time.Hour |
| 150 | + } |
| 151 | + } |
| 152 | + }() |
| 153 | + |
| 154 | + return func(ctx context.Context) (oauth2.Config, error) { |
| 155 | + return *o2cfg, nil |
| 156 | + }, nil |
| 157 | +} |
| 158 | + |
| 159 | +// registerOIDCClient performs dynamic client registration with the OIDC provider |
| 160 | +func registerOIDCClient(ctx context.Context, upstream upstream, provider *oidc.Provider) (*oidcclientreg.ClientRegistrationResponse, error) { |
| 161 | + // Create registration request |
| 162 | + request := &oidcclientreg.ClientRegistrationRequest{ |
| 163 | + ClientName: fmt.Sprintf("tsproxy-%s", upstream.Name), |
| 164 | + RedirectURIs: []string{"http://127.0.0.1/callback"}, |
| 165 | + ApplicationType: "web", |
| 166 | + ResponseTypes: []string{"code"}, |
| 167 | + GrantTypes: []string{"authorization_code"}, |
| 168 | + } |
| 169 | + |
| 170 | + if slices.Contains(provider.Metadata.IDTokenSigningAlgValuesSupported, "ES256") { |
| 171 | + request.IDTokenSignedResponseAlg = "ES256" |
| 172 | + } |
| 173 | + |
| 174 | + response, err := oidcclientreg.RegisterWithProvider(ctx, provider, request) |
| 175 | + if err != nil { |
| 176 | + return nil, fmt.Errorf("failed to register client: %w", err) |
| 177 | + } |
| 178 | + |
| 179 | + return response, nil |
| 180 | +} |
0 commit comments