|
1 | 1 | """OpenAPI utils.""" |
2 | | -from typing import Any, Dict, List, Optional, Sequence, Union |
| 2 | +from typing import Any, Dict, Sequence, Tuple |
3 | 3 |
|
4 | | -from fastapi import routing |
5 | 4 | from fastapi.encoders import jsonable_encoder |
6 | 5 | from fastapi.openapi.models import OpenAPI |
7 | | -from fastapi.openapi.utils import get_flat_models_from_routes, get_openapi_path |
8 | | -from fastapi.utils import get_model_definitions |
| 6 | +from fastapi.openapi.utils import get_openapi |
| 7 | +from fastapi.security.base import SecurityBase |
9 | 8 | from pybotx_smartapp_rpc import RPCRouter |
10 | 9 | from pybotx_smartapp_rpc.openapi_utils import ( |
11 | 10 | get_rpc_flat_models_from_routes, |
| 11 | + get_rpc_model_definitions, |
12 | 12 | get_rpc_openapi_path, |
13 | 13 | ) |
14 | 14 | from pydantic.schema import get_model_name_map |
15 | 15 | from starlette.routing import BaseRoute |
16 | 16 |
|
| 17 | +from app.services.execute_rpc import security |
| 18 | + |
| 19 | + |
| 20 | +def get_openapi_security_definitions( |
| 21 | + security_component: SecurityBase, |
| 22 | +) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
| 23 | + security_definition = jsonable_encoder( |
| 24 | + security_component.model, |
| 25 | + by_alias=True, |
| 26 | + exclude_none=True, |
| 27 | + ) |
| 28 | + security_name = security_component.scheme_name |
| 29 | + security_definitions = {security_name: security_definition} |
| 30 | + operation_security = {security_name: []} # type: ignore |
| 31 | + return security_definitions, operation_security |
| 32 | + |
17 | 33 |
|
18 | 34 | def custom_openapi( |
19 | 35 | *, |
20 | 36 | title: str, |
21 | 37 | version: str, |
22 | | - openapi_version: str = "3.0.2", |
23 | | - description: Optional[str] = None, |
24 | 38 | fastapi_routes: Sequence[BaseRoute], |
25 | 39 | rpc_router: RPCRouter, |
26 | | - tags: Optional[List[Dict[str, Any]]] = None, |
27 | | - servers: Optional[List[Dict[str, Union[str, Any]]]] = None, |
28 | | - terms_of_service: Optional[str] = None, |
29 | | - contact: Optional[Dict[str, Union[str, Any]]] = None, |
30 | | - license_info: Optional[Dict[str, Union[str, Any]]] = None, |
| 40 | + **kwargs: Any, |
31 | 41 | ) -> Dict[str, Any]: |
32 | | - info: Dict[str, Any] = {"title": title, "version": version} |
33 | | - if description: |
34 | | - info["description"] = description |
35 | | - if terms_of_service: |
36 | | - info["termsOfService"] = terms_of_service |
37 | | - if contact: |
38 | | - info["contact"] = contact |
39 | | - if license_info: |
40 | | - info["license"] = license_info |
41 | | - output: Dict[str, Any] = {"openapi": openapi_version, "info": info} |
42 | | - if servers: |
43 | | - output["servers"] = servers |
44 | | - components: Dict[str, Dict[str, Any]] = {} |
45 | | - paths: Dict[str, Dict[str, Any]] = {} |
46 | | - # FastAPI |
47 | | - flat_fastapi_models = get_flat_models_from_routes(fastapi_routes) |
48 | | - fastapi_model_name_map = get_model_name_map(flat_fastapi_models) |
49 | | - fast_api_definitions = get_model_definitions( |
50 | | - flat_models=flat_fastapi_models, model_name_map=fastapi_model_name_map |
| 42 | + openapi_dict = get_openapi( |
| 43 | + title=title, |
| 44 | + version=version, |
| 45 | + routes=fastapi_routes, |
| 46 | + **kwargs, |
51 | 47 | ) |
52 | 48 |
|
53 | | - # pybotx RPC |
| 49 | + paths: Dict[str, Dict[str, Any]] = {} |
| 50 | + |
54 | 51 | flat_rpc_models = get_rpc_flat_models_from_routes(rpc_router) |
55 | 52 | rpc_model_name_map = get_model_name_map(flat_rpc_models) |
56 | | - rpc_definitions = get_model_definitions( |
| 53 | + rpc_definitions = get_rpc_model_definitions( |
57 | 54 | flat_models=flat_rpc_models, model_name_map=rpc_model_name_map |
58 | 55 | ) |
59 | | - |
60 | | - for route in fastapi_routes: |
61 | | - if isinstance(route, routing.APIRoute): |
62 | | - result = get_openapi_path( |
63 | | - route=route, model_name_map=fastapi_model_name_map |
64 | | - ) |
65 | | - if result: |
66 | | - path, security_schemes, path_definitions = result |
67 | | - if path: |
68 | | - paths.setdefault(route.path_format, {}).update(path) |
69 | | - if security_schemes: |
70 | | - components.setdefault("securitySchemes", {}).update( |
71 | | - security_schemes |
72 | | - ) |
73 | | - if path_definitions: |
74 | | - fast_api_definitions.update(path_definitions) |
| 56 | + security_definitions, operation_security = get_openapi_security_definitions( |
| 57 | + security_component=security |
| 58 | + ) |
75 | 59 |
|
76 | 60 | for method_name in rpc_router.rpc_methods.keys(): |
77 | 61 | if not rpc_router.rpc_methods[method_name].include_in_schema: |
78 | 62 | continue |
79 | 63 |
|
80 | | - result = get_rpc_openapi_path( # type: ignore |
| 64 | + path = get_rpc_openapi_path( # type: ignore |
81 | 65 | method_name=method_name, |
82 | 66 | route=rpc_router.rpc_methods[method_name], |
83 | 67 | model_name_map=rpc_model_name_map, |
| 68 | + security_scheme=operation_security, |
84 | 69 | ) |
85 | | - if result: |
86 | | - path, path_definitions = result # type: ignore |
87 | | - if path: |
88 | | - paths.setdefault(method_name, {}).update(path) |
| 70 | + if path: |
| 71 | + paths.setdefault(f"/{method_name}", {}).update(path) |
89 | 72 |
|
90 | | - if path_definitions: |
91 | | - rpc_definitions.update(path_definitions) |
92 | | - |
93 | | - if fast_api_definitions: |
94 | | - components["schemas"] = { |
95 | | - k: fast_api_definitions[k] for k in sorted(fast_api_definitions) |
96 | | - } |
97 | 73 | if rpc_definitions: |
98 | | - components.setdefault("schemas", {}).update( |
| 74 | + openapi_dict.setdefault("components", {}).setdefault("schemas", {}).update( |
99 | 75 | {k: rpc_definitions[k] for k in sorted(rpc_definitions)} |
100 | 76 | ) |
101 | | - if components: |
102 | | - output["components"] = components |
103 | | - output["paths"] = paths |
104 | | - if tags: |
105 | | - output["tags"] = tags |
106 | 77 |
|
107 | | - return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) |
| 78 | + openapi_dict.setdefault("components", {}).setdefault("securitySchemes", {}).update( |
| 79 | + security_definitions |
| 80 | + ) |
| 81 | + openapi_dict.setdefault("paths", {}).update(paths) |
| 82 | + |
| 83 | + return jsonable_encoder(OpenAPI(**openapi_dict), by_alias=True, exclude_none=True) |
0 commit comments