diff --git a/cogs/poll_controls.py b/cogs/poll_controls.py index 283fd9e..2605724 100644 --- a/cogs/poll_controls.py +++ b/cogs/poll_controls.py @@ -637,6 +637,7 @@ async def route(poll): await poll.set_options_reaction(ctx) await poll.set_survey_flags(ctx) await poll.set_multiple_choice(ctx) + await poll.set_ranked_choice_voting(ctx) await poll.set_hide_vote_count(ctx) await poll.set_roles(ctx) await poll.set_weights(ctx) @@ -662,6 +663,7 @@ async def route(poll): await poll.set_options_reaction(ctx) await poll.set_survey_flags(ctx) await poll.set_multiple_choice(ctx) + await poll.set_ranked_choice_voting(ctx) await poll.set_hide_vote_count(ctx) await poll.set_roles(ctx) await poll.set_weights(ctx) @@ -687,6 +689,7 @@ async def route(poll): await poll.set_options_reaction(ctx) await poll.set_survey_flags(ctx, force='0') await poll.set_multiple_choice(ctx) + await poll.set_ranked_choice_voting(ctx) await poll.set_hide_vote_count(ctx, force='no') await poll.set_roles(ctx, force='all') await poll.set_weights(ctx, force='none') diff --git a/models/poll.py b/models/poll.py index ef412fb..9db2bf0 100644 --- a/models/poll.py +++ b/models/poll.py @@ -88,6 +88,8 @@ def __init__(self, bot, ctx=None, server=None, channel=None, load=False): self.votes = {} self.wizard_messages = [] + + self.rcv = False @staticmethod def get_preset_options(number): @@ -437,6 +439,56 @@ async def get_valid(in_reply): except OutOfRange: await self.add_error(message, '**You can\'t have more choices than options.**') + async def set_ranked_choice_voting(self, ctx, force=None): + """Determine if poll is ranked choice voting.""" + async def get_valid(in_reply): + if self.multiple_choice == 1: + return False + if not in_reply: + raise InvalidInput + is_true = ['yes', '1'] + is_false = ['no', '0'] + in_reply = self.sanitize_string(in_reply) + if not in_reply: + raise InvalidInput + elif in_reply.lower() in is_true: + return True + elif in_reply.lower() in is_false: + return False + else: + raise InvalidInput + + try: + self.rcv = await get_valid(force) + return + except InputError: + pass + + text = ("Next you need to decide: **Do you want your poll to be ranked-choice voting?**\n" + "\n" + "`0 - No`\n" + "`1 - Yes`\n" + "\n" + "🔶 An ranked-choice voting has the following effects:\n" + "🔶 After chosing multiple choices, only first-preference votes are counted\n" + "🔶 If none of the choices is voted by a majority, the choice with the fewest first-preference votes is eliminated.\n" + "🔶 All first-preference votes for the failed choice are eliminated, lifting the second-preference choices indicated by those users.\n" + "🔶 The process repeats until choice wins a majority of votes") + message = await self.wizard_says(ctx, text) + + while True: + try: + if force: + reply = force + force = None + else: + reply = await self.get_user_reply(ctx) + self.rcv = await get_valid(reply) + await self.add_vaild(message, f'{"Yes" if self.rcv else "No"}') + break + except InvalidInput: + await self.add_error(message, '**You can only answer with `yes` | `1` or `no` | `0`!**') + async def set_options_reaction(self, ctx, force=None): """Set the answers / options of the Poll.""" async def get_valid(in_reply): @@ -929,6 +981,7 @@ async def to_dict(self): 'multiple_choice': self.multiple_choice, 'options_reaction': self.options_reaction, 'reaction_default': self.options_reaction_default, + 'rcv': self.rcv, #'options_traditional': self.options_traditional, 'survey_flags': self.survey_flags, 'roles': self.roles, @@ -948,7 +1001,10 @@ async def to_export(self): """Create report and return string""" # load all votes from database await self.load_full_votes() - await self.load_vote_counts() + if self.rcv: + await self.load_vote_counts_rcv() + else: + await self.load_vote_counts() await self.load_unique_participants() # build string for weights weight_str = 'No weights' @@ -1131,6 +1187,7 @@ async def from_dict(self, d): self.name = d['name'] self.short = d['short'] self.anonymous = d['anonymous'] + self.rcv = d['rcv'] # backwards compatibility if 'hide_count' in d.keys(): @@ -1220,6 +1277,29 @@ async def load_vote_counts(self): self.vote_counts_weighted[v.choice] = self.vote_counts_weighted.get(v.choice, 0) + v.weight else: self.vote_counts_weighted = self.vote_counts + + async def load_vote_counts_rcv(self): + if not self.vote_counts: + self.vote_counts = await Vote.load_vote_counts_for_poll(self.bot, self.id) + user_counts = await Vote.load_votes_for_rcv_poll(self.bot, self.id) + while user_counts: + weights_count = 0 + for votes in user_counts.values(): + if len(votes) > 0: + weights_count += votes[0]['weight'] + vote_counts = {} + for votes in user_counts.values(): + if len(votes) > 0: + vote = votes[0] + vote_counts[vote['choice']] = vote_counts.get(vote['choice'], 0) + vote['weight'] + if any(count > weights_count / 2 for count in vote_counts.values()): + self.vote_counts_weighted = vote_counts + break + eliminated = min(vote_counts.values()) + for i in user_counts.keys(): + for j in list(user_counts[i]): + if vote_counts.get(j['choice']) == eliminated: + user_counts[i].remove(j) async def load_full_votes(self): if not self.full_votes: @@ -1281,7 +1361,10 @@ async def generate_embed(self): embed = self.add_field_custom(name='**Deadline**', value=await self.get_poll_status(), embed=embed) # embed = self.add_field_custom(name='**Author**', value=self.author.name, embed=embed) - await self.load_vote_counts() + if self.rcv: + await self.load_vote_counts_rcv() + else: + await self.load_vote_counts() if self.options_reaction_default: if await self.is_open(): text = f'**Score** ' diff --git a/models/vote.py b/models/vote.py index 2d86ba8..3900581 100644 --- a/models/vote.py +++ b/models/vote.py @@ -53,6 +53,18 @@ async def load_vote_counts_for_poll(bot, poll_id: ObjectId,): result[q['_id']] = q['count'] return result + @staticmethod + async def load_votes_for_rcv_poll(bot, poll_id: ObjectId,): + pipeline = [ + {"$match": {'poll_id': poll_id}}, + {"$group": {"_id": "$user_id", "choice": {"$push": {"choice": "$choice", "weight": "$weight"}}}} + ] + query = bot.db.votes.aggregate(pipeline) + result = {} + async for q in query: + result[q['_id']] = q['choice'] + return result + @staticmethod async def load_votes_for_poll_and_user(bot, poll_id: ObjectId, user_id): user_id = str(user_id) diff --git a/requirements.txt b/requirements.txt index ef8605a..d16c0b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,13 @@ aiohttp async-timeout -asyncpg==0.21.0 +asyncpg==0.26.0 attrs==20.2.0 certifi==2020.6.20 chardet==3.0.4 cycler==0.10.0 dateparser==0.7.4 dblpy -discord.py==1.5.1 +discord.py==1.7.3 idna==2.10 idna-ssl==1.1.0 kiwisolver==1.3.0