diff --git a/README.rst b/README.rst index 4a1a17d..9958666 100644 --- a/README.rst +++ b/README.rst @@ -26,10 +26,13 @@ An example for when subscribing to the on_member_kick event. import discord from discord.ext import commands, events - from discord.ext.events import member_kick + from discord.ext.events import EventsMixin, CustomEventDispatcher class MyBot(commands.Bot, events.EventsMixin): + dispatcher = CustomEventDispatcher([ + 'member_kick', + ]) async def on_ready(self): print('Logged in!') diff --git a/discord/ext/events/__init__.py b/discord/ext/events/__init__.py index 39f3c55..39ab3ab 100644 --- a/discord/ext/events/__init__.py +++ b/discord/ext/events/__init__.py @@ -17,6 +17,7 @@ from collections import namedtuple from .mixins import EventsMixin +from .dispatcher import CustomEventDispatcher from . import utils diff --git a/discord/ext/events/_events.py b/discord/ext/events/_events.py deleted file mode 100644 index 4329ba9..0000000 --- a/discord/ext/events/_events.py +++ /dev/null @@ -1,5 +0,0 @@ -from typing import Callable, Dict - -_ALL: Dict[str, Callable] = { - # This is populated by subscribed events at runtime -} diff --git a/discord/ext/events/custom_events/__init__.py b/discord/ext/events/custom_events/__init__.py new file mode 100644 index 0000000..1bd3e9d --- /dev/null +++ b/discord/ext/events/custom_events/__init__.py @@ -0,0 +1,5 @@ +from typing import Callable, Dict + +# Filled at run time +_ALL: Dict[str, Callable] = {} + diff --git a/discord/ext/events/member_kick.py b/discord/ext/events/custom_events/member_kick.py similarity index 63% rename from discord/ext/events/member_kick.py rename to discord/ext/events/custom_events/member_kick.py index 966e1c3..034c4f2 100644 --- a/discord/ext/events/member_kick.py +++ b/discord/ext/events/custom_events/member_kick.py @@ -1,18 +1,12 @@ import discord -from ._events import _ALL -from .utils import fetch_recent_audit_log_entry, listens_for - - -EVENT_NAME = 'member_kick' +from ..utils import fetch_recent_audit_log_entry, listens_for @listens_for('member_remove') async def check_member_kick(client: discord.Client, member: discord.Member): guild = member.guild - print('!RAN') - if not guild.me.guild_permissions.view_audit_log: return @@ -20,7 +14,4 @@ async def check_member_kick(client: discord.Client, member: discord.Member): if entry is None: return - client.dispatch(EVENT_NAME, member, entry) - - -_ALL[EVENT_NAME] = check_member_kick + return member, entry diff --git a/discord/ext/events/dispatcher.py b/discord/ext/events/dispatcher.py new file mode 100644 index 0000000..af7711c --- /dev/null +++ b/discord/ext/events/dispatcher.py @@ -0,0 +1,27 @@ +import asyncio +from typing import List, Optional + +import discord + +from .custom_events import _ALL +from .errors import InvalidEvent + + +class CustomEventDispatcher: + def __init__(self, listening_to: Optional[List[str]]=None): + valid_handlers = _ALL + + if listening_to: + try: + valid_handlers = {name: _ALL[name] for name in listening_to} + except KeyError as e: + raise InvalidEvent('no registered handler for {!r}'.format(e.args[0])) + + self.valid_handlers = valid_handlers + + def handle(self, client: discord.Client, event: str, *args, **kwargs): + if event in self.valid_handlers: + return + + for event_check in self.valid_handlers.values(): + asyncio.ensure_future(event_check(client, event, *args, **kwargs)) diff --git a/discord/ext/events/errors.py b/discord/ext/events/errors.py new file mode 100644 index 0000000..6d21b69 --- /dev/null +++ b/discord/ext/events/errors.py @@ -0,0 +1,7 @@ +from discord.errors import DiscordException + +class EventsException(DiscordException): + pass + +class InvalidEvent(EventsException): + pass diff --git a/discord/ext/events/mixins.py b/discord/ext/events/mixins.py index 4c63f27..785ac0d 100644 --- a/discord/ext/events/mixins.py +++ b/discord/ext/events/mixins.py @@ -1,16 +1,11 @@ -import asyncio - import discord -from ._events import _ALL +from .dispatcher import CustomEventDispatcher class EventsMixin(discord.Client): - - async def on__event(self, event, *args, **kwargs): - for event_name, event_check in _ALL.items(): - asyncio.ensure_future(event_check(self, event, *args, **kwargs)) + dispatcher = CustomEventDispatcher() def dispatch(self, event, *args, **kwargs): super().dispatch(event, *args, **kwargs) # type: ignore - super().dispatch('_event', event, *args, **kwargs) + self.dispatcher.handle(self, event, *args, **kwargs) diff --git a/discord/ext/events/utils.py b/discord/ext/events/utils.py index 687cdb2..709658b 100644 --- a/discord/ext/events/utils.py +++ b/discord/ext/events/utils.py @@ -1,11 +1,12 @@ import asyncio import datetime - from functools import wraps from typing import Callable, Optional import discord +from .custom_events import _ALL + SLEEP_FOR = 2.5 @@ -68,11 +69,17 @@ def listens_for(*events: str) -> Callable: """ def decorator(func: Callable) -> Callable: + event_name = func.__name__[6:] + _ALL[event_name] = func @wraps(func) async def wrapper(client, event, *args, **kwargs): if event in events: - await func(client, *args, **kwargs) + result = await func(client, *args, **kwargs) + if result is not None: + if not isinstance(result, tuple): + result = (result,) + client.dispatch(event_name, *result) return wrapper diff --git a/setup.py b/setup.py index 15161d9..e35b1e8 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ python_requires=">=3.5.3", url="https://github.com/Ext-Creators/discord-ext-events", version=version, - packages=["discord.ext.events"], + packages=["discord.ext.events", "discord.ext.events.custom_events"], license="Apache Software License", description="Custom events derived from events dispatched by Discord. ", long_description=readme,