From 85be702aa69f76f1011e813ae3e4e417188856ba Mon Sep 17 00:00:00 2001 From: Jaya Venkatesh Date: Fri, 19 Sep 2025 13:04:13 -0700 Subject: [PATCH 1/4] added registration check for plugins Signed-off-by: Jaya Venkatesh --- distributed/client.py | 51 ++++++++++++++++++++++++++++++++++++++++ distributed/scheduler.py | 36 ++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/distributed/client.py b/distributed/client.py index 01fe47fa9c..499e484ce6 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -5456,6 +5456,57 @@ def unregister_worker_plugin(self, name, nanny=None): """ return self.sync(self._unregister_worker_plugin, name=name, nanny=nanny) + def has_plugin( + self, + name: str | list[str], + plugin_type: str = "worker" + ) -> bool | dict[str, bool]: + """Check if plugin(s) are registered + + Checks whether plugin(s) are registered in the scheduler's plugin registry. + This only verifies registration - not whether plugins are actually running + or functioning correctly. + + Parameters + ---------- + name : str or list[str] + Plugin name(s) to check + plugin_type : str, optional + Type of plugin: 'worker', 'scheduler', or 'nanny'. Defaults to 'worker'. + + Returns + ------- + bool or dict[str, bool] + If name is str: True if plugin is registered, False otherwise + If name is list: dict mapping names to registration status + + See Also + -------- + register_plugin + unregister_worker_plugin + """ + if isinstance(name, str): + result = self.sync( + self._get_plugin_registration_status, + names=[name], + plugin_type=plugin_type + ) + return result[name] + else: + return self.sync( + self._get_plugin_registration_status, + names=name, + plugin_type=plugin_type + ) + + async def _get_plugin_registration_status( + self, names: list[str], plugin_type: str + ) -> dict[str, bool]: + """Async implementation for checking plugin registration""" + return await self.scheduler.get_plugin_registration_status( + names=names, plugin_type=plugin_type + ) + @property def amm(self): """Convenience accessors for the :doc:`active_memory_manager`""" diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 2d5ee2c8cf..3d44ccebcc 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4039,6 +4039,7 @@ async def post(self) -> None: "unregister_worker_plugin": self.unregister_worker_plugin, "register_nanny_plugin": self.register_nanny_plugin, "unregister_nanny_plugin": self.unregister_nanny_plugin, + "get_plugin_registration_status": self.get_plugin_registration_status, "adaptive_target": self.adaptive_target, "workers_to_close": self.workers_to_close, "subscribe_worker_status": self.subscribe_worker_status, @@ -8696,6 +8697,41 @@ async def get_worker_monitor_info( ) return dict(zip(self.workers, results)) + async def get_plugin_registration_status( + self, names: list[str], plugin_type: str = "worker" + ) -> dict[str, bool]: + """Check if plugins are registered + + Parameters + ---------- + names : list[str] + List of plugin names to check + plugin_type : str, optional + Type of plugin to check: 'worker', 'scheduler', or 'nanny' + + Returns + ------- + dict[str, bool] + Dict mapping plugin names to their registration status + + Raises + ------ + ValueError + If plugin_type is not one of 'worker', 'scheduler', 'nanny' + """ + if plugin_type == "worker": + plugin_dict = self.worker_plugins + elif plugin_type == "scheduler": + plugin_dict = self.plugins + elif plugin_type == "nanny": + plugin_dict = self.nanny_plugins + else: + raise ValueError( + f"plugin_type must be 'worker', 'scheduler', or 'nanny', got {plugin_type!r}" + ) + + return {name: name in plugin_dict for name in names} + ########### # Cleanup # ########### From 4e88de0ac07e3f592200a412b2619dfbb6b1e594 Mon Sep 17 00:00:00 2001 From: Jaya Venkatesh Date: Tue, 23 Sep 2025 22:47:18 -0700 Subject: [PATCH 2/4] added function to check plugin status Signed-off-by: Jaya Venkatesh --- distributed/client.py | 78 +++++++++++++++++++++++----------------- distributed/scheduler.py | 61 +++++++++++++++++-------------- 2 files changed, 81 insertions(+), 58 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 499e484ce6..632ed72f25 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -5457,55 +5457,69 @@ def unregister_worker_plugin(self, name, nanny=None): return self.sync(self._unregister_worker_plugin, name=name, nanny=nanny) def has_plugin( - self, - name: str | list[str], - plugin_type: str = "worker" + self, plugin: str | WorkerPlugin | SchedulerPlugin | NannyPlugin | list ) -> bool | dict[str, bool]: """Check if plugin(s) are registered - - Checks whether plugin(s) are registered in the scheduler's plugin registry. - This only verifies registration - not whether plugins are actually running - or functioning correctly. - + Parameters ---------- - name : str or list[str] - Plugin name(s) to check - plugin_type : str, optional - Type of plugin: 'worker', 'scheduler', or 'nanny'. Defaults to 'worker'. - + plugin : str | plugin object | list + Plugin to check. You can use the plugin object directly or the plugin name. For plugin objects, they must have a 'name' attribute. You can also pass a list of plugin objects or names. + Returns ------- bool or dict[str, bool] If name is str: True if plugin is registered, False otherwise If name is list: dict mapping names to registration status - - See Also + + Examples -------- - register_plugin - unregister_worker_plugin + >>> logging_plugin = LoggingConfigPlugin() # Has name = "logging-config" + >>> client.register_plugin(logging_plugin) + >>> client.has_plugin(logging_plugin) + True + + >>> client.has_plugin('logging-config') + True + + >>> client.has_plugin([logging_plugin, 'other-plugin']) + {'logging-config': True, 'other-plugin': False} """ - if isinstance(name, str): + if isinstance(plugin, str): + result = self.sync(self._get_plugin_registration_status, names=[plugin]) + return result[plugin] + + elif isinstance(plugin, (WorkerPlugin, SchedulerPlugin, NannyPlugin)): + plugin_name = getattr(plugin, "name", None) + if plugin_name is None: + raise ValueError( + f"Plugin {funcname(type(plugin))} has no 'name' attribute. " + "Please add a 'name' attribute to your plugin class." + ) result = self.sync( - self._get_plugin_registration_status, - names=[name], - plugin_type=plugin_type - ) - return result[name] - else: - return self.sync( - self._get_plugin_registration_status, - names=name, - plugin_type=plugin_type + self._get_plugin_registration_status, names=[plugin_name] ) + return result[plugin_name] + + elif isinstance(plugin, list): + names_to_check = [] + for p in plugin: + if isinstance(p, str): + names_to_check.append(p) + else: + plugin_name = getattr(p, "name", None) + if plugin_name is None: + raise ValueError( + f"Plugin {funcname(type(p))} has no 'name' attribute" + ) + names_to_check.append(plugin_name) + return self.sync(self._get_plugin_registration_status, names=names_to_check) async def _get_plugin_registration_status( - self, names: list[str], plugin_type: str + self, names: list[str] ) -> dict[str, bool]: """Async implementation for checking plugin registration""" - return await self.scheduler.get_plugin_registration_status( - names=names, plugin_type=plugin_type - ) + return await self.scheduler.get_plugin_registration_status(names=names) @property def amm(self): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 3d44ccebcc..d717e8e4aa 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -8697,41 +8697,50 @@ async def get_worker_monitor_info( ) return dict(zip(self.workers, results)) - async def get_plugin_registration_status( - self, names: list[str], plugin_type: str = "worker" - ) -> dict[str, bool]: - """Check if plugins are registered - + async def get_plugin_registration_status(self, names: list[str]) -> dict[str, bool]: + """Check if plugins are registered in any plugin registry + + Checks all plugin registries (worker, scheduler, nanny) and returns True + if the plugin is found in any of them. + Parameters ---------- names : list[str] List of plugin names to check - plugin_type : str, optional - Type of plugin to check: 'worker', 'scheduler', or 'nanny' - + Returns ------- dict[str, bool] - Dict mapping plugin names to their registration status - - Raises - ------ - ValueError - If plugin_type is not one of 'worker', 'scheduler', 'nanny' + Dict mapping plugin names to their registration status across all registries """ - if plugin_type == "worker": - plugin_dict = self.worker_plugins - elif plugin_type == "scheduler": - plugin_dict = self.plugins - elif plugin_type == "nanny": - plugin_dict = self.nanny_plugins - else: - raise ValueError( - f"plugin_type must be 'worker', 'scheduler', or 'nanny', got {plugin_type!r}" + result = {} + for name in names: + # Check if plugin exists in any registry + result[name] = ( + name in self.worker_plugins + or name in self.plugins + or name in self.nanny_plugins ) - - return {name: name in plugin_dict for name in names} - + return result + + async def get_worker_plugin_registration_status( + self, names: list[str] + ) -> dict[str, bool]: + """Check if worker plugins are registered""" + return {name: name in self.worker_plugins for name in names} + + async def get_scheduler_plugin_registration_status( + self, names: list[str] + ) -> dict[str, bool]: + """Check if scheduler plugins are registered""" + return {name: name in self.plugins for name in names} + + async def get_nanny_plugin_registration_status( + self, names: list[str] + ) -> dict[str, bool]: + """Check if nanny plugins are registered""" + return {name: name in self.nanny_plugins for name in names} + ########### # Cleanup # ########### From 0ce0ffd3dc4a615babb1f40332f65f7cea8126df Mon Sep 17 00:00:00 2001 From: Jaya Venkatesh Date: Tue, 30 Sep 2025 20:20:12 -0700 Subject: [PATCH 3/4] refactor to move functions inside async Signed-off-by: Jaya Venkatesh --- distributed/client.py | 53 ++++--- .../diagnostics/tests/test_nanny_plugin.py | 147 ++++++++++++++++++ .../tests/test_scheduler_plugin.py | 138 ++++++++++++++++ .../diagnostics/tests/test_worker_plugin.py | 103 ++++++++++++ 4 files changed, 421 insertions(+), 20 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 632ed72f25..bfbc82c4f3 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -5457,20 +5457,22 @@ def unregister_worker_plugin(self, name, nanny=None): return self.sync(self._unregister_worker_plugin, name=name, nanny=nanny) def has_plugin( - self, plugin: str | WorkerPlugin | SchedulerPlugin | NannyPlugin | list + self, plugin: str | WorkerPlugin | SchedulerPlugin | NannyPlugin | Sequence ) -> bool | dict[str, bool]: """Check if plugin(s) are registered Parameters ---------- - plugin : str | plugin object | list - Plugin to check. You can use the plugin object directly or the plugin name. For plugin objects, they must have a 'name' attribute. You can also pass a list of plugin objects or names. + plugin : str | plugin object | Sequence + Plugin to check. You can use the plugin object directly or the plugin name. + For plugin objects, they must have a 'name' attribute. You can also pass + a sequence of plugin objects or names. Returns ------- bool or dict[str, bool] If name is str: True if plugin is registered, False otherwise - If name is list: dict mapping names to registration status + If name is Sequence: dict mapping names to registration status Examples -------- @@ -5485,10 +5487,17 @@ def has_plugin( >>> client.has_plugin([logging_plugin, 'other-plugin']) {'logging-config': True, 'other-plugin': False} """ - if isinstance(plugin, str): - result = self.sync(self._get_plugin_registration_status, names=[plugin]) - return result[plugin] + return self.sync(self._has_plugin_async, plugin=plugin) + async def _has_plugin_async( + self, plugin: str | WorkerPlugin | SchedulerPlugin | NannyPlugin | Sequence + ) -> bool | dict[str, bool]: + """Async implementation for checking plugin registration""" + + # Convert plugin to list of names + if isinstance(plugin, str): + names_to_check = [plugin] + return_single = True elif isinstance(plugin, (WorkerPlugin, SchedulerPlugin, NannyPlugin)): plugin_name = getattr(plugin, "name", None) if plugin_name is None: @@ -5496,12 +5505,9 @@ def has_plugin( f"Plugin {funcname(type(plugin))} has no 'name' attribute. " "Please add a 'name' attribute to your plugin class." ) - result = self.sync( - self._get_plugin_registration_status, names=[plugin_name] - ) - return result[plugin_name] - - elif isinstance(plugin, list): + names_to_check = [plugin_name] + return_single = True + elif isinstance(plugin, Sequence): names_to_check = [] for p in plugin: if isinstance(p, str): @@ -5513,13 +5519,20 @@ def has_plugin( f"Plugin {funcname(type(p))} has no 'name' attribute" ) names_to_check.append(plugin_name) - return self.sync(self._get_plugin_registration_status, names=names_to_check) - - async def _get_plugin_registration_status( - self, names: list[str] - ) -> dict[str, bool]: - """Async implementation for checking plugin registration""" - return await self.scheduler.get_plugin_registration_status(names=names) + return_single = False + else: + raise TypeError( + f"plugin must be a plugin object, name string, or Sequence. Got {type(plugin)}" + ) + + # Get status from scheduler + result = await self.scheduler.get_plugin_registration_status(names=names_to_check) + + # Return single bool or dict based on input + if return_single: + return result[names_to_check[0]] + else: + return result @property def amm(self): diff --git a/distributed/diagnostics/tests/test_nanny_plugin.py b/distributed/diagnostics/tests/test_nanny_plugin.py index 3c481dce26..8c0e1c19c0 100644 --- a/distributed/diagnostics/tests/test_nanny_plugin.py +++ b/distributed/diagnostics/tests/test_nanny_plugin.py @@ -217,3 +217,150 @@ async def test_nanny_plugin_with_broken_teardown_logs_on_close(c, s): logs = caplog.getvalue() assert "TestPlugin1 failed to teardown" in logs assert "test error" in logs + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_has_nanny_plugin_by_name(c, s, a): + """Test checking if nanny plugin is registered using string name""" + + class DuckPlugin(NannyPlugin): + name = "duck-plugin" + + def setup(self, nanny): + nanny.foo = 123 + + def teardown(self, nanny): + pass + + # Check non-existent plugin + assert not await c.has_plugin("duck-plugin") + + # Register plugin + await c.register_plugin(DuckPlugin()) + assert a.foo == 123 + + # Check using string name + assert await c.has_plugin("duck-plugin") + + # Unregister and check again + await c.unregister_worker_plugin("duck-plugin", nanny=True) + assert not await c.has_plugin("duck-plugin") + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_has_nanny_plugin_by_object(c, s, a): + """Test checking if nanny plugin is registered using plugin object""" + + class DuckPlugin(NannyPlugin): + name = "duck-plugin" + + def setup(self, nanny): + nanny.bar = 456 + + def teardown(self, nanny): + pass + + plugin = DuckPlugin() + + # Check before registration + assert not await c.has_plugin(plugin) + + # Register and check + await c.register_plugin(plugin) + assert a.bar == 456 + assert await c.has_plugin(plugin) + + # Unregister and check + await c.unregister_worker_plugin("duck-plugin", nanny=True) + assert not await c.has_plugin(plugin) + + +@gen_cluster(client=True, nthreads=[("", 1), ("", 1)], Worker=Nanny) +async def test_has_nanny_plugin_multiple_nannies(c, s, a, b): + """Test checking nanny plugin with multiple nannies""" + + class DuckPlugin(NannyPlugin): + name = "duck-plugin" + + def setup(self, nanny): + nanny.multi = "setup" + + def teardown(self, nanny): + pass + + # Check before registration + assert not await c.has_plugin("duck-plugin") + + # Register plugin (should propagate to all nannies) + await c.register_plugin(DuckPlugin()) + + # Verify both nannies have the plugin + assert a.multi == "setup" + assert b.multi == "setup" + + # Check plugin is registered + assert await c.has_plugin("duck-plugin") + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_has_nanny_plugin_custom_name_override(c, s, a): + """Test nanny plugin registered with custom name different from class name""" + + class DuckPlugin(NannyPlugin): + name = "duck-plugin" + + def setup(self, nanny): + nanny.custom = "test" + + def teardown(self, nanny): + pass + + plugin = DuckPlugin() + + # Register with custom name (overriding the class name attribute) + await c.register_plugin(plugin, name="custom-override") + + # Check with custom name works + assert await c.has_plugin("custom-override") + + # Original name won't work since we overrode it + assert not await c.has_plugin("duck-plugin") + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_has_nanny_plugin_list_check(c, s, a): + """Test checking multiple nanny plugins at once""" + + class IdempotentPlugin(NannyPlugin): + name = "idempotentplugin" + + def setup(self, nanny): + pass + + def teardown(self, nanny): + pass + + class NonIdempotentPlugin(NannyPlugin): + name = "nonidempotentplugin" + + def setup(self, nanny): + pass + + def teardown(self, nanny): + pass + + # Check multiple before registration + result = await c.has_plugin(["idempotentplugin", "nonidempotentplugin", "nonexistent"]) + assert result == { + "idempotentplugin": False, + "nonidempotentplugin": False, + "nonexistent": False, + } + + # Register first plugin + await c.register_plugin(IdempotentPlugin()) + result = await c.has_plugin(["idempotentplugin", "nonidempotentplugin"]) + assert result == {"idempotentplugin": True, "nonidempotentplugin": False} + + # Register second plugin + await c.register_plugin(NonIdempotentPlugin()) + result = await c.has_plugin(["idempotentplugin", "nonidempotentplugin"]) + assert result == {"idempotentplugin": True, "nonidempotentplugin": True} diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index d520524291..a8aba5da5d 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -753,3 +753,141 @@ def __init__(self, instance=None): await s.register_scheduler_plugin(plugin=dumps(third)) assert "nonidempotentplugin" in s.plugins assert s.plugins["nonidempotentplugin"].instance == "third" + +@gen_cluster(client=True) +async def test_has_scheduler_plugin_by_name(c, s, a, b): + """Test checking if scheduler plugin is registered using string name""" + + class Dummy1(SchedulerPlugin): + name = "Dummy1" + + def start(self, scheduler): + scheduler.foo = "bar" + + # Check non-existent plugin + assert not await c.has_plugin("Dummy1") + + # Register plugin + await c.register_plugin(Dummy1()) + assert s.foo == "bar" + + # Check using string name + assert await c.has_plugin("Dummy1") + + # Unregister and check again + await c.unregister_scheduler_plugin("Dummy1") + assert not await c.has_plugin("Dummy1") + + +@gen_cluster(client=True) +async def test_has_scheduler_plugin_by_object(c, s, a, b): + """Test checking if scheduler plugin is registered using plugin object""" + + class Dummy2(SchedulerPlugin): + name = "Dummy2" + + def start(self, scheduler): + scheduler.check_value = 42 + + plugin = Dummy2() + + # Check before registration + assert not await c.has_plugin(plugin) + + # Register and check + await c.register_plugin(plugin) + assert s.check_value == 42 + assert await c.has_plugin(plugin) + + # Unregister and check + await c.unregister_scheduler_plugin("Dummy2") + assert not await c.has_plugin(plugin) + + +@gen_cluster(client=True) +async def test_has_plugin_mixed_scheduler_and_worker_types(c, s, a, b): + """Test checking scheduler and worker plugins together""" + from distributed import WorkerPlugin + + class MyPlugin(SchedulerPlugin): + name = "MyPlugin" + + def start(self, scheduler): + scheduler.my_value = "scheduler" + + class MyWorkerPlugin(WorkerPlugin): + name = "MyWorkerPlugin" + + def setup(self, worker): + worker.my_value = "worker" + + sched_plugin = MyPlugin() + work_plugin = MyWorkerPlugin() + + # Register both types + await c.register_plugin(sched_plugin) + await c.register_plugin(work_plugin) + + # Verify both registered + assert s.my_value == "scheduler" + assert a.my_value == "worker" + assert b.my_value == "worker" + + # Check both with list of names + result = await c.has_plugin(["MyPlugin", "MyWorkerPlugin"]) + assert result == {"MyPlugin": True, "MyWorkerPlugin": True} + + # Check both with objects + assert await c.has_plugin(sched_plugin) + assert await c.has_plugin(work_plugin) + + # Check non-existent alongside real ones + result = await c.has_plugin(["MyPlugin", "nonexistent", "MyWorkerPlugin"]) + assert result == { + "MyPlugin": True, + "nonexistent": False, + "MyWorkerPlugin": True + } + + +@gen_cluster(client=True, nthreads=[]) +async def test_has_scheduler_plugin_no_workers(c, s): + """Test checking scheduler plugin when no workers exist""" + + class Plugin(SchedulerPlugin): + name = "plugin" + + def start(self, scheduler): + scheduler.no_worker_test = True + + # Check before registration + assert not await c.has_plugin("plugin") + + # Register plugin when no workers exist + await c.register_plugin(Plugin()) + assert s.no_worker_test is True + + # Check after registration + assert await c.has_plugin("plugin") + + +@gen_cluster(client=True) +async def test_has_scheduler_plugin_custom_name_override(c, s, a, b): + """Test scheduler plugin registered with custom name different from class name""" + + class Dummy3(SchedulerPlugin): + name = "Dummy3" + + def start(self, scheduler): + scheduler.name_test = "custom" + + plugin = Dummy3() + + # Register with custom name (overriding the class name attribute) + await c.register_plugin(plugin, name="custom-override") + + # Check with custom name works + assert await c.has_plugin("custom-override") + + # Original name won't work since we overrode it + assert not await c.has_plugin("Dummy3") diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index 001576afe3..05e2295bc7 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -479,3 +479,106 @@ async def test_plugin_with_broken_teardown_logs_on_close(c, s): logs = caplog.getvalue() assert "TestPlugin1 failed to teardown" in logs assert "test error" in logs + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_has_worker_plugin_by_name(c, s, a): + """Test checking if worker plugin is registered using string name""" + + class MyPlugin(WorkerPlugin): + name = "MyPlugin" + + def __init__(self, data, expected_notifications=None): + self.data = data + self.expected_notifications = expected_notifications + + # Check non-existent plugin + assert not await c.has_plugin("MyPlugin") # ← await + + # Register plugin + await c.register_plugin(MyPlugin(123, None)) + + # Check using string name + assert await c.has_plugin("MyPlugin") # ← await + + # Unregister and check again + await c.unregister_worker_plugin("MyPlugin") + assert not await c.has_plugin("MyPlugin") # ← await + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_has_worker_plugin_by_object(c, s, a): + """Test checking if worker plugin is registered using plugin object""" + plugin = MyPlugin(456) + + # Check before registration + assert not await c.has_plugin(plugin) # ← await + + # Register and check + await c.register_plugin(plugin) + assert await c.has_plugin(plugin) # ← await + + # Unregister and check + await c.unregister_worker_plugin("MyPlugin") + assert not await c.has_plugin(plugin) # ← await + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_has_plugin_list(c, s, a): + """Test checking multiple plugins at once""" + plugin1 = MyPlugin(1) + + class AnotherPlugin(WorkerPlugin): + name = "AnotherPlugin" + + plugin2 = AnotherPlugin() + + # Check multiple plugins before registration + result = await c.has_plugin(["MyPlugin", "AnotherPlugin", "NonExistent"]) # ← await + assert result == { + "MyPlugin": False, + "AnotherPlugin": False, + "NonExistent": False, + } + + # Register first plugin + await c.register_plugin(plugin1) + result = await c.has_plugin(["MyPlugin", "AnotherPlugin"]) # ← await + assert result == {"MyPlugin": True, "AnotherPlugin": False} + + # Register second plugin + await c.register_plugin(plugin2) + result = await c.has_plugin(["MyPlugin", "AnotherPlugin"]) # ← await + assert result == {"MyPlugin": True, "AnotherPlugin": True} + + # Can also pass list of objects + result = await c.has_plugin([plugin1, plugin2]) # ← await + assert result == {"MyPlugin": True, "AnotherPlugin": True} + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_has_plugin_without_name_attribute(c, s, a): + """Test error when plugin has no name attribute""" + + class PluginWithoutName(WorkerPlugin): + pass # No name attribute + + plugin = PluginWithoutName() + + # Should raise error when checking + with pytest.raises(ValueError, match="has no 'name' attribute"): + await c.has_plugin(plugin) # ← await + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_has_plugin_custom_name(c, s, a): + """Test plugin registered with custom name""" + plugin = MyPlugin(789) + + # Register with custom name + await c.register_plugin(plugin, name="custom-name") + + # Check with custom name + assert await c.has_plugin("custom-name") # ← await + + # Original name won't work + assert not await c.has_plugin("MyPlugin") # ← await \ No newline at end of file From e5cf6331e1da91c49d898ed151aedb8b0aed3c27 Mon Sep 17 00:00:00 2001 From: Jaya Venkatesh Date: Tue, 30 Sep 2025 20:21:24 -0700 Subject: [PATCH 4/4] precommit Signed-off-by: Jaya Venkatesh --- distributed/client.py | 14 +-- .../diagnostics/tests/test_nanny_plugin.py | 80 +++++++-------- .../tests/test_scheduler_plugin.py | 97 +++++++++---------- .../diagnostics/tests/test_worker_plugin.py | 45 ++++----- 4 files changed, 120 insertions(+), 116 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index bfbc82c4f3..83bfaa336a 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -5464,8 +5464,8 @@ def has_plugin( Parameters ---------- plugin : str | plugin object | Sequence - Plugin to check. You can use the plugin object directly or the plugin name. - For plugin objects, they must have a 'name' attribute. You can also pass + Plugin to check. You can use the plugin object directly or the plugin name. + For plugin objects, they must have a 'name' attribute. You can also pass a sequence of plugin objects or names. Returns @@ -5493,7 +5493,7 @@ async def _has_plugin_async( self, plugin: str | WorkerPlugin | SchedulerPlugin | NannyPlugin | Sequence ) -> bool | dict[str, bool]: """Async implementation for checking plugin registration""" - + # Convert plugin to list of names if isinstance(plugin, str): names_to_check = [plugin] @@ -5524,10 +5524,12 @@ async def _has_plugin_async( raise TypeError( f"plugin must be a plugin object, name string, or Sequence. Got {type(plugin)}" ) - + # Get status from scheduler - result = await self.scheduler.get_plugin_registration_status(names=names_to_check) - + result = await self.scheduler.get_plugin_registration_status( + names=names_to_check + ) + # Return single bool or dict based on input if return_single: return result[names_to_check[0]] diff --git a/distributed/diagnostics/tests/test_nanny_plugin.py b/distributed/diagnostics/tests/test_nanny_plugin.py index 8c0e1c19c0..256772542e 100644 --- a/distributed/diagnostics/tests/test_nanny_plugin.py +++ b/distributed/diagnostics/tests/test_nanny_plugin.py @@ -218,29 +218,30 @@ async def test_nanny_plugin_with_broken_teardown_logs_on_close(c, s): assert "TestPlugin1 failed to teardown" in logs assert "test error" in logs + @gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) async def test_has_nanny_plugin_by_name(c, s, a): """Test checking if nanny plugin is registered using string name""" - + class DuckPlugin(NannyPlugin): name = "duck-plugin" - + def setup(self, nanny): nanny.foo = 123 - + def teardown(self, nanny): pass - + # Check non-existent plugin assert not await c.has_plugin("duck-plugin") - + # Register plugin await c.register_plugin(DuckPlugin()) assert a.foo == 123 - + # Check using string name assert await c.has_plugin("duck-plugin") - + # Unregister and check again await c.unregister_worker_plugin("duck-plugin", nanny=True) assert not await c.has_plugin("duck-plugin") @@ -249,26 +250,26 @@ def teardown(self, nanny): @gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) async def test_has_nanny_plugin_by_object(c, s, a): """Test checking if nanny plugin is registered using plugin object""" - + class DuckPlugin(NannyPlugin): name = "duck-plugin" - + def setup(self, nanny): nanny.bar = 456 - + def teardown(self, nanny): pass - + plugin = DuckPlugin() - + # Check before registration assert not await c.has_plugin(plugin) - + # Register and check await c.register_plugin(plugin) assert a.bar == 456 assert await c.has_plugin(plugin) - + # Unregister and check await c.unregister_worker_plugin("duck-plugin", nanny=True) assert not await c.has_plugin(plugin) @@ -277,50 +278,51 @@ def teardown(self, nanny): @gen_cluster(client=True, nthreads=[("", 1), ("", 1)], Worker=Nanny) async def test_has_nanny_plugin_multiple_nannies(c, s, a, b): """Test checking nanny plugin with multiple nannies""" - + class DuckPlugin(NannyPlugin): name = "duck-plugin" - + def setup(self, nanny): nanny.multi = "setup" - + def teardown(self, nanny): pass - + # Check before registration assert not await c.has_plugin("duck-plugin") - + # Register plugin (should propagate to all nannies) await c.register_plugin(DuckPlugin()) - + # Verify both nannies have the plugin assert a.multi == "setup" assert b.multi == "setup" - + # Check plugin is registered assert await c.has_plugin("duck-plugin") + @gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) async def test_has_nanny_plugin_custom_name_override(c, s, a): """Test nanny plugin registered with custom name different from class name""" - + class DuckPlugin(NannyPlugin): name = "duck-plugin" - + def setup(self, nanny): nanny.custom = "test" - + def teardown(self, nanny): pass - + plugin = DuckPlugin() - + # Register with custom name (overriding the class name attribute) await c.register_plugin(plugin, name="custom-override") - + # Check with custom name works assert await c.has_plugin("custom-override") - + # Original name won't work since we overrode it assert not await c.has_plugin("duck-plugin") @@ -328,38 +330,40 @@ def teardown(self, nanny): @gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) async def test_has_nanny_plugin_list_check(c, s, a): """Test checking multiple nanny plugins at once""" - + class IdempotentPlugin(NannyPlugin): name = "idempotentplugin" - + def setup(self, nanny): pass - + def teardown(self, nanny): pass - + class NonIdempotentPlugin(NannyPlugin): name = "nonidempotentplugin" - + def setup(self, nanny): pass - + def teardown(self, nanny): pass - + # Check multiple before registration - result = await c.has_plugin(["idempotentplugin", "nonidempotentplugin", "nonexistent"]) + result = await c.has_plugin( + ["idempotentplugin", "nonidempotentplugin", "nonexistent"] + ) assert result == { "idempotentplugin": False, "nonidempotentplugin": False, "nonexistent": False, } - + # Register first plugin await c.register_plugin(IdempotentPlugin()) result = await c.has_plugin(["idempotentplugin", "nonidempotentplugin"]) assert result == {"idempotentplugin": True, "nonidempotentplugin": False} - + # Register second plugin await c.register_plugin(NonIdempotentPlugin()) result = await c.has_plugin(["idempotentplugin", "nonidempotentplugin"]) diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index a8aba5da5d..dd580ed09f 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -754,140 +754,137 @@ def __init__(self, instance=None): assert "nonidempotentplugin" in s.plugins assert s.plugins["nonidempotentplugin"].instance == "third" + @gen_cluster(client=True) async def test_has_scheduler_plugin_by_name(c, s, a, b): """Test checking if scheduler plugin is registered using string name""" - + class Dummy1(SchedulerPlugin): name = "Dummy1" - + def start(self, scheduler): scheduler.foo = "bar" - + # Check non-existent plugin - assert not await c.has_plugin("Dummy1") - + assert not await c.has_plugin("Dummy1") + # Register plugin await c.register_plugin(Dummy1()) assert s.foo == "bar" - + # Check using string name - assert await c.has_plugin("Dummy1") - + assert await c.has_plugin("Dummy1") + # Unregister and check again await c.unregister_scheduler_plugin("Dummy1") - assert not await c.has_plugin("Dummy1") + assert not await c.has_plugin("Dummy1") @gen_cluster(client=True) async def test_has_scheduler_plugin_by_object(c, s, a, b): """Test checking if scheduler plugin is registered using plugin object""" - + class Dummy2(SchedulerPlugin): name = "Dummy2" - + def start(self, scheduler): scheduler.check_value = 42 - + plugin = Dummy2() - + # Check before registration - assert not await c.has_plugin(plugin) - + assert not await c.has_plugin(plugin) + # Register and check await c.register_plugin(plugin) assert s.check_value == 42 - assert await c.has_plugin(plugin) - + assert await c.has_plugin(plugin) + # Unregister and check await c.unregister_scheduler_plugin("Dummy2") - assert not await c.has_plugin(plugin) + assert not await c.has_plugin(plugin) @gen_cluster(client=True) async def test_has_plugin_mixed_scheduler_and_worker_types(c, s, a, b): """Test checking scheduler and worker plugins together""" from distributed import WorkerPlugin - + class MyPlugin(SchedulerPlugin): name = "MyPlugin" - + def start(self, scheduler): scheduler.my_value = "scheduler" - + class MyWorkerPlugin(WorkerPlugin): name = "MyWorkerPlugin" - + def setup(self, worker): worker.my_value = "worker" - + sched_plugin = MyPlugin() work_plugin = MyWorkerPlugin() - + # Register both types await c.register_plugin(sched_plugin) await c.register_plugin(work_plugin) - + # Verify both registered assert s.my_value == "scheduler" assert a.my_value == "worker" assert b.my_value == "worker" - + # Check both with list of names - result = await c.has_plugin(["MyPlugin", "MyWorkerPlugin"]) + result = await c.has_plugin(["MyPlugin", "MyWorkerPlugin"]) assert result == {"MyPlugin": True, "MyWorkerPlugin": True} - + # Check both with objects - assert await c.has_plugin(sched_plugin) - assert await c.has_plugin(work_plugin) - + assert await c.has_plugin(sched_plugin) + assert await c.has_plugin(work_plugin) + # Check non-existent alongside real ones result = await c.has_plugin(["MyPlugin", "nonexistent", "MyWorkerPlugin"]) - assert result == { - "MyPlugin": True, - "nonexistent": False, - "MyWorkerPlugin": True - } + assert result == {"MyPlugin": True, "nonexistent": False, "MyWorkerPlugin": True} @gen_cluster(client=True, nthreads=[]) async def test_has_scheduler_plugin_no_workers(c, s): """Test checking scheduler plugin when no workers exist""" - + class Plugin(SchedulerPlugin): name = "plugin" - + def start(self, scheduler): scheduler.no_worker_test = True - + # Check before registration - assert not await c.has_plugin("plugin") - + assert not await c.has_plugin("plugin") + # Register plugin when no workers exist await c.register_plugin(Plugin()) assert s.no_worker_test is True - + # Check after registration - assert await c.has_plugin("plugin") + assert await c.has_plugin("plugin") @gen_cluster(client=True) async def test_has_scheduler_plugin_custom_name_override(c, s, a, b): """Test scheduler plugin registered with custom name different from class name""" - + class Dummy3(SchedulerPlugin): name = "Dummy3" - + def start(self, scheduler): scheduler.name_test = "custom" - + plugin = Dummy3() - + # Register with custom name (overriding the class name attribute) await c.register_plugin(plugin, name="custom-override") - + # Check with custom name works - assert await c.has_plugin("custom-override") - + assert await c.has_plugin("custom-override") + # Original name won't work since we overrode it assert not await c.has_plugin("Dummy3") diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index 05e2295bc7..83eac3bbb9 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -480,26 +480,27 @@ async def test_plugin_with_broken_teardown_logs_on_close(c, s): assert "TestPlugin1 failed to teardown" in logs assert "test error" in logs + @gen_cluster(client=True, nthreads=[("", 1)]) async def test_has_worker_plugin_by_name(c, s, a): """Test checking if worker plugin is registered using string name""" - + class MyPlugin(WorkerPlugin): name = "MyPlugin" - + def __init__(self, data, expected_notifications=None): self.data = data self.expected_notifications = expected_notifications - + # Check non-existent plugin assert not await c.has_plugin("MyPlugin") # ← await - + # Register plugin await c.register_plugin(MyPlugin(123, None)) - + # Check using string name assert await c.has_plugin("MyPlugin") # ← await - + # Unregister and check again await c.unregister_worker_plugin("MyPlugin") assert not await c.has_plugin("MyPlugin") # ← await @@ -509,14 +510,14 @@ def __init__(self, data, expected_notifications=None): async def test_has_worker_plugin_by_object(c, s, a): """Test checking if worker plugin is registered using plugin object""" plugin = MyPlugin(456) - + # Check before registration assert not await c.has_plugin(plugin) # ← await - + # Register and check await c.register_plugin(plugin) assert await c.has_plugin(plugin) # ← await - + # Unregister and check await c.unregister_worker_plugin("MyPlugin") assert not await c.has_plugin(plugin) # ← await @@ -526,12 +527,12 @@ async def test_has_worker_plugin_by_object(c, s, a): async def test_has_plugin_list(c, s, a): """Test checking multiple plugins at once""" plugin1 = MyPlugin(1) - + class AnotherPlugin(WorkerPlugin): name = "AnotherPlugin" - + plugin2 = AnotherPlugin() - + # Check multiple plugins before registration result = await c.has_plugin(["MyPlugin", "AnotherPlugin", "NonExistent"]) # ← await assert result == { @@ -539,17 +540,17 @@ class AnotherPlugin(WorkerPlugin): "AnotherPlugin": False, "NonExistent": False, } - + # Register first plugin await c.register_plugin(plugin1) result = await c.has_plugin(["MyPlugin", "AnotherPlugin"]) # ← await assert result == {"MyPlugin": True, "AnotherPlugin": False} - + # Register second plugin await c.register_plugin(plugin2) result = await c.has_plugin(["MyPlugin", "AnotherPlugin"]) # ← await assert result == {"MyPlugin": True, "AnotherPlugin": True} - + # Can also pass list of objects result = await c.has_plugin([plugin1, plugin2]) # ← await assert result == {"MyPlugin": True, "AnotherPlugin": True} @@ -558,12 +559,12 @@ class AnotherPlugin(WorkerPlugin): @gen_cluster(client=True, nthreads=[("", 1)]) async def test_has_plugin_without_name_attribute(c, s, a): """Test error when plugin has no name attribute""" - + class PluginWithoutName(WorkerPlugin): pass # No name attribute - + plugin = PluginWithoutName() - + # Should raise error when checking with pytest.raises(ValueError, match="has no 'name' attribute"): await c.has_plugin(plugin) # ← await @@ -573,12 +574,12 @@ class PluginWithoutName(WorkerPlugin): async def test_has_plugin_custom_name(c, s, a): """Test plugin registered with custom name""" plugin = MyPlugin(789) - + # Register with custom name await c.register_plugin(plugin, name="custom-name") - + # Check with custom name assert await c.has_plugin("custom-name") # ← await - + # Original name won't work - assert not await c.has_plugin("MyPlugin") # ← await \ No newline at end of file + assert not await c.has_plugin("MyPlugin") # ← await