diff --git a/src/plugins/nonebot_bison/config.py b/src/plugins/nonebot_bison/config.py index 770b090..0712fcb 100644 --- a/src/plugins/nonebot_bison/config.py +++ b/src/plugins/nonebot_bison/config.py @@ -1,7 +1,7 @@ from collections import defaultdict from os import path import os -from typing import DefaultDict, Mapping +from typing import DefaultDict, Literal, Mapping, TypedDict import nonebot from nonebot import logger @@ -34,6 +34,18 @@ class NoSuchUserException(Exception): class NoSuchSubscribeException(Exception): pass +class SubscribeContent(TypedDict): + target: str + target_type: str + target_name: str + cats: list[int] + tags: list[str] + +class ConfigContent(TypedDict): + user: str + user_type: Literal["group", "private"] + subs: list[SubscribeContent] + class Config(metaclass=Singleton): migrate_version = 2 @@ -64,7 +76,7 @@ class Config(metaclass=Singleton): }) self.update_send_cache() - def list_subscribe(self, user, user_type): + def list_subscribe(self, user, user_type) -> list[SubscribeContent]: query = Query() if user_sub := self.user_target.get((query.user == user) & (query.user_type ==user_type)): return user_sub['subs'] diff --git a/src/plugins/nonebot_bison/platform/__init__.py b/src/plugins/nonebot_bison/platform/__init__.py index bc66380..028f002 100644 --- a/src/plugins/nonebot_bison/platform/__init__.py +++ b/src/plugins/nonebot_bison/platform/__init__.py @@ -14,7 +14,7 @@ async def check_sub_target(target_type, target): return await platform_manager[target_type].get_target_name(target) _platform_list = defaultdict(list) -for _platform in Platform.registory: +for _platform in Platform.registry: if not _platform.enabled: continue _platform_list[_platform.platform_name].append(_platform) diff --git a/src/plugins/nonebot_bison/platform/arknights.py b/src/plugins/nonebot_bison/platform/arknights.py index 6015a02..fb25729 100644 --- a/src/plugins/nonebot_bison/platform/arknights.py +++ b/src/plugins/nonebot_bison/platform/arknights.py @@ -7,10 +7,10 @@ import httpx from ..post import Post from ..types import Category, RawPost, Target from ..utils import Render -from .platform import CategoryNotSupport, NewMessage, NoTargetMixin, StatusChange +from .platform import CategoryNotSupport, NewMessage, StatusChange -class Arknights(NewMessage, NoTargetMixin): +class Arknights(NewMessage): categories = {1: '游戏公告'} platform_name = 'arknights' @@ -20,6 +20,7 @@ class Arknights(NewMessage, NoTargetMixin): is_common = False schedule_type = 'interval' schedule_kw = {'seconds': 30} + has_target = False async def get_target_name(self, _: Target) -> str: return '明日方舟游戏信息' @@ -60,7 +61,7 @@ class Arknights(NewMessage, NoTargetMixin): raise CategoryNotSupport() return Post('arknights', text=text, url='', target_name="明日方舟游戏内公告", pics=pics, compress=True, override_use_pic=False) -class AkVersion(NoTargetMixin, StatusChange): +class AkVersion(StatusChange): categories = {2: '更新信息'} platform_name = 'arknights' @@ -70,6 +71,7 @@ class AkVersion(NoTargetMixin, StatusChange): is_common = False schedule_type = 'interval' schedule_kw = {'seconds': 30} + has_target = False async def get_target_name(self, _: Target) -> str: return '明日方舟游戏信息' @@ -104,7 +106,7 @@ class AkVersion(NoTargetMixin, StatusChange): async def parse(self, raw_post): return raw_post -class MonsterSiren(NewMessage, NoTargetMixin): +class MonsterSiren(NewMessage): categories = {3: '塞壬唱片新闻'} platform_name = 'arknights' @@ -114,6 +116,7 @@ class MonsterSiren(NewMessage, NoTargetMixin): is_common = False schedule_type = 'interval' schedule_kw = {'seconds': 30} + has_target = False async def get_target_name(self, _: Target) -> str: return '明日方舟游戏信息' diff --git a/src/plugins/nonebot_bison/platform/bilibili.py b/src/plugins/nonebot_bison/platform/bilibili.py index b46d364..8980eb1 100644 --- a/src/plugins/nonebot_bison/platform/bilibili.py +++ b/src/plugins/nonebot_bison/platform/bilibili.py @@ -5,9 +5,9 @@ import httpx from ..post import Post from ..types import Category, RawPost, Tag, Target -from .platform import NewMessage, TargetMixin, CategoryNotSupport +from .platform import NewMessage, CategoryNotSupport -class Bilibili(NewMessage, TargetMixin): +class Bilibili(NewMessage): categories = { 1: "一般动态", @@ -24,6 +24,7 @@ class Bilibili(NewMessage, TargetMixin): schedule_type = 'interval' schedule_kw = {'seconds': 10} name = 'B站' + has_target = True async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: diff --git a/src/plugins/nonebot_bison/platform/ncm_artist.py b/src/plugins/nonebot_bison/platform/ncm_artist.py index 28664c9..e230f65 100644 --- a/src/plugins/nonebot_bison/platform/ncm_artist.py +++ b/src/plugins/nonebot_bison/platform/ncm_artist.py @@ -3,9 +3,9 @@ from typing import Any, Optional import httpx from ..post import Post from ..types import RawPost, Target -from .platform import TargetMixin, NewMessage +from .platform import NewMessage -class NcmArtist(TargetMixin, NewMessage): +class NcmArtist(NewMessage): categories = {} platform_name = 'ncm-artist' @@ -15,6 +15,7 @@ class NcmArtist(TargetMixin, NewMessage): schedule_type = 'interval' schedule_kw = {'minutes': 1} name = "网易云-歌手" + has_target = True async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: diff --git a/src/plugins/nonebot_bison/platform/ncm_radio.py b/src/plugins/nonebot_bison/platform/ncm_radio.py index 246cf05..6fae725 100644 --- a/src/plugins/nonebot_bison/platform/ncm_radio.py +++ b/src/plugins/nonebot_bison/platform/ncm_radio.py @@ -3,9 +3,9 @@ from typing import Any, Optional import httpx from ..post import Post from ..types import RawPost, Target -from .platform import TargetMixin, NewMessage +from .platform import NewMessage -class NcmRadio(TargetMixin, NewMessage): +class NcmRadio(NewMessage): categories = {} platform_name = 'ncm-radio' @@ -15,6 +15,7 @@ class NcmRadio(TargetMixin, NewMessage): schedule_type = 'interval' schedule_kw = {'minutes': 10} name = "网易云-电台" + has_target = True async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: diff --git a/src/plugins/nonebot_bison/platform/platform.py b/src/plugins/nonebot_bison/platform/platform.py index b8afa72..19dd5de 100644 --- a/src/plugins/nonebot_bison/platform/platform.py +++ b/src/plugins/nonebot_bison/platform/platform.py @@ -24,35 +24,55 @@ class RegistryMeta(type): def __init__(cls, name, bases, namespace, **kwargs): if kwargs.get('base'): # this is the base class - cls.registory = [] + cls.registry = [] elif not kwargs.get('abstract'): # this is the subclass - cls.registory.append(cls) + cls.registry.append(cls) super().__init__(name, bases, namespace, **kwargs) class RegistryABCMeta(RegistryMeta, ABC): ... -class StorageMixinProto(metaclass=RegistryABCMeta, abstract=True): +class Platform(metaclass=RegistryABCMeta, base=True): + schedule_type: Literal['date', 'interval', 'cron'] + schedule_kw: dict + is_common: bool + enabled: bool + name: str has_target: bool + categories: dict[Category, str] + enable_tag: bool + store: dict[Target, Any] + platform_name: str @abstractmethod - def get_stored_data(self, target: Target) -> Any: + async def get_target_name(self, target: Target) -> Optional[str]: ... @abstractmethod - def set_stored_data(self, target: Target, data: Any): + async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]: ... -class TargetMixin(StorageMixinProto, abstract=True): + @abstractmethod + async def parse(self, raw_post: RawPost) -> Post: + ... - has_target = True + async def do_parse(self, raw_post: RawPost) -> Post: + "actually function called" + return await self.parse(raw_post) def __init__(self): super().__init__() - self.store: dict[Target, Any] = dict() + self.reverse_category = {} + for key, val in self.categories.items(): + self.reverse_category[val] = key + self.store = dict() + + @abstractmethod + def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]: + "Return Tag list of given RawPost" def get_stored_data(self, target: Target) -> Any: return self.store.get(target) @@ -60,39 +80,43 @@ class TargetMixin(StorageMixinProto, abstract=True): def set_stored_data(self, target: Target, data: Any): self.store[target] = data + async def filter_user_custom(self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]) -> list[RawPost]: + res: list[RawPost] = [] + for raw_post in raw_post_list: + if self.categories: + cat = self.get_category(raw_post) + if cats and cat not in cats: + continue + if self.enable_tag and tags: + flag = False + post_tags = self.get_tags(raw_post) + for tag in post_tags or []: + if tag in tags: + flag = True + break + if not flag: + continue + res.append(raw_post) + return res -class NoTargetMixin(StorageMixinProto, abstract=True): - - has_target = False - - def __init__(self): - super().__init__() - self.store = None - - def get_stored_data(self, _: Target) -> Any: - return self.store - - def set_stored_data(self, _: Target, data: Any): - self.store = data - -class PlatformNameMixin(metaclass=RegistryABCMeta, abstract=True): - platform_name: str - -class CategoryMixin(metaclass=RegistryABCMeta, abstract=True): + async def dispatch_user_post(self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]: + res: list[tuple[User, list[Post]]] = [] + for user, category_getter, tag_getter in users: + required_tags = tag_getter(target) if self.enable_tag else [] + cats = category_getter(target) + user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags) + user_post: list[Post] = [] + for raw_post in user_raw_post: + user_post.append(await self.do_parse(raw_post)) + res.append((user, user_post)) + return res @abstractmethod def get_category(self, post: RawPost) -> Optional[Category]: "Return category of given Rawpost" raise NotImplementedError() -class ParsePostMixin(metaclass=RegistryABCMeta, abstract=True): - - @abstractmethod - async def parse(self, raw_post: RawPost) -> Post: - "parse RawPost into post" - ... - -class MessageProcessMixin(PlatformNameMixin, CategoryMixin, ParsePostMixin, abstract=True): +class MessageProcess(Platform, abstract=True): "General message process fetch, parse, filter progress" def __init__(self): @@ -104,7 +128,7 @@ class MessageProcessMixin(PlatformNameMixin, CategoryMixin, ParsePostMixin, abst "Get post id of given RawPost" - async def _parse_with_cache(self, raw_post: RawPost) -> Post: + async def do_parse(self, raw_post: RawPost) -> Post: post_id = self.get_id(raw_post) if post_id not in self.parse_cache: retry_times = 3 @@ -144,8 +168,8 @@ class MessageProcessMixin(PlatformNameMixin, CategoryMixin, ParsePostMixin, abst res.append(raw_post) return res -class NewMessageProcessMixin(StorageMixinProto, MessageProcessMixin, abstract=True): - "General message process, fetch, parse, filter, and only returns the new Post" +class NewMessage(MessageProcess, abstract=True): + "Fetch a list of messages, filter the new messages, dispatch it to different users" @dataclass class MessageStorage(): @@ -173,79 +197,6 @@ class NewMessageProcessMixin(StorageMixinProto, MessageProcessMixin, abstract=Tr self.set_stored_data(target, store) return res -class UserCustomFilterMixin(CategoryMixin, ParsePostMixin, abstract=True): - - categories: dict[Category, str] - enable_tag: bool - - def __init__(self): - super().__init__() - self.reverse_category = {} - for key, val in self.categories.items(): - self.reverse_category[val] = key - - @abstractmethod - def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]: - "Return Tag list of given RawPost" - - async def filter_user_custom(self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]) -> list[RawPost]: - res: list[RawPost] = [] - for raw_post in raw_post_list: - if self.categories: - cat = self.get_category(raw_post) - if cats and cat not in cats: - continue - if self.enable_tag and tags: - flag = False - post_tags = self.get_tags(raw_post) - for tag in post_tags or []: - if tag in tags: - flag = True - break - if not flag: - continue - res.append(raw_post) - return res - - async def dispatch_user_post(self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]: - res: list[tuple[User, list[Post]]] = [] - for user, category_getter, tag_getter in users: - required_tags = tag_getter(target) if self.enable_tag else [] - cats = category_getter(target) - user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags) - user_post: list[Post] = [] - for raw_post in user_raw_post: - if isinstance(self, MessageProcessMixin): - user_post.append(await self._parse_with_cache(raw_post)) - else: - user_post.append(await self.parse(raw_post)) - res.append((user, user_post)) - return res - -class Platform(PlatformNameMixin, UserCustomFilterMixin, base=True): - - # schedule_interval: int - schedule_type: Literal['date', 'interval', 'cron'] - schedule_kw: dict - is_common: bool - enabled: bool - name: str - - @abstractmethod - async def get_target_name(self, target: Target) -> Optional[str]: - ... - - @abstractmethod - async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]: - ... - -class NewMessage( - Platform, - NewMessageProcessMixin, - UserCustomFilterMixin, - abstract=True - ): - "Fetch a list of messages, filter the new messages, dispatch it to different users" async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]: try: @@ -266,12 +217,7 @@ class NewMessage( logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) return [] -class StatusChange( - Platform, - StorageMixinProto, - UserCustomFilterMixin, - abstract=True - ): +class StatusChange(Platform, abstract=True): "Watch a status, and fire a post when status changes" @abstractmethod @@ -305,13 +251,7 @@ class StatusChange( logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) return [] -class SimplePost( - Platform, - MessageProcessMixin, - UserCustomFilterMixin, - StorageMixinProto, - abstract=True - ): +class SimplePost(MessageProcess, abstract=True): "Fetch a list of messages, dispatch it to different users" async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]: @@ -332,20 +272,12 @@ class SimplePost( logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) return [] -class NoTargetGroup( - Platform, - NoTargetMixin, - UserCustomFilterMixin, - abstract=True - ): +class NoTargetGroup(Platform, abstract=True): enable_tag = False DUMMY_STR = '_DUMMY' enabled = True - class PlatformProto(Platform, NoTargetMixin, UserCustomFilterMixin, abstract=True): - ... - - def __init__(self, platform_list: list[PlatformProto]): + def __init__(self, platform_list: list[Platform]): self.platform_list = platform_list name = self.DUMMY_STR self.categories = {} @@ -353,6 +285,8 @@ class NoTargetGroup( self.schedule_type = platform_list[0].schedule_type self.schedule_kw = platform_list[0].schedule_kw for platform in platform_list: + if platform.has_target: + raise RuntimeError('Platform {} should have no target'.format(platform.name)) if name == self.DUMMY_STR: name = platform.name elif name != platform.name: @@ -381,3 +315,4 @@ class NoTargetGroup( for user, posts in platform_res: res[user].extend(posts) return [[key, val] for key, val in res.items()] + diff --git a/src/plugins/nonebot_bison/platform/rss.py b/src/plugins/nonebot_bison/platform/rss.py index 1874d87..4cc18cc 100644 --- a/src/plugins/nonebot_bison/platform/rss.py +++ b/src/plugins/nonebot_bison/platform/rss.py @@ -7,9 +7,9 @@ import httpx from ..post import Post from ..types import RawPost, Target -from .platform import NewMessage, TargetMixin +from .platform import NewMessage -class Rss(NewMessage, TargetMixin): +class Rss(NewMessage): categories = {} enable_tag = False @@ -19,6 +19,7 @@ class Rss(NewMessage, TargetMixin): is_common = True schedule_type = 'interval' schedule_kw = {'seconds': 30} + has_target = True async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: diff --git a/src/plugins/nonebot_bison/platform/weibo.py b/src/plugins/nonebot_bison/platform/weibo.py index e3a27c4..19d8703 100644 --- a/src/plugins/nonebot_bison/platform/weibo.py +++ b/src/plugins/nonebot_bison/platform/weibo.py @@ -9,9 +9,9 @@ from nonebot import logger from ..post import Post from ..types import * -from .platform import NewMessage, TargetMixin +from .platform import NewMessage -class Weibo(NewMessage, TargetMixin): +class Weibo(NewMessage): categories = { 1: '转发', @@ -26,6 +26,7 @@ class Weibo(NewMessage, TargetMixin): is_common = True schedule_type = 'interval' schedule_kw = {'seconds': 3} + has_target = True async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: diff --git a/tests/platforms/test_arknights.py b/tests/platforms/test_arknights.py index d105ac4..1058019 100644 --- a/tests/platforms/test_arknights.py +++ b/tests/platforms/test_arknights.py @@ -2,7 +2,6 @@ import pytest import typing import respx from httpx import Response -import feedparser if typing.TYPE_CHECKING: import sys diff --git a/tests/platforms/test_platform.py b/tests/platforms/test_platform.py index c8c661b..f708863 100644 --- a/tests/platforms/test_platform.py +++ b/tests/platforms/test_platform.py @@ -38,8 +38,7 @@ def user_info_factory(plugin_module: 'nonebot_bison', dummy_user): @pytest.fixture def mock_platform_without_cats_tags(plugin_module: 'nonebot_bison'): - class MockPlatform(plugin_module.platform.platform.NewMessage, - plugin_module.platform.platform.TargetMixin): + class MockPlatform(plugin_module.platform.platform.NewMessage): platform_name = 'mock_platform' name = 'Mock Platform' @@ -48,6 +47,7 @@ def mock_platform_without_cats_tags(plugin_module: 'nonebot_bison'): schedule_interval = 10 enable_tag = False categories = {} + has_target = True def __init__(self): self.sub_index = 0 @@ -77,8 +77,7 @@ def mock_platform_without_cats_tags(plugin_module: 'nonebot_bison'): @pytest.fixture def mock_platform(plugin_module: 'nonebot_bison'): - class MockPlatform(plugin_module.platform.platform.NewMessage, - plugin_module.platform.platform.TargetMixin): + class MockPlatform(plugin_module.platform.platform.NewMessage): platform_name = 'mock_platform' name = 'Mock Platform' @@ -86,6 +85,7 @@ def mock_platform(plugin_module: 'nonebot_bison'): is_common = True schedule_interval = 10 enable_tag = True + has_target = True categories = { 1: '转发', 2: '视频', @@ -124,8 +124,7 @@ def mock_platform(plugin_module: 'nonebot_bison'): @pytest.fixture def mock_platform_no_target(plugin_module: 'nonebot_bison'): - class MockPlatform(plugin_module.platform.platform.NewMessage, - plugin_module.platform.platform.NoTargetMixin): + class MockPlatform(plugin_module.platform.platform.NewMessage): platform_name = 'mock_platform' name = 'Mock Platform' @@ -134,6 +133,7 @@ def mock_platform_no_target(plugin_module: 'nonebot_bison'): schedule_type = 'interval' schedule_kw = {'seconds': 30} enable_tag = True + has_target = False categories = { 1: '转发', 2: '视频', @@ -175,8 +175,7 @@ def mock_platform_no_target(plugin_module: 'nonebot_bison'): @pytest.fixture def mock_platform_no_target_2(plugin_module: 'nonebot_bison'): - class MockPlatform(plugin_module.platform.platform.NewMessage, - plugin_module.platform.platform.NoTargetMixin): + class MockPlatform(plugin_module.platform.platform.NewMessage): platform_name = 'mock_platform' name = 'Mock Platform' @@ -185,6 +184,7 @@ def mock_platform_no_target_2(plugin_module: 'nonebot_bison'): schedule_kw = {'seconds': 30} is_common = True enable_tag = True + has_target = False categories = { 4: 'leixing4', 5: 'leixing5', @@ -231,8 +231,7 @@ def mock_platform_no_target_2(plugin_module: 'nonebot_bison'): @pytest.fixture def mock_status_change(plugin_module: 'nonebot_bison'): - class MockPlatform(plugin_module.platform.platform.StatusChange, - plugin_module.platform.platform.NoTargetMixin): + class MockPlatform(plugin_module.platform.platform.StatusChange): platform_name = 'mock_platform' name = 'Mock Platform' @@ -241,6 +240,7 @@ def mock_status_change(plugin_module: 'nonebot_bison'): enable_tag = False schedule_type = 'interval' schedule_kw = {'seconds': 10} + has_target = False categories = { 1: '转发', 2: '视频',