Skip to content

Commit 449bd24

Browse files
committed
Fix mypy errors
1 parent ecd1450 commit 449bd24

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

src/shelloracle/configure.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,21 @@ def get_bundled_script_path(shell: str) -> Path:
5252
shell_dir = Path(__file__).parent / "shell"
5353
if shell == "zsh":
5454
return shell_dir / "shelloracle.zsh"
55-
elif shell == "bash":
55+
else:
5656
return shell_dir / "shelloracle.bash"
5757

5858

5959
def get_script_path(shell: str) -> Path:
6060
if shell == "zsh":
6161
return Path.home() / ".shelloracle.zsh"
62-
elif shell == "bash":
62+
else:
6363
return Path.home() / ".shelloracle.bash"
6464

6565

6666
def get_rc_path(shell: str) -> Path:
6767
if shell == "zsh":
6868
return Path.home() / ".zshrc"
69-
elif shell == "bash":
69+
else:
7070
return Path.home() / ".bashrc"
7171

7272

@@ -91,7 +91,7 @@ def update_rc(shell: str) -> None:
9191
print_info(f"Successfully updated {replace_home_with_tilde(rc_path)}")
9292

9393

94-
def get_settings(provider: Provider) -> Iterator[tuple[str, Setting]]:
94+
def get_settings(provider: type[Provider]) -> Iterator[tuple[str, Setting]]:
9595
settings = inspect.getmembers(provider, predicate=lambda p: isinstance(p, Setting))
9696

9797
def correct_name_setting():
@@ -103,7 +103,7 @@ def correct_name_setting():
103103
yield from correct_name_setting()
104104

105105

106-
def write_shelloracle_config(provider: Provider, settings: dict[str, Any]) -> None:
106+
def write_shelloracle_config(provider: type[Provider], settings: dict[str, Any]) -> None:
107107
config = tomlkit.document()
108108

109109
shor_table = tomlkit.table()
@@ -134,11 +134,14 @@ def install_keybindings() -> None:
134134
update_rc(shell)
135135

136136

137-
def user_configure_settings(provider: Provider) -> dict[str, Any]:
137+
def user_configure_settings(provider: type[Provider]) -> dict[str, Any]:
138138
settings = {}
139139
for name, setting in get_settings(provider):
140140
user_input = prompt(f"{name}: ", default=str(setting.default))
141-
type_ = type(setting.default) if setting.default else str
141+
if setting.default:
142+
type_ = type(setting.default)
143+
else:
144+
type_ = str
142145
value = type_(user_input)
143146
settings[name] = value
144147
return settings
@@ -151,7 +154,7 @@ def case_correct_user_input(user_input: str, options: Sequence[str]) -> str | No
151154
return None
152155

153156

154-
def user_select_provider() -> Provider:
157+
def user_select_provider() -> type[Provider]:
155158
providers = list_providers()
156159
completer = WordCompleter(providers, ignore_case=True)
157160
user_selected_provider = prompt(f"Choose your LLM provider ({', '.join(providers)}): ", completer=completer)

0 commit comments

Comments
 (0)