Skip to content

Commit cb2f9f4

Browse files
committed
Refactor command handling and improve type hints across multiple modules
1 parent ef4e505 commit cb2f9f4

File tree

9 files changed

+75
-52
lines changed

9 files changed

+75
-52
lines changed

src/iop/_cli.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _determine_command_type(self) -> CommandType:
9898
def _handle_default(self) -> None:
9999
if self.args.default == 'not_set':
100100
print(_Director.get_default_production())
101-
else:
101+
elif self.args.default is not None:
102102
_Director.set_default_production(self.args.default)
103103

104104
def _handle_list(self) -> None:
@@ -144,15 +144,16 @@ def _handle_export(self) -> None:
144144

145145
def _handle_migrate(self) -> None:
146146
migrate_path = self.args.migrate
147-
if not os.path.isabs(migrate_path):
148-
migrate_path = os.path.join(os.getcwd(), migrate_path)
149-
_Utils.migrate(migrate_path)
147+
if migrate_path is not None:
148+
if not os.path.isabs(migrate_path):
149+
migrate_path = os.path.join(os.getcwd(), migrate_path)
150+
_Utils.migrate(migrate_path)
150151

151152
def _handle_log(self) -> None:
152153
if self.args.log == 'not_set':
153154
print(_Director.log_production())
154-
else:
155-
print(_Director.log_production_top(self.args.log))
155+
elif self.args.log is not None:
156+
print(_Director.log_production_top(int(self.args.log)))
156157

157158
def _handle_init(self) -> None:
158159
_Utils.setup(None)

src/iop/_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def _get_info(cls) -> List[str]:
106106
super_class = classname[1:-1]
107107
adapter = cls.get_adapter_type()
108108
if adapter is None:
109-
adapter = cls.getAdapterType() # For backwards compatibility
109+
# for retro-compatibility
110+
adapter = cls.getAdapterType() # type: ignore
110111
break
111112
elif classname in ["'iop.BusinessProcess'","'iop.DuplexProcess'","'iop.InboundAdapter'","'iop.OutboundAdapter'",
112113
"'grongier.pex.BusinessProcess'","'grongier.pex.DuplexProcess'","'grongier.pex.InboundAdapter'","'grongier.pex.OutboundAdapter'"] :

src/iop/_dispatch.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from inspect import signature
2-
from typing import Any
1+
from inspect import signature, Parameter
2+
from typing import Any, List, Tuple, Callable
33

44
from ._serialization import serialize_message, serialize_pickle_message, deserialize_message, deserialize_pickle_message
55
from ._message_validator import is_message_instance, is_pickle_message_instance, is_iris_object_instance
@@ -59,7 +59,7 @@ def dispatch_deserializer(serial: Any) -> Any:
5959
else:
6060
return serial
6161

62-
def dispach_message(host, request: Any) -> Any:
62+
def dispach_message(host: Any, request: Any) -> Any:
6363
"""Dispatches the message to the appropriate method.
6464
6565
Args:
@@ -79,23 +79,43 @@ def dispach_message(host, request: Any) -> Any:
7979

8080
return getattr(host, call)(request)
8181

82-
def create_dispatch(host) -> None:
83-
"""Creates a list of tuples, where each tuple contains the name of a class and the name of a method
84-
that takes an instance of that class as its only argument.
82+
def create_dispatch(host: Any) -> None:
83+
"""Creates a dispatch table mapping class names to their handler methods.
84+
The dispatch table consists of tuples of (fully_qualified_class_name, method_name).
85+
Only methods that take a single typed parameter are considered as handlers.
8586
"""
86-
if len(host.DISPATCH) == 0:
87-
method_list = [func for func in dir(host) if callable(getattr(host, func)) and not func.startswith("_")]
88-
for method in method_list:
89-
try:
90-
param = signature(getattr(host, method)).parameters
91-
except ValueError as e:
92-
param = ''
93-
if (len(param) == 1):
94-
annotation = str(param[list(param)[0]].annotation)
95-
i = annotation.find("'")
96-
j = annotation.rfind("'")
97-
if j == -1:
98-
j = None
99-
classname = annotation[i+1:j]
100-
host.DISPATCH.append((classname, method))
101-
return
87+
if len(host.DISPATCH) > 0:
88+
return
89+
90+
for method_name in get_callable_methods(host):
91+
handler_info = get_handler_info(host, method_name)
92+
if handler_info:
93+
host.DISPATCH.append(handler_info)
94+
95+
def get_callable_methods(host: Any) -> List[str]:
96+
"""Returns a list of callable method names that don't start with underscore."""
97+
return [
98+
func for func in dir(host)
99+
if callable(getattr(host, func)) and not func.startswith("_")
100+
]
101+
102+
def get_handler_info(host: Any, method_name: str) -> Tuple[str, str] | None:
103+
"""Analyzes a method to determine if it's a valid message handler.
104+
Returns a tuple of (fully_qualified_class_name, method_name) if valid,
105+
None otherwise.
106+
"""
107+
try:
108+
params = signature(getattr(host, method_name)).parameters
109+
if len(params) != 1:
110+
return None
111+
112+
param: Parameter = next(iter(params.values()))
113+
annotation = param.annotation
114+
115+
if annotation == Parameter.empty or not isinstance(annotation, type):
116+
return None
117+
118+
return f"{annotation.__module__}.{annotation.__name__}", method_name
119+
120+
except ValueError:
121+
return None

src/iop/_iris.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from typing import Optional
33

4-
def get_iris(namespace: Optional[str]=None)->'iris':
4+
def get_iris(namespace: Optional[str]=None)->'iris': # type: ignore
55
if namespace:
66
os.environ['IRISNAMESPACE'] = namespace
77
import iris

src/iop/_log_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ def emit(self, record: logging.LogRecord) -> None:
7272
Args:
7373
record: The logging record to emit
7474
"""
75-
class_name = record.class_name if hasattr(record, "class_name") else record.name
76-
method_name = record.method_name if hasattr(record, "method_name") else record.funcName
77-
if self.to_console or (hasattr(record, "to_console") and record.to_console):
75+
class_name = record.class_name if hasattr(record, "class_name") else record.name # type: ignore has been added as extra attribute in LogRecord
76+
method_name = record.method_name if hasattr(record, "method_name") else record.funcName # type: ignore has been added as extra attribute in LogRecord
77+
if self.to_console or (hasattr(record, "to_console") and record.to_console): # type: ignore has been added as extra attribute in LogRecord
7878
_iris.get_iris().cls("%SYS.System").WriteToConsoleLog(self.format(record),
7979
0,self.level_map_console.get(record.levelno, 0),class_name+"."+method_name)
8080
else:

src/iop/_private_session_duplex.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ._business_host import _BusinessHost
44
from ._decorators import input_deserializer, input_serializer_param, output_serializer, input_serializer, output_deserializer
5+
from ._dispatch import create_dispatch, dispach_message
56

67
class _PrivateSessionDuplex(_BusinessHost):
78

@@ -27,7 +28,7 @@ def on_message(self, request):
2728
@output_serializer
2829
def _dispatch_on_message(self, request):
2930
""" For internal use only. """
30-
return self._dispach_message(request)
31+
return dispach_message(self,request)
3132

3233
def _set_iris_handles(self, handle_current, handle_partner):
3334
""" For internal use only. """

src/iop/_private_session_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def _dispatch_on_document(self, host_object,source_config_name, request):
1212
self._save_persistent_properties(host_object)
1313
return return_object
1414

15-
def on_document(source_config_name,request):
15+
def on_document(self,source_config_name,request):
1616
pass
1717

1818

src/iop/_serialization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _parse_classname(classname: str) -> tuple[str, str]:
107107
raise SerializationError(f"Classname must include a module: {classname}")
108108
return classname[:j], classname[j+1:]
109109

110-
def dataclass_from_dict(klass: Type, dikt: Dict) -> Any:
110+
def dataclass_from_dict(klass: Type | Any, dikt: Dict) -> Any:
111111
"""Converts a dictionary to a dataclass instance.
112112
Handles non attended fields and nested dataclasses."""
113113

src/iop/_utils.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import importlib.util
77
import importlib.resources
88
import json
9+
from typing import Any, Dict, Optional, Union
910

1011
import xmltodict
1112
from pydantic import TypeAdapter
@@ -25,13 +26,12 @@ def raise_on_error(sc):
2526
raise RuntimeError(_iris.get_iris().system.Status.GetOneStatusText(sc))
2627

2728
@staticmethod
28-
def setup(path:str = None):
29+
def setup(path:Optional[str] = None):
2930

3031
if path is None:
3132
# get the path of the data folder with importlib.resources
3233
try:
33-
path = importlib.resources.files('iop').joinpath('cls')
34-
path = str(path)
34+
path = str(importlib.resources.files('iop').joinpath('cls'))
3535
except ModuleNotFoundError:
3636
path = None
3737

@@ -40,29 +40,28 @@ def setup(path:str = None):
4040

4141
# for retrocompatibility load grongier.pex
4242
try:
43-
path = importlib.resources.files('grongier').joinpath('cls')
44-
path = str(path)
43+
path = str(importlib.resources.files('grongier').joinpath('cls'))
4544
except ModuleNotFoundError:
4645
path = None
4746

4847
if path:
4948
_Utils.raise_on_error(_iris.get_iris().cls('%SYSTEM.OBJ').LoadDir(path,'cubk',"*.cls",1))
5049

5150
@staticmethod
52-
def register_message_schema(cls):
51+
def register_message_schema(msg_cls: type):
5352
"""
5453
It takes a class and registers the schema
5554
5655
:param cls: The class to register
5756
"""
58-
if issubclass(cls,_PydanticMessage):
59-
schema = cls.model_json_schema()
60-
elif issubclass(cls,_Message):
61-
type_adapter = TypeAdapter(cls)
57+
if issubclass(msg_cls,_PydanticMessage):
58+
schema = msg_cls.model_json_schema()
59+
elif issubclass(msg_cls,_Message):
60+
type_adapter = TypeAdapter(msg_cls)
6261
schema = type_adapter.json_schema()
6362
else:
6463
raise ValueError("The class must be a subclass of _Message or _PydanticMessage")
65-
schema_name = cls.__module__ + '.' + cls.__name__
64+
schema_name = msg_cls.__module__ + '.' + msg_cls.__name__
6665
schema_str = json.dumps(schema)
6766
categories = schema_name
6867
_Utils.register_schema(schema_name,schema_str,categories)
@@ -172,10 +171,11 @@ def _register_file(filename:str,path:str,overwrite:int=1,iris_package_name:str='
172171
for klass in classes:
173172
extend = ''
174173
if len(klass.bases) == 1:
175-
if hasattr(klass.bases[0],'id'):
176-
extend = klass.bases[0].id
177-
else:
178-
extend = klass.bases[0].attr
174+
base = klass.bases[0]
175+
if isinstance(base, ast.Name):
176+
extend = base.id
177+
elif isinstance(base, ast.Attribute):
178+
extend = base.attr
179179
if extend in ('BusinessOperation','BusinessProcess','BusinessService','DuplexService','DuplexProcess','DuplexOperation','InboundAdapter','OutboundAdapter'):
180180
module = _Utils.filename_to_module(filename)
181181
iris_class_name = f"{iris_package_name}.{module}.{klass.name}"
@@ -283,7 +283,7 @@ def import_module_from_path(module_name, file_path):
283283
raise ValueError("The file path must be absolute")
284284

285285
spec = importlib.util.spec_from_file_location(module_name, file_path)
286-
if spec is None:
286+
if spec is None or spec.loader is None:
287287
raise ImportError(f"Cannot find module named {module_name} at {file_path}")
288288

289289
module = importlib.util.module_from_spec(spec)

0 commit comments

Comments
 (0)