@@ -153,7 +153,7 @@ def path_join(*args: str, abs_path: bool = False) -> str:
153153    """ 
154154    path  =  os .path .join (* args )
155155    if  abs_path :
156-         path  =  os .path .abspath (path )
156+         path  =  os .path .realpath (path )
157157    return  os .path .normpath (path ).replace ("\\ " , "/" )
158158
159159
@@ -274,33 +274,150 @@ def check_required_params(config: Mapping[str, Any], required_params: Iterable[s
274274        )
275275
276276
277- def  get_git_info (path : str  =  __file__ ) ->  tuple [ str ,  str ,  str ] :
277+ def  get_git_root (path : str  =  __file__ ) ->  str :
278278    """ 
279-     Get the git repository, commit hash, and local path of the given file. 
279+     Get the root dir of the git repository. 
280+ 
281+     Parameters 
282+     ---------- 
283+     path : Optional[str] 
284+         Path to the file in git repository. 
285+ 
286+     Raises 
287+     ------ 
288+     subprocess.CalledProcessError 
289+         If the path is not a git repository or the command fails. 
290+ 
291+     Returns 
292+     ------- 
293+     str 
294+         The absolute path to the root directory of the git repository. 
295+     """ 
296+     abspath  =  path_join (path , abs_path = True )
297+     if  not  os .path .exists (abspath ) or  not  os .path .isdir (abspath ):
298+         dirname  =  os .path .dirname (abspath )
299+     else :
300+         dirname  =  abspath 
301+     git_root  =  subprocess .check_output (
302+         ["git" , "-C" , dirname , "rev-parse" , "--show-toplevel" ], text = True 
303+     ).strip ()
304+     return  path_join (git_root , abs_path = True )
305+ 
306+ 
307+ def  get_git_remote_info (path : str , remote : str ) ->  str :
308+     """ 
309+     Gets the remote URL for the given remote name in the git repository. 
280310
281311    Parameters 
282312    ---------- 
283313    path : str 
284314        Path to the file in git repository. 
315+     remote : str 
316+         The name of the remote (e.g., "origin"). 
317+ 
318+     Raises 
319+     ------ 
320+     subprocess.CalledProcessError 
321+         If the command fails or the remote does not exist. 
285322
286323    Returns 
287324    ------- 
288-     (git_repo, git_commit, git_path) : tuple[ str, str, str]  
289-         Git repository  URL, last commit hash, and relative file path . 
325+     str 
326+         The  URL of the remote repository . 
290327    """ 
291-     dirname  =  os .path .dirname (path )
292-     git_repo  =  subprocess .check_output (
293-         ["git" , "-C" , dirname , "remote" , "get-url" , "origin" ], text = True 
328+     return  subprocess .check_output (
329+         ["git" , "-C" , path , "remote" , "get-url" , remote ], text = True 
294330    ).strip ()
331+ 
332+ 
333+ def  get_git_repo_info (path : str ) ->  str :
334+     """ 
335+     Get the git repository URL for the given git repo. 
336+ 
337+     Tries to get the upstream branch URL, falling back to the "origin" remote 
338+     if the upstream branch is not set or does not exist. If that also fails, 
339+     it returns a file URL pointing to the local path. 
340+ 
341+     Parameters 
342+     ---------- 
343+     path : str 
344+         Path to the git repository. 
345+ 
346+     Raises 
347+     ------ 
348+     subprocess.CalledProcessError 
349+         If the command fails or the git repository does not exist. 
350+ 
351+     Returns 
352+     ------- 
353+     str 
354+         The upstream URL of the git repository. 
355+     """ 
356+     # In case "origin" remote is not set, or this branch has a different 
357+     # upstream, we should handle it gracefully. 
358+     # (e.g., fallback to the first one we find?) 
359+     path  =  path_join (path , abs_path = True )
360+     cmd  =  ["git" , "-C" , path , "rev-parse" , "--abbrev-ref" , "--symbolic-full-name" , "HEAD@{u}" ]
361+     try :
362+         git_remote  =  subprocess .check_output (cmd , text = True ).strip ()
363+         git_remote  =  git_remote .split ("/" , 1 )[0 ]
364+         git_repo  =  get_git_remote_info (path , git_remote )
365+     except  subprocess .CalledProcessError :
366+         git_remote  =  "origin" 
367+         _LOG .warning (
368+             "Failed to get the upstream branch for %s. Falling back to '%s' remote." ,
369+             path ,
370+             git_remote ,
371+         )
372+         try :
373+             git_repo  =  get_git_remote_info (path , git_remote )
374+         except  subprocess .CalledProcessError :
375+             git_repo  =  "file://"  +  path 
376+             _LOG .warning (
377+                 "Failed to get the upstream branch for %s. Falling back to '%s'." ,
378+                 path ,
379+                 git_repo ,
380+             )
381+     return  git_repo 
382+ 
383+ 
384+ def  get_git_info (path : str  =  __file__ ) ->  tuple [str , str , str , str ]:
385+     """ 
386+     Get the git repository, commit hash, and local path of the given file. 
387+ 
388+     Parameters 
389+     ---------- 
390+     path : str 
391+         Path to the file in git repository. 
392+ 
393+     Raises 
394+     ------ 
395+     subprocess.CalledProcessError 
396+         If the path is not a git repository or the command fails. 
397+ 
398+     Returns 
399+     ------- 
400+     (git_repo, git_commit, rel_path, abs_path) : tuple[str, str, str, str] 
401+         Git repository URL, last commit hash, and relative file path and current 
402+         absolute path. 
403+     """ 
404+     abspath  =  path_join (path , abs_path = True )
405+     if  os .path .exists (abspath ) and  os .path .isdir (abspath ):
406+         dirname  =  abspath 
407+     else :
408+         dirname  =  os .path .dirname (abspath )
409+     git_root  =  get_git_root (path = abspath )
410+     git_repo  =  get_git_repo_info (git_root )
295411    git_commit  =  subprocess .check_output (
296412        ["git" , "-C" , dirname , "rev-parse" , "HEAD" ], text = True 
297413    ).strip ()
298-     git_root  =  subprocess .check_output (
299-         ["git" , "-C" , dirname , "rev-parse" , "--show-toplevel" ], text = True 
300-     ).strip ()
301-     _LOG .debug ("Current git branch: %s %s" , git_repo , git_commit )
302-     rel_path  =  os .path .relpath (os .path .abspath (path ), os .path .abspath (git_root ))
303-     return  (git_repo , git_commit , rel_path .replace ("\\ " , "/" ))
414+     _LOG .debug ("Current git branch for %s: %s %s" , git_root , git_repo , git_commit )
415+     rel_path  =  os .path .relpath (abspath , os .path .abspath (git_root ))
416+     # TODO: return the branch too? 
417+     return  (git_repo , git_commit , rel_path .replace ("\\ " , "/" ), abspath )
418+ 
419+ 
420+ # TODO: Add support for checking out the branch locally. 
304421
305422
306423# Note: to avoid circular imports, we don't specify TunableValue here. 
0 commit comments