Skip to content

Commit b9553d0

Browse files
Fixups and improvements to git info retrieval (#991)
# Pull Request ## Title Fixups and improvements to git info retrieval ______________________________________________________________________ ## Description Various fixups and improvements to git info retrieval: Test and handle cases where the local git repo has no upstream or the upstream is not "origin". ______________________________________________________________________ ## Type of Change - 🛠️ Bug fix - 🔄 Refactor - 🧪 Tests ______________________________________________________________________ ## Testing - CI - New unit tests ______________________________________________________________________ ## Additional Notes (optional) To be merged before #985 ______________________________________________________________________ --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 04447c8 commit b9553d0

File tree

3 files changed

+216
-18
lines changed

3 files changed

+216
-18
lines changed

mlos_bench/mlos_bench/storage/base_storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def __init__( # pylint: disable=too-many-arguments
194194
self._tunables = tunables.copy()
195195
self._trial_id = trial_id
196196
self._experiment_id = experiment_id
197-
(self._git_repo, self._git_commit, self._root_env_config) = get_git_info(
197+
(self._git_repo, self._git_commit, self._root_env_config, _future_pr) = get_git_info(
198198
root_env_config
199199
)
200200
self._description = description

mlos_bench/mlos_bench/tests/util_git_test.py

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,95 @@
33
# Licensed under the MIT License.
44
#
55
"""Unit tests for get_git_info utility function."""
6+
import os
67
import re
8+
import tempfile
9+
from pathlib import Path
10+
from subprocess import CalledProcessError
11+
from subprocess import check_call as run
712

8-
from mlos_bench.util import get_git_info
13+
import pytest
14+
15+
from mlos_bench.util import get_git_info, get_git_root, path_join
916

1017

1118
def test_get_git_info() -> None:
12-
"""Check that we can retrieve git info about the current repository correctly."""
13-
(git_repo, git_commit, rel_path) = get_git_info(__file__)
19+
"""Check that we can retrieve git info about the current repository correctly from a
20+
file.
21+
"""
22+
(git_repo, git_commit, rel_path, abs_path) = get_git_info(__file__)
1423
assert "mlos" in git_repo.lower()
1524
assert re.match(r"[0-9a-f]{40}", git_commit) is not None
1625
assert rel_path == "mlos_bench/mlos_bench/tests/util_git_test.py"
26+
assert abs_path == path_join(__file__, abs_path=True)
27+
28+
29+
def test_get_git_info_dir() -> None:
30+
"""Check that we can retrieve git info about the current repository correctly from a
31+
directory.
32+
"""
33+
dirname = os.path.dirname(__file__)
34+
(git_repo, git_commit, rel_path, abs_path) = get_git_info(dirname)
35+
assert "mlos" in git_repo.lower()
36+
assert re.match(r"[0-9a-f]{40}", git_commit) is not None
37+
assert rel_path == "mlos_bench/mlos_bench/tests"
38+
assert abs_path == path_join(dirname, abs_path=True)
39+
40+
41+
def test_non_git_dir() -> None:
42+
"""Check that we can handle a non-git directory."""
43+
with tempfile.TemporaryDirectory() as non_git_dir:
44+
with pytest.raises(CalledProcessError):
45+
# This should raise an error because the directory is not a git repository.
46+
get_git_root(non_git_dir)
47+
48+
49+
def test_non_upstream_git() -> None:
50+
"""Check that we can handle a git directory without an upstream."""
51+
with tempfile.TemporaryDirectory() as local_git_dir:
52+
local_git_dir = path_join(local_git_dir, abs_path=True)
53+
# Initialize a new git repository.
54+
run(["git", "init", local_git_dir, "-b", "main"])
55+
run(["git", "-C", local_git_dir, "config", "--local", "user.email", "[email protected]"])
56+
run(["git", "-C", local_git_dir, "config", "--local", "user.name", "PyTest User"])
57+
Path(local_git_dir).joinpath("README.md").touch()
58+
run(["git", "-C", local_git_dir, "add", "README.md"])
59+
run(["git", "-C", local_git_dir, "commit", "-m", "Initial commit"])
60+
# This should have slightly different behavior when there is no upstream.
61+
(git_repo, _git_commit, rel_path, abs_path) = get_git_info(local_git_dir)
62+
assert git_repo == f"file://{local_git_dir}"
63+
assert abs_path == local_git_dir
64+
assert rel_path == "."
65+
66+
67+
@pytest.mark.skipif(
68+
os.environ.get("GITHUB_ACTIONS") != "true",
69+
reason="Not running in GitHub Actions CI.",
70+
)
71+
def test_github_actions_git_info() -> None:
72+
"""
73+
Test that get_git_info matches GitHub Actions environment variables if running in
74+
CI.
75+
76+
Examples
77+
--------
78+
Test locally with the following command:
79+
80+
.. code-block:: shell
81+
82+
export GITHUB_ACTIONS=true
83+
export GITHUB_SHA=$(git rev-parse HEAD)
84+
# GITHUB_REPOSITORY should be in "owner/repo" format.
85+
# e.g., GITHUB_REPOSITORY="bpkroth/MLOS" or "microsoft/MLOS"
86+
export GITHUB_REPOSITORY=$(git rev-parse --abbrev-ref --symbolic-full-name HEAD@{u} | cut -d/ -f1 | xargs git remote get-url | grep https://github.com | cut -d/ -f4-)
87+
pytest -n0 mlos_bench/mlos_bench/tests/util_git_test.py
88+
""" # pylint: disable=line-too-long # noqa: E501
89+
repo_env = os.environ.get("GITHUB_REPOSITORY") # "owner/repo" format
90+
sha_env = os.environ.get("GITHUB_SHA")
91+
assert repo_env, "GITHUB_REPOSITORY not set in environment."
92+
assert sha_env, "GITHUB_SHA not set in environment."
93+
git_repo, git_commit, _rel_path, _abs_path = get_git_info(__file__)
94+
assert git_repo.endswith(repo_env), f"git_repo '{git_repo}' does not end with '{repo_env}'"
95+
assert (
96+
git_commit == sha_env
97+
), f"git_commit '{git_commit}' does not match GITHUB_SHA '{sha_env}'"

mlos_bench/mlos_bench/util.py

Lines changed: 131 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def path_join(*args: str, abs_path: bool = False) -> str:
153153
"""
154154
path = os.path.join(*args)
155155
if abs_path:
156-
path = os.path.abspath(path)
156+
path = os.path.realpath(path)
157157
return os.path.normpath(path).replace("\\", "/")
158158

159159

@@ -274,33 +274,150 @@ def check_required_params(config: Mapping[str, Any], required_params: Iterable[s
274274
)
275275

276276

277-
def get_git_info(path: str = __file__) -> tuple[str, str, str]:
277+
def get_git_root(path: str = __file__) -> str:
278278
"""
279-
Get the git repository, commit hash, and local path of the given file.
279+
Get the root dir of the git repository.
280+
281+
Parameters
282+
----------
283+
path : Optional[str]
284+
Path to the file in git repository.
285+
286+
Raises
287+
------
288+
subprocess.CalledProcessError
289+
If the path is not a git repository or the command fails.
290+
291+
Returns
292+
-------
293+
str
294+
The absolute path to the root directory of the git repository.
295+
"""
296+
abspath = path_join(path, abs_path=True)
297+
if not os.path.exists(abspath) or not os.path.isdir(abspath):
298+
dirname = os.path.dirname(abspath)
299+
else:
300+
dirname = abspath
301+
git_root = subprocess.check_output(
302+
["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True
303+
).strip()
304+
return path_join(git_root, abs_path=True)
305+
306+
307+
def get_git_remote_info(path: str, remote: str) -> str:
308+
"""
309+
Gets the remote URL for the given remote name in the git repository.
280310
281311
Parameters
282312
----------
283313
path : str
284314
Path to the file in git repository.
315+
remote : str
316+
The name of the remote (e.g., "origin").
317+
318+
Raises
319+
------
320+
subprocess.CalledProcessError
321+
If the command fails or the remote does not exist.
285322
286323
Returns
287324
-------
288-
(git_repo, git_commit, git_path) : tuple[str, str, str]
289-
Git repository URL, last commit hash, and relative file path.
325+
str
326+
The URL of the remote repository.
290327
"""
291-
dirname = os.path.dirname(path)
292-
git_repo = subprocess.check_output(
293-
["git", "-C", dirname, "remote", "get-url", "origin"], text=True
328+
return subprocess.check_output(
329+
["git", "-C", path, "remote", "get-url", remote], text=True
294330
).strip()
331+
332+
333+
def get_git_repo_info(path: str) -> str:
334+
"""
335+
Get the git repository URL for the given git repo.
336+
337+
Tries to get the upstream branch URL, falling back to the "origin" remote
338+
if the upstream branch is not set or does not exist. If that also fails,
339+
it returns a file URL pointing to the local path.
340+
341+
Parameters
342+
----------
343+
path : str
344+
Path to the git repository.
345+
346+
Raises
347+
------
348+
subprocess.CalledProcessError
349+
If the command fails or the git repository does not exist.
350+
351+
Returns
352+
-------
353+
str
354+
The upstream URL of the git repository.
355+
"""
356+
# In case "origin" remote is not set, or this branch has a different
357+
# upstream, we should handle it gracefully.
358+
# (e.g., fallback to the first one we find?)
359+
path = path_join(path, abs_path=True)
360+
cmd = ["git", "-C", path, "rev-parse", "--abbrev-ref", "--symbolic-full-name", "HEAD@{u}"]
361+
try:
362+
git_remote = subprocess.check_output(cmd, text=True).strip()
363+
git_remote = git_remote.split("/", 1)[0]
364+
git_repo = get_git_remote_info(path, git_remote)
365+
except subprocess.CalledProcessError:
366+
git_remote = "origin"
367+
_LOG.warning(
368+
"Failed to get the upstream branch for %s. Falling back to '%s' remote.",
369+
path,
370+
git_remote,
371+
)
372+
try:
373+
git_repo = get_git_remote_info(path, git_remote)
374+
except subprocess.CalledProcessError:
375+
git_repo = "file://" + path
376+
_LOG.warning(
377+
"Failed to get the upstream branch for %s. Falling back to '%s'.",
378+
path,
379+
git_repo,
380+
)
381+
return git_repo
382+
383+
384+
def get_git_info(path: str = __file__) -> tuple[str, str, str, str]:
385+
"""
386+
Get the git repository, commit hash, and local path of the given file.
387+
388+
Parameters
389+
----------
390+
path : str
391+
Path to the file in git repository.
392+
393+
Raises
394+
------
395+
subprocess.CalledProcessError
396+
If the path is not a git repository or the command fails.
397+
398+
Returns
399+
-------
400+
(git_repo, git_commit, rel_path, abs_path) : tuple[str, str, str, str]
401+
Git repository URL, last commit hash, and relative file path and current
402+
absolute path.
403+
"""
404+
abspath = path_join(path, abs_path=True)
405+
if os.path.exists(abspath) and os.path.isdir(abspath):
406+
dirname = abspath
407+
else:
408+
dirname = os.path.dirname(abspath)
409+
git_root = get_git_root(path=abspath)
410+
git_repo = get_git_repo_info(git_root)
295411
git_commit = subprocess.check_output(
296412
["git", "-C", dirname, "rev-parse", "HEAD"], text=True
297413
).strip()
298-
git_root = subprocess.check_output(
299-
["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True
300-
).strip()
301-
_LOG.debug("Current git branch: %s %s", git_repo, git_commit)
302-
rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root))
303-
return (git_repo, git_commit, rel_path.replace("\\", "/"))
414+
_LOG.debug("Current git branch for %s: %s %s", git_root, git_repo, git_commit)
415+
rel_path = os.path.relpath(abspath, os.path.abspath(git_root))
416+
# TODO: return the branch too?
417+
return (git_repo, git_commit, rel_path.replace("\\", "/"), abspath)
418+
419+
420+
# TODO: Add support for checking out the branch locally.
304421

305422

306423
# Note: to avoid circular imports, we don't specify TunableValue here.

0 commit comments

Comments
 (0)