diff --git a/bot.py b/bot.py index 7ab3b9a..7d884c9 100644 --- a/bot.py +++ b/bot.py @@ -9,7 +9,6 @@ import datetime import logging import logging.handlers -import re import subprocess import traceback import uuid @@ -17,6 +16,7 @@ import aiohttp import discord +import re2 from beanie import init_beanie from discord import app_commands from discord.ext import commands @@ -285,8 +285,8 @@ async def on_message(self, message: discord.Message) -> None: spam: commands.Cog | SpamManager = self.get_cog("SpamManager") await spam.store_and_validate(message) - legacy_command: list[str] = re.findall( - rf"^{re.escape(BOT_PREFIX)}\s*(\w+)", + legacy_command: list[str] = re2.findall( + rf"^{re2.escape(BOT_PREFIX)}\s*(\w+)", message.content, ) if message.content and len(legacy_command): diff --git a/migrations/20250819210945_re2_pings.py b/migrations/20250819210945_re2_pings.py new file mode 100644 index 0000000..a0aa004 --- /dev/null +++ b/migrations/20250819210945_re2_pings.py @@ -0,0 +1,51 @@ +from typing import Annotated + +from beanie import Document, Indexed, iterative_migration +from pydantic import BaseModel + + +class PhrasePing(BaseModel): + phrase: str + is_re2: bool + + +class NewPing(Document): + user_id: Annotated[int, Indexed()] + pings: list[PhrasePing] + dnd: bool + + class Settings: + name = "pings" + + +class OldPing(Document): + user_id: Annotated[int, Indexed()] + word_pings: list[str] + dnd: bool + + class Settings: + name = "pings" + + +class Forward: + @iterative_migration() + async def str_to_word_ping_model( + self, + input_document: OldPing, + output_document: NewPing, + ): + output_document.pings = [ + PhrasePing(phrase=word, is_re2=False) for word in input_document.word_pings + ] + + +class Backward: + @iterative_migration() + async def str_to_word_ping_model( + self, + input_document: NewPing, + output_document: OldPing, + ): + output_document.word_pings = [ + ping.phrase for ping in input_document.pings if not ping.is_re2 + ] diff --git a/requirements.txt b/requirements.txt index a4dc083..d1560c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,3 +25,4 @@ rich==13.7.0 pydantic<2.7,>=2.6 pydantic-settings==2.2.1 beanie==1.26.0 +google-re2==1.1.20250805 diff --git a/src/discord/censor.py b/src/discord/censor.py index ee0732b..3a206ac 100644 --- a/src/discord/censor.py +++ b/src/discord/censor.py @@ -2,15 +2,16 @@ Contains all functionality related to censoring users' actions in the Scioly.org Discord server. """ + from __future__ import annotations import asyncio import contextlib import logging -import re from typing import TYPE_CHECKING import discord +import re2 from discord.ext import commands import src.discord.globals @@ -93,10 +94,12 @@ async def on_message(self, message: discord.Message) -> None: def word_present(self, content: str) -> bool: with contextlib.suppress(asyncio.CancelledError): for word in src.discord.globals.CENSOR.words: - if re.findall(rf"\b({word})\b", content, re.I): + flags = re2.Options() + flags.case_sensitive = False + if re2.findall(rf"\b({word})\b", content, flags): return True for emoji in src.discord.globals.CENSOR.emojis: - if len(re.findall(emoji, content)): + if len(re2.findall(emoji, content)): return True return False @@ -119,11 +122,13 @@ def discord_invite_censor_needed(self, content: str) -> bool: Determines whether the Discord invite link censor is needed. In other words, whether this content contains a Discord invite link. """ + flags = re2.Options() + flags.case_sensitive = False if not any( ending for ending in DISCORD_INVITE_ENDINGS if ending in content ) and ( - len(re.findall("discord.gg", content, re.I)) > 0 - or len(re.findall("discord.com/invite", content, re.I)) > 0 + len(re2.findall("discord.gg", content, flags)) > 0 + or len(re2.findall("discord.com/invite", content, flags)) > 0 ): return True return False @@ -141,15 +146,17 @@ async def __censor(self, message: discord.Message): author = message.author.nick or message.author.name # Actually replace content found on the censored words/emojis list + flags = re2.Options() + flags.case_sensitive = False for word in src.discord.globals.CENSOR.words: - content = re.sub( + content = re2.sub( rf"\b({word})\b", "", content, - flags=re.IGNORECASE, + options=flags, ) for emoji in src.discord.globals.CENSOR.emojis: - content = re.sub(emoji, "", content, flags=re.I) + content = re2.sub(emoji, "", content, options=flags) reply = ( (message.reference.resolved or message.reference.cached_message) diff --git a/src/discord/embed.py b/src/discord/embed.py index 2f3b82c..0141e17 100644 --- a/src/discord/embed.py +++ b/src/discord/embed.py @@ -7,10 +7,10 @@ import asyncio import contextlib import json -import re from typing import TYPE_CHECKING, Any import discord +import re2 import webcolors from discord import app_commands from discord.ext import commands @@ -480,7 +480,7 @@ async def callback(self, interaction: discord.Interaction) -> None: # If the user is attempting to update the color of the embed, but doesn't # pass a color, deny them if self.update_value == "color" and not ( - re.findall(r"#[0-9a-f]{6}", response_message.content.lower()) + re2.findall(r"#[0-9a-f]{6}", response_message.content.lower()) ): help_message = await self.embed_view.channel.send( "The color you provide must be a hex code. For example, `#abbb02` " diff --git a/src/discord/membercommands.py b/src/discord/membercommands.py index 2614e31..315cbe2 100644 --- a/src/discord/membercommands.py +++ b/src/discord/membercommands.py @@ -6,10 +6,10 @@ import datetime import random -import re from typing import TYPE_CHECKING, Literal import discord +import re2 import wikipedia as wikip from aioify import aioify from discord import app_commands @@ -164,10 +164,12 @@ async def profile( text = text.decode("utf-8") description = "" - total_posts_matches = re.search( + flags = re2.Options() + flags.one_line = False + total_posts_matches = re2.search( r"(?:
Total posts:<\/dt>\s+
)(\d+)", text, - re.MULTILINE, + flags, ) if total_posts_matches is None: return await interaction.response.send_message( @@ -176,13 +178,13 @@ async def profile( else: description += f"**Total Posts:** `{total_posts_matches.group(1)} posts`\n" - has_thanked_matches = re.search(r"Has thanked: (\d+)", text, re.MULTILINE) + has_thanked_matches = re2.search(r"Has thanked: (\d+)", text, flags) description += f"**Has Thanked:** `{has_thanked_matches.group(1)} times`\n" - been_thanked_matches = re.search( + been_thanked_matches = re2.search( r"Been(?: )?thanked: (\d+)", text, - re.MULTILINE, + flags, ) description += f"**Been Thanked:** `{been_thanked_matches.group(1)} times`\n" @@ -192,7 +194,7 @@ async def profile( ] for pattern in date_regexes: try: - matches = re.search(pattern["regex"], text, re.MULTILINE) + matches = re2.search(pattern["regex"], text, flags) raw_dt_string = matches.group(1) raw_dt_string = raw_dt_string.replace("st", "") raw_dt_string = raw_dt_string.replace("nd", "") @@ -211,18 +213,18 @@ async def profile( pass for i in range(1, 7): - stars_matches = re.search( + stars_matches = re2.search( rf" {rule}") diff --git a/src/discord/ping.py b/src/discord/ping.py index acd86b3..8f04046 100644 --- a/src/discord/ping.py +++ b/src/discord/ping.py @@ -9,10 +9,10 @@ import contextlib import datetime import logging -import re from typing import TYPE_CHECKING import discord +import re2 from discord import app_commands from discord.app_commands import AppCommandContext from discord.ext import commands @@ -106,7 +106,9 @@ async def on_message(self, message: discord.Message): pings = [rf"\b({ping})\b" for ping in user_pings.word_pings] for ping in pings: try: - if len(re.findall(ping, message.content, re.I)): + flags = re2.Options() + flags.case_sensitive = False + if len(re2.findall(ping, message.content, flags)): ping_count += 1 except Exception as e: logger.error( @@ -148,7 +150,9 @@ def format_text( for expression in pings: try: - text = re.sub(rf"{expression}", r"**\1**", text, flags=re.I) + flags = re2.Options() + flags.case_sensitive = False + text = re2.sub(rf"{expression}", r"**\1**", text, options=flags) except Exception as e: logger.warn(f"Could not bold ping due to unfavored RegEx. Error: {e}") @@ -300,7 +304,9 @@ async def pingadd(self, interaction: discord.Interaction, word: str): # User already has an object in the PING_INFO dictionary pings = user.word_pings try: - re.findall(word, "test phrase") + flags = re2.Options() + flags.case_sensitive = False + re2.compile(word, flags) except Exception: return await interaction.response.send_message( f"Ignoring adding the `{word}` ping because it uses illegal characters.", @@ -310,7 +316,7 @@ async def pingadd(self, interaction: discord.Interaction, word: str): f"Ignoring adding the `{word}` ping because you already have a ping currently set as that.", ) else: - logger.debug(f"adding word: {re.escape(word)}") + logger.debug(f"adding word: {re2.escape(word)}") # relevant_doc = next( # doc # for doc in src.discord.globals.PING_INFO @@ -379,12 +385,14 @@ async def pingtest(self, interaction: discord.Interaction, test: str): response = "" for ping in user_pings: + flags = re2.Options() + flags.case_sensitive = False if isinstance(ping, dict): - if len(re.findall(ping["new"], test, re.I)) > 0: + if len(re2.findall(ping["new"], test, flags)) > 0: response += f"Your ping `{ping['original']}` matches `{test}`.\n" matched = True else: - if len(re.findall(ping, test, re.I)) > 0: + if len(re2.findall(ping, test, flags)) > 0: response += f"Your ping `{ping}` matches `{test}`.\n" matched = True diff --git a/src/discord/staff/events.py b/src/discord/staff/events.py index 3dc5079..f3b5c1d 100644 --- a/src/discord/staff/events.py +++ b/src/discord/staff/events.py @@ -1,9 +1,9 @@ from __future__ import annotations -import re from typing import TYPE_CHECKING, Literal import discord +import re2 from discord import app_commands from discord.ext import commands @@ -65,7 +65,7 @@ async def event_add( # and local storage aliases_array = [] if event_aliases: - aliases_array = re.findall(r"\w+", event_aliases) + aliases_array = re2.findall(r"\w+", event_aliases) new_dict = Event(name=event_name, aliases=aliases_array, emoji=None) # Add dict into events container diff --git a/src/discord/staff/invitationals.py b/src/discord/staff/invitationals.py index e49dbd0..71d8063 100644 --- a/src/discord/staff/invitationals.py +++ b/src/discord/staff/invitationals.py @@ -1,10 +1,10 @@ from __future__ import annotations import datetime -import re from typing import TYPE_CHECKING, Literal import discord +import re2 from beanie.odm.operators.update.general import Inc from discord import Emoji, Guild, app_commands from discord.ext import commands @@ -534,7 +534,7 @@ async def invitational_delete( await r.delete() # Delete the invitational emoji - search = re.findall(r"<:.*:\d+>", invitational.emoji) + search = re2.findall(r"<:.*:\d+>", invitational.emoji) if len(search): emoji = self.bot.get_emoji(search[0]) if emoji: diff --git a/src/discord/staffcommands.py b/src/discord/staffcommands.py index 0c0a5f8..4b45b83 100644 --- a/src/discord/staffcommands.py +++ b/src/discord/staffcommands.py @@ -4,11 +4,11 @@ import contextlib import datetime import logging -import re from typing import TYPE_CHECKING, Literal import discord import matplotlib.pyplot as plt +import re2 from beanie.odm.operators.update.general import Set from discord import app_commands from discord.ext import commands @@ -306,8 +306,8 @@ def __init__(self, docs: list[Cron], bot: PiBot): async def callback(self, interaction: discord.Interaction): value = self.values[0] - num = re.findall(r"\(#(\d*)", value) - value = re.sub(r" \(#\d*\)", "", value) + num = re2.findall(r"\(#(\d*)", value) + value = re2.sub(r" \(#\d*\)", "", value) relevant_doc = [ d for d in self.docs if f"{d.cron_type.title()} {d.tag}" == value ] diff --git a/src/mongo/models.py b/src/mongo/models.py index 000f099..801b2d3 100644 --- a/src/mongo/models.py +++ b/src/mongo/models.py @@ -25,9 +25,14 @@ class Settings: use_cache = False +class PhrasePing(BaseModel): + phrase: str + is_re2: bool + + class Ping(Document): user_id: Annotated[int, Indexed()] - word_pings: list[str] + pings: list[PhrasePing] dnd: bool class Settings: diff --git a/src/wiki/wiki.py b/src/wiki/wiki.py index 24c7c20..6e74c3b 100644 --- a/src/wiki/wiki.py +++ b/src/wiki/wiki.py @@ -1,8 +1,8 @@ import asyncio import logging -import re import pywikibot +import re2 import wikitextparser as wtp from aioify import aioify @@ -84,7 +84,7 @@ async def implement_command(action, page_title): pt = wtp.parse(rf"{text}").plain_text() title = await page.title() link = site.base_url(site.article_path + title.replace(" ", "_")) - return re.split(r"(?!", ]