@@ -52,21 +52,21 @@ def get_bundled_script_path(shell: str) -> Path:
52
52
shell_dir = Path (__file__ ).parent / "shell"
53
53
if shell == "zsh" :
54
54
return shell_dir / "shelloracle.zsh"
55
- elif shell == "bash" :
55
+ else :
56
56
return shell_dir / "shelloracle.bash"
57
57
58
58
59
59
def get_script_path (shell : str ) -> Path :
60
60
if shell == "zsh" :
61
61
return Path .home () / ".shelloracle.zsh"
62
- elif shell == "bash" :
62
+ else :
63
63
return Path .home () / ".shelloracle.bash"
64
64
65
65
66
66
def get_rc_path (shell : str ) -> Path :
67
67
if shell == "zsh" :
68
68
return Path .home () / ".zshrc"
69
- elif shell == "bash" :
69
+ else :
70
70
return Path .home () / ".bashrc"
71
71
72
72
@@ -91,7 +91,7 @@ def update_rc(shell: str) -> None:
91
91
print_info (f"Successfully updated { replace_home_with_tilde (rc_path )} " )
92
92
93
93
94
- def get_settings (provider : Provider ) -> Iterator [tuple [str , Setting ]]:
94
+ def get_settings (provider : type [ Provider ] ) -> Iterator [tuple [str , Setting ]]:
95
95
settings = inspect .getmembers (provider , predicate = lambda p : isinstance (p , Setting ))
96
96
97
97
def correct_name_setting ():
@@ -103,7 +103,7 @@ def correct_name_setting():
103
103
yield from correct_name_setting ()
104
104
105
105
106
- def write_shelloracle_config (provider : Provider , settings : dict [str , Any ]) -> None :
106
+ def write_shelloracle_config (provider : type [ Provider ] , settings : dict [str , Any ]) -> None :
107
107
config = tomlkit .document ()
108
108
109
109
shor_table = tomlkit .table ()
@@ -134,11 +134,14 @@ def install_keybindings() -> None:
134
134
update_rc (shell )
135
135
136
136
137
- def user_configure_settings (provider : Provider ) -> dict [str , Any ]:
137
+ def user_configure_settings (provider : type [ Provider ] ) -> dict [str , Any ]:
138
138
settings = {}
139
139
for name , setting in get_settings (provider ):
140
140
user_input = prompt (f"{ name } : " , default = str (setting .default ))
141
- type_ = type (setting .default ) if setting .default else str
141
+ if setting .default :
142
+ type_ = type (setting .default )
143
+ else :
144
+ type_ = str
142
145
value = type_ (user_input )
143
146
settings [name ] = value
144
147
return settings
@@ -151,7 +154,7 @@ def case_correct_user_input(user_input: str, options: Sequence[str]) -> str | No
151
154
return None
152
155
153
156
154
- def user_select_provider () -> Provider :
157
+ def user_select_provider () -> type [ Provider ] :
155
158
providers = list_providers ()
156
159
completer = WordCompleter (providers , ignore_case = True )
157
160
user_selected_provider = prompt (f"Choose your LLM provider ({ ', ' .join (providers )} ): " , completer = completer )
0 commit comments