|
| 1 | +import glob |
| 2 | +import logging |
| 3 | +import os |
| 4 | +import platform |
| 5 | +import re |
| 6 | +import stat |
| 7 | +from pathlib import Path |
| 8 | +from typing import Optional |
| 9 | + |
| 10 | +from vectorcode.cli_utils import GLOBAL_CONFIG_PATH, Config, find_project_root |
| 11 | + |
| 12 | +logger = logging.getLogger(name=__name__) |
| 13 | +__GLOBAL_HOOKS_PATH = Path(GLOBAL_CONFIG_PATH).parent / "hooks" |
| 14 | + |
| 15 | + |
| 16 | +# Keys: name of the hooks, ie. `pre-commit` |
| 17 | +# Values: lines of the hooks. |
| 18 | +__HOOK_CONTENTS: dict[str, list[str]] = { |
| 19 | + "pre-commit": [ |
| 20 | + "diff_files=$(git diff --cached --name-only)", |
| 21 | + '[ -z "$diff_files" ] || vectorcode vectorise $diff_files', |
| 22 | + ], |
| 23 | + "post-checkout": [ |
| 24 | + 'files=$(git diff --name-only "$1" "$2")', |
| 25 | + '[ -z "$files" ] || vectorcode vectorise $files', |
| 26 | + ], |
| 27 | +} |
| 28 | + |
| 29 | + |
| 30 | +def __lines_are_empty(lines: list[str]) -> bool: |
| 31 | + pattern = re.compile(r"^\s*$") |
| 32 | + if len(lines) == 0: |
| 33 | + return True |
| 34 | + return all(map(lambda line: pattern.match(line) is not None, lines)) |
| 35 | + |
| 36 | + |
| 37 | +def load_hooks(): |
| 38 | + global __HOOK_CONTENTS |
| 39 | + for file in glob.glob(str(__GLOBAL_HOOKS_PATH / "*")): |
| 40 | + hook_name = Path(file).stem |
| 41 | + with open(file) as fin: |
| 42 | + lines = fin.readlines() |
| 43 | + if not __lines_are_empty(lines): |
| 44 | + __HOOK_CONTENTS[hook_name] = lines |
| 45 | + |
| 46 | + |
| 47 | +class HookFile: |
| 48 | + prefix = "# VECTORCODE_HOOK_START" |
| 49 | + suffix = "# VECTORCODE_HOOK_END" |
| 50 | + prefix_pattern = re.compile(r"^\s*#\s*VECTORCODE_HOOK_START\s*") |
| 51 | + suffix_pattern = re.compile(r"^\s*#\s*VECTORCODE_HOOK_END\s*") |
| 52 | + |
| 53 | + def __init__(self, path: str | Path, git_dir: Optional[str | Path] = None): |
| 54 | + self.path = path |
| 55 | + self.lines: list[str] = [] |
| 56 | + if os.path.isfile(self.path): |
| 57 | + with open(self.path) as fin: |
| 58 | + self.lines.extend(fin.readlines()) |
| 59 | + |
| 60 | + def has_vectorcode_hooks(self, force: bool = False) -> bool: |
| 61 | + for start, start_line in enumerate(self.lines): |
| 62 | + if self.prefix_pattern.match(start_line) is None: |
| 63 | + continue |
| 64 | + |
| 65 | + for end in range(start + 1, len(self.lines)): |
| 66 | + if self.suffix_pattern.match(self.lines[end]) is not None: |
| 67 | + if force: |
| 68 | + logger.debug("`force` cleaning existing VectorCode hooks...") |
| 69 | + new_lines = self.lines[:start] + self.lines[end + 1 :] |
| 70 | + self.lines[:] = new_lines |
| 71 | + return False |
| 72 | + logger.debug( |
| 73 | + f"Found vectorcode hook block between line {start} and {end} in {self.path}:\n{''.join(self.lines[start + 1 : end])}" |
| 74 | + ) |
| 75 | + return True |
| 76 | + |
| 77 | + return False |
| 78 | + |
| 79 | + def inject_hook(self, content: list[str], force: bool = False): |
| 80 | + if len(self.lines) == 0 or not self.has_vectorcode_hooks(force): |
| 81 | + self.lines.append(self.prefix + "\n") |
| 82 | + self.lines.extend(i if i.endswith("\n") else i + "\n" for i in content) |
| 83 | + self.lines.append(self.suffix + "\n") |
| 84 | + with open(self.path, "w") as fin: |
| 85 | + fin.writelines(self.lines) |
| 86 | + if platform.system() != "Windows": |
| 87 | + # for unix systems, set the executable bit. |
| 88 | + curr_mode = os.stat(self.path).st_mode |
| 89 | + os.chmod(self.path, mode=curr_mode | stat.S_IXUSR) |
| 90 | + |
| 91 | + |
| 92 | +async def hooks(configs: Config) -> int: |
| 93 | + project_root = configs.project_root or "." |
| 94 | + git_root = find_project_root(project_root, ".git") |
| 95 | + if git_root is None: |
| 96 | + logger.error(f"{project_root} is not inside a git repo directory!") |
| 97 | + return 1 |
| 98 | + load_hooks() |
| 99 | + for hook in __HOOK_CONTENTS.keys(): |
| 100 | + hook_file_path = os.path.join(git_root, ".git", "hooks", hook) |
| 101 | + logger.info(f"Writing {hook} hook into {hook_file_path}.") |
| 102 | + print(f"Processing {hook} hook...") |
| 103 | + hook_obj = HookFile(hook_file_path, git_dir=git_root) |
| 104 | + hook_obj.inject_hook(__HOOK_CONTENTS[hook], configs.force) |
| 105 | + return 0 |
0 commit comments