Skip to content

Commit 095e961

Browse files
committed
Refactor @cmd.declare_command
1 parent 33c5c7c commit 095e961

File tree

3 files changed

+213
-161
lines changed

3 files changed

+213
-161
lines changed

modules/pymol/commanding.py

Lines changed: 119 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,19 @@
2020
if True:
2121
import _thread as thread
2222
import urllib.request as urllib2
23-
from io import FileIO as file
23+
from io import FileIO as file, BytesIO
2424

25+
import builtins
2526
import inspect
2627
import glob
2728
import shlex
29+
import tokenize
2830
from enum import Enum
2931
from functools import wraps
3032
from pathlib import Path
3133
from textwrap import dedent
32-
from typing import List
34+
from typing import Tuple, Iterable, get_args, Optional, Union, Any, NewType, List, get_origin
35+
3336

3437
import re
3538
import os
@@ -599,45 +602,117 @@ def get_state_list(states_str):
599602
states_list = sorted(set(map(int, output)))
600603
return _cmd.delete_states(_self._COb, name, states_list)
601604

602-
class Selection(str):
603-
pass
604-
605-
606-
def _parse_bool(value: str):
607-
if isinstance(value, str):
605+
def _into_types(type, value):
606+
if repr(type) == 'typing.Any':
607+
return value
608+
elif type is bool:
609+
if isinstance(value, bool):
610+
return value
608611
if value.lower() in ["yes", "1", "true", "on", "y"]:
609612
return True
610613
elif value.lower() in ["no", "0", "false", "off", "n"]:
611614
return False
612615
else:
613-
raise Exception("Invalid boolean value: %s" % value)
614-
elif isinstance(value, bool):
615-
return value
616-
else:
617-
raise Exception(f"Unsuported boolean flag {value}")
618-
619-
def _parse_list_str(value):
620-
return shlex.split(value)
621-
622-
def _parse_list_int(value):
623-
return list(map(int, shlex.split(value)))
616+
raise pymol.CmdException("Invalid boolean value: %s" % value)
617+
618+
elif isinstance(type, builtins.type):
619+
return type(value)
620+
621+
if origin := get_origin(type):
622+
if not repr(origin).startswith('typing.') and issubclass(origin, tuple):
623+
args = get_args(type)
624+
new_values = []
625+
for i, new_value in enumerate(shlex.split(value)):
626+
new_values.append(_into_types(args[i], new_value))
627+
return tuple(new_values)
628+
629+
elif origin == Union:
630+
args = get_args(type)
631+
found = False
632+
for i, arg in enumerate(args):
633+
try:
634+
found = True
635+
return _into_types(arg, value)
636+
except:
637+
found = False
638+
if not found:
639+
raise pymol.CmdException(f"Union was not able to cast %s" % value)
640+
641+
elif issubclass(list, origin):
642+
args = get_args(type)
643+
if len(args) > 0:
644+
f = args[0]
645+
else:
646+
f = lambda x: x
647+
return [f(i) for i in shlex.split(value)]
648+
649+
# elif value is None:
650+
# origin = get_origin(type)
651+
# if origin is None:
652+
# return None
653+
# else:
654+
# return _into_types(origin)
655+
# for arg in get_args(origin):
656+
# return _into_types(get_args(origin), value)
657+
658+
elif isinstance(type, str):
659+
return str(value)
660+
661+
raise pymol.CmdException(f"Unsupported argument type {type}")
662+
663+
def parse_documentation(func):
664+
source = inspect.getsource(func)
665+
tokens = tokenize.tokenize(BytesIO(source.encode('utf-8')).readline)
666+
tokens = list(tokens)
667+
comments = []
668+
params = {}
669+
i = -1
670+
started = False
671+
while True:
672+
i += 1
673+
if tokens[i].string == "def":
674+
while tokens[i].string == "(":
675+
i += 1
676+
started = True
677+
continue
678+
if not started:
679+
continue
680+
if tokens[i].string == "->":
681+
break
682+
if tokens[i].type == tokenize.NEWLINE:
683+
break
684+
if tokens[i].string == ")":
685+
break
686+
if tokens[i].type == tokenize.COMMENT:
687+
comments.append(tokens[i].string)
688+
continue
689+
if tokens[i].type == tokenize.NAME and tokens[i+1].string == ":":
690+
name = tokens[i].string
691+
name_line = tokens[i].line
692+
i += 1
693+
while not (tokens[i].type == tokenize.NAME and tokens[i+1].string == ":"):
694+
if tokens[i].type == tokenize.COMMENT and tokens[i].line == name_line:
695+
comments.append(tokens[i].string)
696+
break
697+
elif tokens[i].type == tokenize.NEWLINE:
698+
break
699+
i += 1
700+
else:
701+
i -= 3
702+
docs = ' '.join(c[1:].strip() for c in comments)
703+
params[name] = docs
704+
comments = []
705+
return params
624706

625-
def _parse_list_float(value):
626-
return list(map(float, shlex.split(value)))
627707

628708
def declare_command(name, function=None, _self=cmd):
709+
629710
if function is None:
630711
name, function = name.__name__, name
631712

632-
# new style commands should have annotations
633-
annotations = [a for a in function.__annotations__ if a != "return"]
634-
if function.__code__.co_argcount != len(annotations):
635-
raise Exception("Messy annotations")
636-
637713
# docstring text, if present, should be dedented
638714
if function.__doc__ is not None:
639-
function.__doc__ = dedent(function.__doc__).strip()
640-
715+
function.__doc__ = dedent(function.__doc__)
641716

642717
# Analysing arguments
643718
spec = inspect.getfullargspec(function)
@@ -658,37 +733,32 @@ def declare_command(name, function=None, _self=cmd):
658733
def inner(*args, **kwargs):
659734
frame = traceback.format_stack()[-2]
660735
caller = frame.split("\"", maxsplit=2)[1]
661-
662736
# It was called from command line or pml script, so parse arguments
663737
if caller.endswith("pymol/parser.py"):
664-
kwargs = {**kwargs_, **kwargs, **dict(zip(args2_, args))}
738+
kwargs = {**kwargs, **dict(zip(args2_, args))}
665739
kwargs.pop("_self", None)
666-
for arg in kwargs.copy():
667-
if funcs[arg] == bool:
668-
funcs[arg] = _parse_bool
669-
elif funcs[arg] == List[str]:
670-
funcs[arg] = _parse_list_str
671-
elif funcs[arg] == List[int]:
672-
funcs[arg] = _parse_list_int
673-
elif funcs[arg] == List[float]:
674-
funcs[arg] = _parse_list_float
675-
else:
676-
# Assume it's a literal supported type
677-
pass
678-
# Convert the argument to the correct type
679-
kwargs[arg] = funcs[arg](kwargs[arg])
680-
return function(**kwargs)
740+
new_kwargs = {}
741+
for var, type in funcs.items():
742+
if var in kwargs:
743+
value = kwargs[var]
744+
new_kwargs[var] = _into_types(type, value)
745+
final_kwargs = {}
746+
for k, v in kwargs_.items():
747+
final_kwargs[k] = v
748+
for k, v in new_kwargs.items():
749+
if k not in final_kwargs:
750+
final_kwargs[k] = v
751+
return function(**final_kwargs)
681752

682753
# It was called from Python, so pass the arguments as is
683754
else:
684755
return function(*args, **kwargs)
756+
inner.__arg_docs = parse_documentation(function)
685757

686-
name = function.__name__
687-
_self.keyword[name] = [inner, 0, 0, ",", parsing.STRICT]
688-
_self.kwhash.append(name)
689-
_self.help_sc.append(name)
758+
_self.keyword[name] = [inner, 0,0,',',parsing.STRICT]
690759
return inner
691760

761+
692762
def extend(name, function=None, _self=cmd):
693763

694764
'''

0 commit comments

Comments
 (0)