diff --git a/src/plugins/nonebot_hk_reporter/platform/arknights.py b/src/plugins/nonebot_hk_reporter/platform/arknights.py index 5239d52..c7de7e5 100644 --- a/src/plugins/nonebot_hk_reporter/platform/arknights.py +++ b/src/plugins/nonebot_hk_reporter/platform/arknights.py @@ -22,8 +22,7 @@ class Arknights(NewMessage, NoTargetMixin): schedule_type = 'interval' schedule_kw = {'seconds': 30} - @staticmethod - async def get_target_name(_: Target) -> str: + async def get_target_name(self, _: Target) -> str: return '明日方舟游戏内公告' async def get_sub_list(self, _) -> list[RawPost]: diff --git a/src/plugins/nonebot_hk_reporter/platform/bilibili.py b/src/plugins/nonebot_hk_reporter/platform/bilibili.py index 02924cc..227da47 100644 --- a/src/plugins/nonebot_hk_reporter/platform/bilibili.py +++ b/src/plugins/nonebot_hk_reporter/platform/bilibili.py @@ -24,8 +24,7 @@ class Bilibili(NewMessage, TargetMixin): schedule_kw = {'seconds': 10} name = 'B站' - @staticmethod - async def get_target_name(target: Target) -> Optional[str]: + async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: res = await client.get('https://api.bilibili.com/x/space/acc/info', params={'mid': target}) res_data = json.loads(res.text) diff --git a/src/plugins/nonebot_hk_reporter/platform/platform.py b/src/plugins/nonebot_hk_reporter/platform/platform.py index 5a83c0a..25f7e96 100644 --- a/src/plugins/nonebot_hk_reporter/platform/platform.py +++ b/src/plugins/nonebot_hk_reporter/platform/platform.py @@ -1,5 +1,7 @@ from abc import abstractmethod, ABC +from collections import defaultdict from dataclasses import dataclass +from functools import reduce import time from typing import Any, Collection, Optional, Literal @@ -230,9 +232,8 @@ class Platform(PlatformNameMixin, UserCustomFilterMixin, base=True): enabled: bool name: str - @staticmethod @abstractmethod - async def get_target_name(target: Target) -> Optional[str]: + async def get_target_name(self, target: Target) -> Optional[str]: ... @abstractmethod @@ -299,3 +300,50 @@ class StatusChange( except httpx.RequestError as err: logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) return [] + +class NoTargetGroup( + Platform, + NoTargetMixin, + UserCustomFilterMixin, + abstract=True + ): + enable_tag = False + DUMMY_STR = '_DUMMY' + enabled = True + + class PlatformProto(Platform, NoTargetMixin, UserCustomFilterMixin, abstract=True): + ... + + def __init__(self, platform_list: list[PlatformProto]): + self.platform_list = platform_list + name = self.DUMMY_STR + self.categories = {} + categories_keys = set() + self.schedule_type = platform_list[0].schedule_type + self.schedule_kw = platform_list[0].schedule_kw + for platform in platform_list: + if name == self.DUMMY_STR: + name = platform.name + elif name != platform.name: + raise RuntimeError('Platform name for {} not fit'.format(self.platform_name)) + platform_category_key_set = set(platform.categories.keys()) + if platform_category_key_set & categories_keys: + raise RuntimeError('Platform categories for {} duplicate'.format(self.platform_name)) + categories_keys |= platform_category_key_set + self.categories.update(platform.categories) + if platform.schedule_kw != self.schedule_kw or platform.schedule_type != self.schedule_type: + raise RuntimeError('Platform scheduler for {} not fit'.format(self.platform_name)) + self.name = name + self.is_common = platform_list[0].is_common + super().__init__() + + async def get_target_name(self, _): + return await self.platform_list[0].get_target_name(_) + + async def fetch_new_post(self, target, users): + res = defaultdict(list) + for platform in self.platform_list: + platform_res = await platform.fetch_new_post(target, users) + for user, posts in platform_res: + res[user].extend(posts) + return [[key, val] for key, val in res.items()] diff --git a/src/plugins/nonebot_hk_reporter/platform/rss.py b/src/plugins/nonebot_hk_reporter/platform/rss.py index 3037ccb..1874d87 100644 --- a/src/plugins/nonebot_hk_reporter/platform/rss.py +++ b/src/plugins/nonebot_hk_reporter/platform/rss.py @@ -20,8 +20,7 @@ class Rss(NewMessage, TargetMixin): schedule_type = 'interval' schedule_kw = {'seconds': 30} - @staticmethod - async def get_target_name(target: Target) -> Optional[str]: + async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: res = await client.get(target, timeout=10.0) feed = feedparser.parse(res.text) diff --git a/src/plugins/nonebot_hk_reporter/platform/weibo.py b/src/plugins/nonebot_hk_reporter/platform/weibo.py index 57d39ae..b4517cd 100644 --- a/src/plugins/nonebot_hk_reporter/platform/weibo.py +++ b/src/plugins/nonebot_hk_reporter/platform/weibo.py @@ -27,8 +27,7 @@ class Weibo(NewMessage, TargetMixin): schedule_type = 'interval' schedule_kw = {'seconds': 10} - @staticmethod - async def get_target_name(target: Target) -> Optional[str]: + async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: param = {'containerid': '100505' + target} res = await client.get('https://m.weibo.cn/api/container/getIndex', params=param) diff --git a/src/plugins/nonebot_hk_reporter/types.py b/src/plugins/nonebot_hk_reporter/types.py index 9e7b08f..089b2b0 100644 --- a/src/plugins/nonebot_hk_reporter/types.py +++ b/src/plugins/nonebot_hk_reporter/types.py @@ -6,7 +6,7 @@ Target = NewType('Target', str) Category = NewType('Category', int) Tag = NewType('Tag', str) -@dataclass +@dataclass(eq=True, frozen=True) class User: user: str user_type: str diff --git a/tests/platforms/test_platform.py b/tests/platforms/test_platform.py index 468eb6d..5f9c2f9 100644 --- a/tests/platforms/test_platform.py +++ b/tests/platforms/test_platform.py @@ -131,7 +131,8 @@ def mock_platform_no_target(plugin_module: 'nonebot_hk_reporter'): name = 'Mock Platform' enabled = True is_common = True - schedule_interval = 10 + schedule_type = 'interval' + schedule_kw = {'seconds': 30} enable_tag = True categories = { 1: '转发', @@ -172,6 +173,62 @@ def mock_platform_no_target(plugin_module: 'nonebot_hk_reporter'): return MockPlatform() +@pytest.fixture +def mock_platform_no_target_2(plugin_module: 'nonebot_hk_reporter'): + class MockPlatform(plugin_module.platform.platform.NewMessage, + plugin_module.platform.platform.NoTargetMixin): + + platform_name = 'mock_platform' + name = 'Mock Platform' + enabled = True + schedule_type = 'interval' + schedule_kw = {'seconds': 30} + is_common = True + enable_tag = True + categories = { + 4: 'leixing4', + 5: 'leixing5', + } + def __init__(self): + self.sub_index = 0 + super().__init__() + + @staticmethod + async def get_target_name(_: 'Target'): + return 'MockPlatform' + + def get_id(self, post: 'RawPost') -> Any: + return post['id'] + + def get_date(self, raw_post: 'RawPost') -> float: + return raw_post['date'] + + def get_tags(self, raw_post: 'RawPost') -> list['Tag']: + return raw_post['tags'] + + def get_category(self, raw_post: 'RawPost') -> 'Category': + return raw_post['category'] + + async def parse(self, raw_post: 'RawPost') -> 'Post': + return plugin_module.post.Post('mock_platform_2', raw_post['text'], 'http://t.tt/' + str(self.get_id(raw_post)), target_name='Mock') + + async def get_sub_list(self, _: 'Target'): + list_1 = [ + {'id': 5, 'text': 'p5', 'date': now, 'tags': ['tag1'], 'category': 4} + ] + + list_2 = list_1 + [ + {'id': 6, 'text': 'p6', 'date': now, 'tags': ['tag1'], 'category': 4}, + {'id': 7, 'text': 'p7', 'date': now, 'tags': ['tag2'], 'category': 5}, + ] + if self.sub_index == 0: + self.sub_index += 1 + return list_1 + else: + return list_2 + + return MockPlatform() + @pytest.fixture def mock_status_change(plugin_module: 'nonebot_hk_reporter'): class MockPlatform(plugin_module.platform.platform.StatusChange, @@ -300,3 +357,17 @@ async def test_status_change(mock_status_change, user_info_factory): assert(len(res3[1][1]) == 0) res4 = await mock_status_change.fetch_new_post('dummy', [user_info_factory(lambda _: [1,2], lambda _: [])]) assert(len(res4) == 0) + +@pytest.mark.asyncio +async def test_group(plugin_module: 'nonebot_hk_reporter', mock_platform_no_target, mock_platform_no_target_2, user_info_factory): + group_platform = plugin_module.platform.platform.NoTargetGroup([mock_platform_no_target, mock_platform_no_target_2]) + res1 = await group_platform.fetch_new_post('dummy', [user_info_factory(lambda _: [1,4], lambda _: [])]) + assert(len(res1) == 0) + res2 = await group_platform.fetch_new_post('dummy', [user_info_factory(lambda _: [1,4], lambda _: [])]) + assert(len(res2) == 1) + posts = res2[0][1] + assert(len(posts) == 2) + id_set_2 = set(map(lambda x: x.text, posts)) + assert('p2' in id_set_2 and 'p6' in id_set_2) + res3 = await group_platform.fetch_new_post('dummy', [user_info_factory(lambda _: [1,4], lambda _: [])]) + assert(len(res3) == 0)