|  | 
| 1 | 1 | #!/usr/bin/env python3 | 
| 2 | 2 | # -*- coding: utf-8 -*- | 
| 3 |  | -from fastapi import Request | 
|  | 3 | +from typing import Any | 
| 4 | 4 | 
 | 
| 5 |  | -from fastapi_oauth20.errors import RedirectURIError | 
|  | 5 | +import httpx | 
|  | 6 | + | 
|  | 7 | +from fastapi import HTTPException, Request | 
|  | 8 | + | 
|  | 9 | +from fastapi_oauth20.errors import AccessTokenError, HTTPXOAuth20Error, OAuth20BaseError | 
| 6 | 10 | from fastapi_oauth20.oauth20 import OAuth20Base | 
| 7 | 11 | 
 | 
| 8 | 12 | 
 | 
|  | 13 | +class OAuth20AuthorizeCallbackError(HTTPException, OAuth20BaseError): | 
|  | 14 | +    """The OAuth2 authorization callback error.""" | 
|  | 15 | + | 
|  | 16 | +    def __init__( | 
|  | 17 | +        self, | 
|  | 18 | +        status_code: int, | 
|  | 19 | +        detail: Any = None, | 
|  | 20 | +        headers: dict[str, str] | None = None, | 
|  | 21 | +        response: httpx.Response | None = None, | 
|  | 22 | +    ) -> None: | 
|  | 23 | +        self.response = response | 
|  | 24 | +        super().__init__(status_code=status_code, detail=detail, headers=headers) | 
|  | 25 | + | 
|  | 26 | + | 
| 9 | 27 | class FastAPIOAuth20: | 
| 10 | 28 |     def __init__( | 
| 11 | 29 |         self, | 
| 12 | 30 |         client: OAuth20Base, | 
| 13 | 31 |         redirect_uri: str | None = None, | 
| 14 |  | -        oauth_callback_route_name: str | None = None, | 
|  | 32 | +        oauth2_callback_route_name: str | None = None, | 
| 15 | 33 |     ): | 
|  | 34 | +        """ | 
|  | 35 | +        OAuth2 authorization callback dependency injection | 
|  | 36 | +
 | 
|  | 37 | +        :param client: A client base on OAuth20Base. | 
|  | 38 | +        :param redirect_uri: OAuth2 callback full URL. | 
|  | 39 | +        :param oauth2_callback_route_name: OAuth2 callback route name, as defined by the route decorator 'name' parameter. | 
|  | 40 | +        """ | 
|  | 41 | +        assert (redirect_uri is None and oauth2_callback_route_name is not None) or ( | 
|  | 42 | +            redirect_uri is not None and oauth2_callback_route_name is None | 
|  | 43 | +        ), 'FastAPIOAuth20 redirect_uri and oauth2_callback_route_name cannot be defined at the same time.' | 
| 16 | 44 |         self.client = client | 
| 17 |  | -        self.oauth_callback_route_name = oauth_callback_route_name | 
| 18 | 45 |         self.redirect_uri = redirect_uri | 
|  | 46 | +        self.oauth2_callback_route_name = oauth2_callback_route_name | 
| 19 | 47 | 
 | 
| 20 | 48 |     async def __call__( | 
| 21 | 49 |         self, | 
| 22 | 50 |         request: Request, | 
| 23 |  | -        code: str, | 
|  | 51 | +        code: str | None = None, | 
| 24 | 52 |         state: str | None = None, | 
| 25 | 53 |         code_verifier: str | None = None, | 
|  | 54 | +        error: str | None = None, | 
| 26 | 55 |     ) -> tuple[dict, str]: | 
| 27 |  | -        if self.redirect_uri is None: | 
| 28 |  | -            if self.oauth_callback_route_name is None: | 
| 29 |  | -                raise RedirectURIError('redirect_uri is required') | 
| 30 |  | -            self.redirect_uri = str(request.url_for(self.oauth_callback_route_name)) | 
| 31 |  | - | 
| 32 |  | -        access_token = await self.client.get_access_token( | 
| 33 |  | -            code=code, redirect_uri=self.redirect_uri, code_verifier=code_verifier | 
| 34 |  | -        ) | 
|  | 56 | +        if code is None or error is not None: | 
|  | 57 | +            raise OAuth20AuthorizeCallbackError( | 
|  | 58 | +                status_code=400, | 
|  | 59 | +                detail=error if error is not None else None, | 
|  | 60 | +            ) | 
|  | 61 | + | 
|  | 62 | +        if self.oauth2_callback_route_name: | 
|  | 63 | +            redirect_url = str(request.url_for(self.oauth2_callback_route_name)) | 
|  | 64 | +        else: | 
|  | 65 | +            redirect_url = self.redirect_uri | 
|  | 66 | + | 
|  | 67 | +        try: | 
|  | 68 | +            access_token = await self.client.get_access_token( | 
|  | 69 | +                code=code, | 
|  | 70 | +                redirect_uri=redirect_url, | 
|  | 71 | +                code_verifier=code_verifier, | 
|  | 72 | +            ) | 
|  | 73 | +        except (HTTPXOAuth20Error, AccessTokenError) as e: | 
|  | 74 | +            raise OAuth20AuthorizeCallbackError( | 
|  | 75 | +                status_code=500, | 
|  | 76 | +                detail=e.msg, | 
|  | 77 | +                response=e.response, | 
|  | 78 | +            ) from e | 
| 35 | 79 | 
 | 
| 36 | 80 |         return access_token, state | 
0 commit comments