1
+ # pyright: reportMissingImports=false
2
+ import os
3
+ from urllib .parse import urljoin
4
+ from dotenv import load_dotenv # type: ignore
5
+ from typing import Any , cast
6
+ import base64 , json , time
7
+ from starlette .requests import Request # type: ignore
8
+
9
+ from mcp .server .fastmcp .server import Context
10
+ from mcp .server .auth .proxy .server import build_proxy_server # noqa: E402
11
+
12
+ # Load environment variables from .env if present
13
+ load_dotenv ()
14
+
15
+ # ----------------------------------------------------------------------------
16
+ # Environment configuration
17
+ # ----------------------------------------------------------------------------
18
+ _upstream_base = os .getenv ("PROXY_UPSTREAM_BASE" , "https://auth.example.com/" )
19
+ # Sanitize trailing slash
20
+ if _upstream_base .endswith ("/" ):
21
+ _upstream_base = _upstream_base [:- 1 ]
22
+ UPSTREAM_BASE : str = _upstream_base
23
+
24
+ print ("[Proxy Config] UPSTREAM_BASE:" , UPSTREAM_BASE )
25
+ print ("[Proxy Config] CLIENT_ID:" , os .getenv ("PROXY_CLIENT_ID" ) or os .getenv ("UPSTREAM_CLIENT_ID" ))
26
+
27
+ CLIENT_ID = os .getenv ("PROXY_CLIENT_ID" ) or os .getenv ("UPSTREAM_CLIENT_ID" ) or "demo-client-id"
28
+ CLIENT_SECRET = os .getenv ("PROXY_CLIENT_SECRET" ) or os .getenv ("UPSTREAM_CLIENT_SECRET" ) # may be None
29
+ DEFAULT_SCOPE = os .getenv ("PROXY_DEFAULT_SCOPE" , "openid profile email" )
30
+ AUDIENCE = os .getenv ("PROXY_AUDIENCE" ) # optional
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # Resolve upstream endpoints – prefer explicit *_ENDPOINT variables (matches
34
+ # naming used in fastmcp example) and fall back to BASE + path.
35
+ # ---------------------------------------------------------------------------
36
+
37
+ UPSTREAM_AUTHORIZE = os .getenv ("UPSTREAM_AUTHORIZATION_ENDPOINT" ) or f"{ UPSTREAM_BASE } /authorize"
38
+ UPSTREAM_TOKEN = os .getenv ("UPSTREAM_TOKEN_ENDPOINT" ) or f"{ UPSTREAM_BASE } /token"
39
+ UPSTREAM_JWKS_URI = os .getenv ("UPSTREAM_JWKS_URI" )
40
+ UPSTREAM_REVOCATION = os .getenv ("UPSTREAM_REVOCATION_ENDPOINT" ) or f"{ UPSTREAM_BASE } /revoke"
41
+
42
+ # Metadata URL (only used if we need to fetch from upstream)
43
+ UPSTREAM_METADATA = f"{ UPSTREAM_BASE } /.well-known/oauth-authorization-server"
44
+
45
+ print ("[Proxy Config] UPSTREAM_AUTHORIZE:" , UPSTREAM_AUTHORIZE )
46
+ print ("[Proxy Config] UPSTREAM_TOKEN:" , UPSTREAM_TOKEN )
47
+
48
+ # Server host/port
49
+ PROXY_PORT = int (os .getenv ("PROXY_PORT" , "8000" ))
50
+
51
+ # ----------------------------------------------------------------------------
52
+ # FastMCP server (now created via library helper)
53
+ # ----------------------------------------------------------------------------
54
+
55
+ ISSUER_URL = os .getenv ("PROXY_ISSUER_URL" , "http://localhost:8000" )
56
+
57
+ # Create FastMCP instance using the reusable proxy builder
58
+ mcp = build_proxy_server (port = PROXY_PORT , issuer_url = ISSUER_URL )
59
+
60
+ # ---------------------------------------------------------------------------
61
+ # Minimal demo tool
62
+ # ---------------------------------------------------------------------------
63
+
64
+ @mcp .tool ()
65
+ def echo (message : str ) -> str :
66
+ return f"Echo: { message } "
67
+
68
+
69
+ @mcp .tool ()
70
+ async def user_info (ctx : Context [Any , Any , Request ]) -> dict [str , Any ]:
71
+ """
72
+ Get information about the authenticated user.
73
+
74
+ This tool demonstrates accessing user information from the OAuth access token.
75
+ The user must be authenticated via OAuth to access this tool.
76
+
77
+ Returns:
78
+ Dictionary containing user information from the access token
79
+ """
80
+ from mcp .server .auth .middleware .auth_context import get_access_token
81
+
82
+ # Get the access token from the authentication context
83
+ access_token = get_access_token ()
84
+
85
+ if not access_token :
86
+ return {
87
+ "error" : "No access token found - user not authenticated" ,
88
+ "authenticated" : False
89
+ }
90
+
91
+ # Attempt to decode the access token as JWT to extract useful user claims.
92
+ # Many OAuth providers issue JWT access tokens (or ID tokens) that contain
93
+ # the user's subject (sub) and preferred username. We parse the token
94
+ # *without* signature verification – we only need the public claims for
95
+ # display purposes. If the token is opaque or the decode fails, we simply
96
+ # skip this step.
97
+
98
+ def _try_decode_jwt (token_str : str ) -> dict [str , Any ] | None : # noqa: D401
99
+ """Best-effort JWT decode without verification.
100
+
101
+ Returns the payload dictionary if the token *looks* like a JWT and can
102
+ be base64-decoded. If anything fails we return None.
103
+ """
104
+
105
+ try :
106
+ parts = token_str .split ("." )
107
+ if len (parts ) != 3 :
108
+ return None # Not a JWT
109
+
110
+ # JWT parts are URL-safe base64 without padding
111
+ def _b64decode (segment : str ) -> bytes :
112
+ padding = "=" * (- len (segment ) % 4 )
113
+ return base64 .urlsafe_b64decode (segment + padding )
114
+
115
+ payload_bytes = _b64decode (parts [1 ])
116
+ return json .loads (payload_bytes )
117
+ except Exception : # noqa: BLE001
118
+ return None
119
+
120
+ jwt_claims = _try_decode_jwt (access_token .token )
121
+
122
+ # Build response with token information plus any extracted claims
123
+ response : dict [str , Any ] = {
124
+ "authenticated" : True ,
125
+ "client_id" : access_token .client_id ,
126
+ "scopes" : access_token .scopes ,
127
+ "token_type" : "Bearer" ,
128
+ "expires_at" : access_token .expires_at ,
129
+ "resource" : access_token .resource ,
130
+ }
131
+
132
+ if jwt_claims :
133
+ # Prefer the `userid` claim used in FastMCP examples; fall back to `sub` if absent.
134
+ uid = jwt_claims .get ("userid" ) or jwt_claims .get ("sub" )
135
+ if uid is not None :
136
+ response ["userid" ] = uid # camelCase variant used in FastMCP reference
137
+ response ["user_id" ] = uid # snake_case variant
138
+ response ["username" ] = (
139
+ jwt_claims .get ("preferred_username" )
140
+ or jwt_claims .get ("nickname" )
141
+ or jwt_claims .get ("name" )
142
+ )
143
+ response ["issuer" ] = jwt_claims .get ("iss" )
144
+ response ["audience" ] = jwt_claims .get ("aud" )
145
+ response ["issued_at" ] = jwt_claims .get ("iat" )
146
+
147
+ # Calculate expiration helpers
148
+ if access_token .expires_at :
149
+ response ["expires_at_iso" ] = time .strftime ('%Y-%m-%dT%H:%M:%S' , time .localtime (access_token .expires_at ))
150
+ response ["expires_in_seconds" ] = max (0 , access_token .expires_at - int (time .time ()))
151
+
152
+ return response
153
+
154
+
155
+ @mcp .tool ()
156
+ async def test_endpoint (message : str = "Hello from proxy server!" ) -> dict [str , Any ]:
157
+ """
158
+ Test endpoint for debugging OAuth proxy functionality.
159
+
160
+ Args:
161
+ message: Optional message to echo back
162
+
163
+ Returns:
164
+ Test response with server information
165
+ """
166
+ return {
167
+ "message" : message ,
168
+ "server" : "Transparent OAuth Proxy Server" ,
169
+ "status" : "active" ,
170
+ "oauth_configured" : True
171
+ }
172
+
173
+
174
+ if __name__ == "__main__" :
175
+ mcp .run (transport = "streamable-http" )
0 commit comments